|
67 | 67 | FlowMatchEulerDiscreteScheduler,
|
68 | 68 | FluxPipeline,
|
69 | 69 | FluxTransformer2DModel,
|
70 | | - ParallelConfig, |
71 | | - enable_parallelism, |
72 | 70 | )
|
73 | 71 | from diffusers.optimization import get_scheduler
|
74 | 72 | from diffusers.training_utils import (
|
@@ -807,8 +805,6 @@ def parse_args(input_args=None):
|
807 | 805 | ],
|
808 | 806 | help="The image interpolation method to use for resizing images.",
|
809 | 807 | )
|
810 | | - parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.") |
811 | | - parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.") |
812 | 808 |
|
813 | 809 | if input_args is not None:
|
814 | 810 | args = parser.parse_args(input_args)
|
@@ -1351,28 +1347,15 @@ def main(args):
|
1351 | 1347 |
|
1352 | 1348 | logging_dir = Path(args.output_dir, args.logging_dir)
|
1353 | 1349 |
|
1354 | | - cp_degree = args.context_parallel_degree |
1355 | 1350 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
1356 | | - if cp_degree > 1: |
1357 | | - kwargs = [] |
1358 | | - else: |
1359 | | - kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)] |
| 1351 | + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
1360 | 1352 | accelerator = Accelerator(
|
1361 | 1353 | gradient_accumulation_steps=args.gradient_accumulation_steps,
|
1362 | 1354 | mixed_precision=args.mixed_precision,
|
1363 | 1355 | log_with=args.report_to,
|
1364 | 1356 | project_config=accelerator_project_config,
|
1365 | | - kwargs_handlers=kwargs, |
1366 | | - ) |
1367 | | - if cp_degree > 1 and not torch.distributed.is_initialized(): |
1368 | | - if not torch.cuda.is_available(): |
1369 | | - raise ValueError("Context parallelism is only tested on CUDA devices.") |
1370 | | - if os.environ.get("WORLD_SIZE", None) is None: |
1371 | | - raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.") |
1372 | | - torch.distributed.init_process_group("nccl") |
1373 | | - rank = torch.distributed.get_rank() |
1374 | | - rank = accelerator.process_index |
1375 | | - torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count())) |
| 1357 | + kwargs_handlers=[kwargs], |
| 1358 | + ) |
1376 | 1359 |
|
1377 | 1360 | # Disable AMP for MPS.
|
1378 | 1361 | if torch.backends.mps.is_available():
|
@@ -1994,14 +1977,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
1994 | 1977 | power=args.lr_power,
|
1995 | 1978 | )
|
1996 | 1979 |
|
1997 | | - # Enable context parallelism |
1998 | | - if cp_degree > 1: |
1999 | | - ring_degree = cp_degree if args.context_parallel_type == "ring" else None |
2000 | | - ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None |
2001 | | - transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree)) |
2002 | | - transformer.set_attention_backend("_native_cudnn") |
2003 | | - parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext() |
2004 | | - |
2005 | 1980 | # Prepare everything with our `accelerator`.
|
2006 | 1981 | if not freeze_text_encoder:
|
2007 | 1982 | if args.enable_t5_ti:
|
@@ -2156,7 +2131,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
|
2156 | 2131 | logger.info(f"PIVOT TRANSFORMER {epoch}")
|
2157 | 2132 | optimizer.param_groups[0]["lr"] = 0.0
|
2158 | 2133 |
|
2159 | | - with accelerator.accumulate(models_to_accumulate), parallel_context: |
| 2134 | + with accelerator.accumulate(models_to_accumulate): |
2160 | 2135 | prompts = batch["prompts"]
|
2161 | 2136 |
|
2162 | 2137 | # encode batch prompts when custom prompts are provided for each image -
|
|
0 commit comments