Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Why use torch.repeat instead of torch.repeat_interleave in train_dreambooth_lora_sdxl #12291

Unanswered
Light-yzc asked this question in Q&A
Discussion options

in train_dreambooth_lora_sdxl.py

you can see those codes:

 if not args.train_text_encoder:
 unet_added_conditions = {
 "time_ids": add_time_ids,
 "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
 }
 prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
 model_pred = unet(
 inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
 timesteps,
 prompt_embeds_input,
 added_cond_kwargs=unet_added_conditions,
 return_dict=False,
 )[0]
 else:
 unet_added_conditions = {"time_ids": add_time_ids}
 prompt_embeds, pooled_prompt_embeds = encode_prompt(
 text_encoders=[text_encoder_one, text_encoder_two],
 tokenizers=None,
 prompt=None,
 text_input_ids_list=[tokens_one, tokens_two],
 )
 unet_added_conditions.update(
 {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
 )
 prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
 model_pred = unet(
 inp_noisy_latents if args.do_edm_style_training else noisy_model_input,
 timesteps,
 prompt_embeds_input,
 added_cond_kwargs=unet_added_conditions,
 return_dict=False,
 )[0]

it means we should repeat prompt_embeds many time for every picture.

but in collect_fn:

def collate_fn(examples, with_prior_preservation=False):
 pixel_values = [example["instance_images"] for example in examples]
 prompts = [example["instance_prompt"] for example in examples]
 original_sizes = [example["original_size"] for example in examples]
 crop_top_lefts = [example["crop_top_left"] for example in examples]
 # Concat class and instance examples for prior preservation.
 # We do this to avoid doing two forward passes.
 if with_prior_preservation:
 pixel_values += [example["class_images"] for example in examples]
 prompts += [example["class_prompt"] for example in examples]
 original_sizes += [example["original_size"] for example in examples]
 crop_top_lefts += [example["crop_top_left"] for example in examples]
 pixel_values = torch.stack(pixel_values)
 pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
 batch = {
 "pixel_values": pixel_values,
 "prompts": prompts,
 "original_sizes": original_sizes,
 "crop_top_lefts": crop_top_lefts,
 }
 return batch

you can see class_images are directly append to the batch.

but when we have no train_dataset.custom_instance_prompts provided, the prompt_embeds like:

 if not train_dataset.custom_instance_prompts:
 if not args.train_text_encoder:
 prompt_embeds = instance_prompt_hidden_states
 unet_add_text_embeds = instance_pooled_prompt_embeds
 if args.with_prior_preservation:
 prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
 unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
 # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
 # batch prompts on all training steps
 else:
 tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
 tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
 if args.with_prior_preservation:
 class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
 class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
 tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
 tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)

they seems to be[ins_token, cls_token] or [ins_embed, cls_embed]

so back to the code like

prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)

are they wrong? because if you use repeat , the embed will like: prompt_embeds_input = [ins_embed, cls_embed, ins_embed, cls_embed, ....]

but i think it should be [ins_embed, ins_embed, ....cls_embed, cls_embed]

You must be logged in to vote

Replies: 0 comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant

AltStyle によって変換されたページ (->オリジナル) /