Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 69cdc25

Browse files
a-r-r-o-wsayakpaul
andauthored
Fix group offloading synchronization bug for parameter-only GroupModule's (#12077)
* update * update * refactor * fuck yeah * make style * Update src/diffusers/hooks/group_offloading.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Update src/diffusers/hooks/group_offloading.py --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent cfd6ec7 commit 69cdc25

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

‎src/diffusers/hooks/group_offloading.py‎

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,6 @@ def _offload_to_memory(self):
245245
param.data = self.cpu_param_dict[param]
246246
for buffer in self.buffers:
247247
buffer.data = self.cpu_param_dict[buffer]
248-
249248
else:
250249
for group_module in self.modules:
251250
group_module.to(self.offload_device, non_blocking=False)
@@ -303,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
303302
if self.group.onload_leader == module:
304303
if self.group.onload_self:
305304
self.group.onload_()
306-
if self.next_group is not None and not self.next_group.onload_self:
305+
306+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
307+
if should_onload_next_group:
307308
self.next_group.onload_()
308309

310+
should_synchronize = (
311+
not self.group.onload_self and self.group.stream is not None and not should_onload_next_group
312+
)
313+
if should_synchronize:
314+
# If this group didn't onload itself, it means it was asynchronously onloaded by the
315+
# previous group. We need to synchronize the side stream to ensure parameters
316+
# are completely loaded to proceed with forward pass. Without this, uninitialized
317+
# weights will be used in the computation, leading to incorrect results
318+
# Also, we should only do this synchronization if we don't already do it from the sync call in
319+
# self.next_group.onload_, hence the `not should_onload_next_group` check.
320+
self.group.stream.synchronize()
321+
309322
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
310323
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
311324
return args, kwargs

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /