from diffusers.utils import is_torch_availablefrom ..testing_utils import (backend_empty_cache,backend_max_memory_allocated,backend_reset_peak_memory_stats,torch_device,)if is_torch_available():import torchimport torch.nn as nnclass LoRALayer(nn.Module):"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes onlyTaken fromhttps://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77"""def __init__(self, module: nn.Module, rank: int):super().__init__()self.module = moduleself.adapter = nn.Sequential(nn.Linear(module.in_features, rank, bias=False),nn.Linear(rank, module.out_features, bias=False),)small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5nn.init.normal_(self.adapter[0].weight, std=small_std)nn.init.zeros_(self.adapter[1].weight)self.adapter.to(module.weight.device)def forward(self, input, *args, **kwargs):return self.module(input, *args, **kwargs) + self.adapter(input)@torch.no_grad()@torch.inference_mode()def get_memory_consumption_stat(model, inputs):backend_reset_peak_memory_stats(torch_device)backend_empty_cache(torch_device)model(**inputs)max_mem_allocated = backend_max_memory_allocated(torch_device)return max_mem_allocated
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。