-
Notifications
You must be signed in to change notification settings - Fork 6.4k
-
in train_dreambooth_lora_sdxl.py
you can see those codes:
if not args.train_text_encoder: unet_added_conditions = { "time_ids": add_time_ids, "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), } prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, return_dict=False, )[0] else: unet_added_conditions = {"time_ids": add_time_ids} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two], ) unet_added_conditions.update( {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( inp_noisy_latents if args.do_edm_style_training else noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions, return_dict=False, )[0]
it means we should repeat prompt_embeds many time for every picture.
but in collect_fn:
def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] prompts = [example["instance_prompt"] for example in examples] original_sizes = [example["original_size"] for example in examples] crop_top_lefts = [example["crop_top_left"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] prompts += [example["class_prompt"] for example in examples] original_sizes += [example["original_size"] for example in examples] crop_top_lefts += [example["crop_top_left"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() batch = { "pixel_values": pixel_values, "prompts": prompts, "original_sizes": original_sizes, "crop_top_lefts": crop_top_lefts, } return batch
you can see class_images are directly append to the batch.
but when we have no train_dataset.custom_instance_prompts provided, the prompt_embeds like:
if not train_dataset.custom_instance_prompts: if not args.train_text_encoder: prompt_embeds = instance_prompt_hidden_states unet_add_text_embeds = instance_pooled_prompt_embeds if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) # if we're optimizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the # batch prompts on all training steps else: tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) if args.with_prior_preservation: class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
they seems to be[ins_token, cls_token] or [ins_embed, cls_embed]
so back to the code like
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
are they wrong? because if you use repeat
, the embed will like: prompt_embeds_input = [ins_embed, cls_embed, ins_embed, cls_embed, ....]
but i think it should be [ins_embed, ins_embed, ....cls_embed, cls_embed]
Beta Was this translation helpful? Give feedback.
All reactions
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment