|
12 | 12 |
|
13 | 13 | class Mlp(nn.Module): |
14 | 14 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks |
| 15 | + |
| 16 | + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. |
15 | 17 | """ |
16 | 18 | def __init__( |
17 | 19 | self, |
@@ -51,6 +53,8 @@ def forward(self, x): |
51 | 53 | class GluMlp(nn.Module): |
52 | 54 | """ MLP w/ GLU style gating |
53 | 55 | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 |
| 56 | + |
| 57 | + NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. |
54 | 58 | """ |
55 | 59 | def __init__( |
56 | 60 | self, |
@@ -192,7 +196,7 @@ def forward(self, x): |
192 | 196 |
|
193 | 197 |
|
194 | 198 | class ConvMlp(nn.Module): |
195 | | - """ MLP using 1x1 convs that keeps spatial dims |
| 199 | + """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors) |
196 | 200 | """ |
197 | 201 | def __init__( |
198 | 202 | self, |
@@ -226,6 +230,8 @@ def forward(self, x): |
226 | 230 |
|
227 | 231 | class GlobalResponseNormMlp(nn.Module): |
228 | 232 | """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d |
| 233 | + |
| 234 | + NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts |
229 | 235 | """ |
230 | 236 | def __init__( |
231 | 237 | self, |
|
0 commit comments