Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 2b7deff

Browse files
vladmandicDN6github-actions[bot]
authored
fix scale_shift_factor being on cpu for wan and ltx (#12347)
* wan fix scale_shift_factor being on cpu * apply device cast to ltx transformer * Apply style fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 941ac9c commit 2b7deff

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

‎src/diffusers/models/transformers/transformer_ltx.py‎

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,9 @@ def forward(
353353
norm_hidden_states = self.norm1(hidden_states)
354354

355355
num_ada_params = self.scale_shift_table.shape[0]
356-
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
356+
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
357+
batch_size, temb.size(1), num_ada_params, -1
358+
)
357359
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
358360
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
359361

‎src/diffusers/models/transformers/transformer_wan.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,12 @@ def forward(
682682
# 5. Output norm, projection & unpatchify
683683
if temb.ndim == 3:
684684
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
685-
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
685+
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
686686
shift = shift.squeeze(2)
687687
scale = scale.squeeze(2)
688688
else:
689689
# batch_size, inner_dim
690-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
690+
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
691691

692692
# Move the shift and scale tensors to the same device as hidden_states.
693693
# When using multi-GPU inference via accelerate these will be on the

‎src/diffusers/models/transformers/transformer_wan_vace.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(
103103
control_hidden_states = control_hidden_states + hidden_states
104104

105105
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
106-
self.scale_shift_table + temb.float()
106+
self.scale_shift_table.to(temb.device) + temb.float()
107107
).chunk(6, dim=1)
108108

109109
# 1. Self-attention
@@ -361,7 +361,7 @@ def forward(
361361
hidden_states = hidden_states + control_hint * scale
362362

363363
# 6. Output norm, projection & unpatchify
364-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
364+
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
365365

366366
# Move the shift and scale tensors to the same device as hidden_states.
367367
# When using multi-GPU inference via accelerate these will be on the

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /