import contextlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.hooks import apply_group_offloading, ModelHook, HookRegistry
from diffusers.models import ModelMixin
from diffusers.utils.logging import set_verbosity_debug
from torch.profiler import profile, record_function, ProfilerActivity
set_verbosity_debug()
class LayerOutputTrackerHook(ModelHook):
def __init__(self):
super().__init__()
self.outputs = []
def post_forward(self, module, output):
self.outputs.append(output)
return output
class Model(ModelMixin):
def __init__(self, d_model=1024, num_layers=1):
super().__init__()
self.d_model = d_model
self.input_proj = nn.Linear(1024, d_model)
# self.norm = nn.LayerNorm(d_model, elementwise_affine=True)
self.blocks = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)])
# This is problematic
self.norm = nn.LayerNorm(d_model, elementwise_affine=True)
# This works
# self.norm = nn.LayerNorm(d_model, elementwise_affine=False)
self.output_proj = nn.Linear(d_model, 1024)
def forward(self, x):
x = self.input_proj(x)
# x = self.norm(x)
for block in self.blocks:
x = block(x)
x = F.relu(x)
x = self.norm(x)
x = self.output_proj(x)
return x
def apply_layer_output_tracker_hook(model: Model):
for name, module in model.named_modules():
if not isinstance(module, (torch.nn.Linear, torch.nn.LayerNorm)):
continue
registry = HookRegistry.check_if_exists_or_initialize(module)
hook = LayerOutputTrackerHook()
registry.register_hook(hook, "layer_output_tracker")
def print_output_diffs(ref_model: Model, model: Model):
for (ref_name, ref_module), (name, module) in zip(ref_model.named_modules(), model.named_modules()):
assert ref_name == name
if not isinstance(ref_module, (torch.nn.Linear, torch.nn.LayerNorm)):
continue
ref_outputs = HookRegistry.check_if_exists_or_initialize(ref_module).get_hook("layer_output_tracker").outputs
outputs = HookRegistry.check_if_exists_or_initialize(module).get_hook("layer_output_tracker").outputs
cumulated_absmax = 0.0
for i in range(len(outputs)):
diff = ref_outputs[0] - outputs[i]
absdiff = diff.abs()
absmax = absdiff.max().item()
cumulated_absmax += absmax
if ref_name == "output_proj":
print(f"{ref_name} absmax {i}: {absmax}")
print(f"{name}: cumulated_absmax={cumulated_absmax:.5f}, num_outputs={len(outputs)}")
torch.manual_seed(42)
model_ref = Model()
model1 = Model()
model2 = Model()
model1.load_state_dict(model_ref.state_dict())
model2.load_state_dict(model_ref.state_dict())
model_ref.eval()
model1.eval()
model2.eval()
onload_device = torch.device("cuda:0")
offload_device = torch.device("cpu")
model_ref = model_ref.to(onload_device)
apply_group_offloading(
model1,
onload_device=onload_device,
offload_device=offload_device,
offload_type="block_level",
num_blocks_per_group=1,
use_stream=True,
)
apply_group_offloading(
model2,
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
apply_layer_output_tracker_hook(model_ref)
apply_layer_output_tracker_hook(model1)
apply_layer_output_tracker_hook(model2)
x = torch.randn(1, 512, 1024).to("cuda")
out_ref = model_ref(x)
def compare_outputs(out1, out2):
diff = out1 - out2
absdiff = diff.abs()
absmax = absdiff.max()
mae = absdiff.mean()
mse = (absdiff ** 2).mean()
cossim = F.cosine_similarity(out1.flatten(), out2.flatten(), dim=0)
print(f"{absmax=:.5f}, {mae=:.5f}, {mse=:.5f}, {cossim=:.5f}")
for _ in range(2):
model1(x)
print("=" * 80)
model2(x)
do_profile = False
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
# context = profile(
# activities=activities,
# record_shapes=True,
# profile_memory=True,
# with_stack=True,
# ) if do_profile else contextlib.nullcontext()
context = contextlib.nullcontext()
with context as prof:
with torch.inference_mode():
for i in range(10):
with record_function(f"model_1_run_{i}"):
output1 = model1(x)
print(i)
compare_outputs(out_ref, output1)
print()
print("=" * 80)
for i in range(10):
with record_function(f"model_2_run_{i}"):
output2 = model2(x)
print(i)
compare_outputs(out_ref, output2)
print()
print_output_diffs(model_ref, model1)
print()
print_output_diffs(model_ref, model2)
# prof.export_chrome_trace("dump_trace.json")
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=1000))
Fixes #11981.
Requires #11990 to be merged first.
code
Tested for 100 rounds with:
Testing with profiling is not helpful because the problem never shows up. See heisenbug thread: https://huggingface.slack.com/archives/C065E480NN9/p1754035222558869