Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d23facd

Browse files
Merge pull request #2388 from laclouis5/fix-mqa-v2
Fix MQA V2
2 parents 2d734d9 + 2d5277e commit d23facd

File tree

2 files changed

+23
-9
lines changed

2 files changed

+23
-9
lines changed

‎tests/test_layers.py‎

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import torch.nn as nn
44

5-
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d
5+
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2
66

77
import importlib
88
import os
@@ -121,6 +121,23 @@ def test_get_act_fn_none():
121121
assert get_act_fn('') is None
122122

123123

124+
@pytest.mark.parametrize("dim", [128])
125+
@pytest.mark.parametrize("dim_out", [128, 256])
126+
@pytest.mark.parametrize("use_m", [True, False])
127+
def test_mqa_v2(dim, dim_out, use_m):
128+
mqa = MultiQueryAttentionV2(dim, dim_out)
129+
130+
x = torch.randn(1, dim, 32, 48)
131+
if use_m:
132+
m = torch.randn(1, dim, 16, 24)
133+
else:
134+
m = None
135+
136+
y = mqa(x, m=m)
137+
138+
assert (y.shape) == (1, dim_out, 32, 48)
139+
140+
124141
@pytest.mark.parametrize("bias", [True, False])
125142
@pytest.mark.parametrize("expand_first", [True, False])
126143
@pytest.mark.parametrize("head_first", [True, False])
@@ -141,6 +158,3 @@ def test_attn2d(bias, expand_first, head_first, attn_mask):
141158
o2 = attn(x, mask)
142159

143160
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"
144-
145-
146-

‎timm/layers/attention2d.py‎

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,24 @@ def _reshape_input(self, t):
5959

6060
def forward(self, x, m: Optional[torch.Tensor] = None):
6161
"""Run layer computation."""
62-
s = x.shape
63-
m = m or x
62+
b, _, h, w = x.shape
63+
m = m ifmisnotNoneelse x
6464

6565
reshaped_x = self._reshape_input(x)
6666
reshaped_m = self._reshape_input(m)
6767

6868
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
6969
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
7070

71-
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
71+
attn = torch.einsum('bnhk,bmk->bnhm', q, k)*self.scale
7272
attn = attn.softmax(dim=-1)
7373
attn = self.attn_drop(attn)
7474

7575
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
7676
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
77-
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
77+
result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
7878
result = self.proj_drop(result)
79-
return result.reshape(s)
79+
return result.reshape(b, -1, h, w)
8080

8181

8282
class MultiQueryAttention2d(nn.Module):

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /