Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 0018b62

Browse files
Revert "try to make dreambooth script work; accelerator backward not playing well"
This reverts commit 768d0ea.
1 parent 768d0ea commit 0018b62

File tree

1 file changed

+4
-29
lines changed

1 file changed

+4
-29
lines changed

‎examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py‎

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,6 @@
6767
FlowMatchEulerDiscreteScheduler,
6868
FluxPipeline,
6969
FluxTransformer2DModel,
70-
ParallelConfig,
71-
enable_parallelism,
7270
)
7371
from diffusers.optimization import get_scheduler
7472
from diffusers.training_utils import (
@@ -807,8 +805,6 @@ def parse_args(input_args=None):
807805
],
808806
help="The image interpolation method to use for resizing images.",
809807
)
810-
parser.add_argument("--context_parallel_degree", type=int, default=1, help="The degree for context parallelism.")
811-
parser.add_argument("--context_parallel_type", type=str, default="ulysses", help="The type of context parallelism to use. Choose between 'ulysses' and 'ring'.")
812808

813809
if input_args is not None:
814810
args = parser.parse_args(input_args)
@@ -1351,28 +1347,15 @@ def main(args):
13511347

13521348
logging_dir = Path(args.output_dir, args.logging_dir)
13531349

1354-
cp_degree = args.context_parallel_degree
13551350
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
1356-
if cp_degree > 1:
1357-
kwargs = []
1358-
else:
1359-
kwargs = [DistributedDataParallelKwargs(find_unused_parameters=True)]
1351+
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
13601352
accelerator = Accelerator(
13611353
gradient_accumulation_steps=args.gradient_accumulation_steps,
13621354
mixed_precision=args.mixed_precision,
13631355
log_with=args.report_to,
13641356
project_config=accelerator_project_config,
1365-
kwargs_handlers=kwargs,
1366-
)
1367-
if cp_degree > 1 and not torch.distributed.is_initialized():
1368-
if not torch.cuda.is_available():
1369-
raise ValueError("Context parallelism is only tested on CUDA devices.")
1370-
if os.environ.get("WORLD_SIZE", None) is None:
1371-
raise ValueError("Try launching the program with `torchrun --nproc_per_node <NUM_GPUS>` instead of `accelerate launch <NUM_GPUS>`.")
1372-
torch.distributed.init_process_group("nccl")
1373-
rank = torch.distributed.get_rank()
1374-
rank = accelerator.process_index
1375-
torch.cuda.set_device(torch.device("cuda", rank % torch.cuda.device_count()))
1357+
kwargs_handlers=[kwargs],
1358+
)
13761359

13771360
# Disable AMP for MPS.
13781361
if torch.backends.mps.is_available():
@@ -1994,14 +1977,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
19941977
power=args.lr_power,
19951978
)
19961979

1997-
# Enable context parallelism
1998-
if cp_degree > 1:
1999-
ring_degree = cp_degree if args.context_parallel_type == "ring" else None
2000-
ulysses_degree = cp_degree if args.context_parallel_type == "ulysses" else None
2001-
transformer.parallelize(config=ParallelConfig(ring_degree=ring_degree, ulysses_degree=ulysses_degree))
2002-
transformer.set_attention_backend("_native_cudnn")
2003-
parallel_context = enable_parallelism(transformer) if cp_degree > 1 else nullcontext()
2004-
20051980
# Prepare everything with our `accelerator`.
20061981
if not freeze_text_encoder:
20071982
if args.enable_t5_ti:
@@ -2156,7 +2131,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
21562131
logger.info(f"PIVOT TRANSFORMER {epoch}")
21572132
optimizer.param_groups[0]["lr"] = 0.0
21582133

2159-
with accelerator.accumulate(models_to_accumulate), parallel_context:
2134+
with accelerator.accumulate(models_to_accumulate):
21602135
prompts = batch["prompts"]
21612136

21622137
# encode batch prompts when custom prompts are provided for each image -

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /