Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

ValueError occurs when parameter "layer_scale" is used in torch #2103

Open
Assignees
@nalnez13

Description

If I define a parameter with the same name as "layer_scale" in the pytorch nn.Module, as shown in the following code, a ValueError occurs.

class ConvEncoder(nn.Module):
 """
 Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
 Input: tensor with shape [B, C, H, W]
 Output: tensor with shape [B, C, H, W]
 """
 def __init__(
 self, dim, hidden_dim=64, kernel_size=3, drop_path=0.0, use_layer_scale=True
 ):
 super().__init__()
 self.dwconv = nn.Conv2d(
 dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim
 )
 self.norm = nn.BatchNorm2d(dim)
 self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
 self.act = nn.GELU()
 self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
 self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
 self.use_layer_scale = use_layer_scale
 if use_layer_scale:
 self.layer_scale = nn.Parameter(
 torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True
 )
 self.apply(self._init_weights)

In "adaptor/pytorch" Line 4174

 for node in model.graph.nodes:
 if node.op == "get_attr":
 if prefix:
 sub_name = prefix + "--" + node.target
 else:
 sub_name = node.target
 if not hasattr(model, node.target):
 continue
 if "scale" in node.target: #### This condition is not suitable
 tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
 elif "zero_point" in node.target:
 tune_cfg["get_attr"][sub_name] = int(getattr(model, node.target))
 else:
 pass
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/utils/utility.py", line 347, in fi
 res = func(*args, **kwargs)
 ^^^^^^^^^^^^^^^^^^^^^
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 3658, in quantize
 self._get_scale_zeropoint(q_model._model, q_model.q_config)
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4217, in _get_scale_zeropoint
 self._get_sub_module_scale_zeropoint(model, tune_cfg)
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
 self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
 self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
 self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4197, in _get_sub_module_scale_zeropoint
 self._get_module_scale_zeropoint(module, tune_cfg, op_name)
 File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4175, in _get_module_scale_zeropoint
 tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: only one element tensors can be converted to Python scalars

When I check the string "node.target" and the tensor value, it treats the layer_scale like a scale of quantization, as follows.

feature_extractor.patch_embed--0_input_scale_0 tensor(0.0355)
feature_extractor.network.0.0--dwconv_input_scale_0 tensor(0.0338)
feature_extractor.network.0.0--pwconv2_input_scale_0 tensor(0.2263)
feature_extractor.network.0.0--layer_scale Parameter containing:
tensor([[[ 0.0336]],
 [[ 0.0153]],
 [[ 0.0214]],
 [[ 0.0068]],
 [[ 0.0229]],
 [[ 0.0136]],
 [[ 0.0491]],
 [[ 0.0202]],
 [[ 0.0420]],
 [[ 0.0495]],
 [[ 0.0060]],
 [[ 0.0225]],
 [[ 0.0311]],
 [[ 0.0303]],
 [[ 0.0556]],
 [[ 0.0290]],
 [[ 0.0222]],
 [[ 0.0153]],
 [[ 0.0332]],
 [[ 0.0667]],
 [[ 0.0168]],
 [[ 0.0416]],
 [[ 0.0258]],
 [[ 0.0200]],
 [[ 0.0259]],
 [[ 0.0044]],
 [[ 0.0514]],
 [[ 0.0190]],
 [[ 0.0545]],
 [[ 0.0119]],
 [[ 0.0220]],
 [[ 0.0481]],
 [[ 0.0115]],
 [[ 0.0707]],
 [[ 0.0299]],
 [[ 0.0105]],
 [[ 0.0266]],
 [[ 0.0156]],
 [[ 0.0380]],
 [[ 0.0160]],
 [[ 0.0521]],
 [[ 0.0094]],
 [[-0.0133]],
 [[ 0.0585]],
 [[ 0.0216]],
 [[ 0.0102]],
 [[ 0.0297]],
 [[ 0.0104]]], requires_grad=True)

Modifying the conditional statement as below fixes the problem, but it doesn't seem to be a perfect way.

if "scale" in node.target and "layer_scale" not in node.target:

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      AltStyle によって変換されたページ (->オリジナル) /