-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Fix bug in timm.layers.drop.drop_block_2d; unify fast/slow versions; add model and unit tests. #2569
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There are two bugs in the `valid_block` code for `drop_block_2d`.
- a (W, H) grid being reshaped as (H, W)
The current code uses (W, H) to generate the meshgrid;
but then uses a `.reshape((1, 1, H, W))` to unsqueeze the block map.
The simplest fix to the first bug is a one-line change:
```python
h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))
```
This is a longer patch, that attempts to make the code testable.
Note: The current code behaves oddly when the block_size or
clipped_block_size is even; I've added tests exposing the behavior;
but have not changed it.
When you trigger the reshape bug, you get wild results:
```
$ python scratch.py
{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': False}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
[ True, True, False, False, True],
[ True, False, False, True, True],
[False, False, False, False, False]]]])
{'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': True}
grid.shape=torch.Size([1, 1, 4, 5])
tensor([[[[False, False, False, False, False],
[False, True, True, True, False],
[False, True, True, True, False],
[False, False, False, False, False]]]])
```
Here's a tiny exceprt script, showing the problem;
it generated the above output.
```python
import torch
from typing import Tuple
def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]:
"""generate N-D grid in dimension order.
The ndgrid function is like meshgrid except that the order of the first two input arguments are switched.
That is, the statement
[X1,X2,X3] = ndgrid(x1,x2,x3)
produces the same result as
[X2,X1,X3] = meshgrid(x2,x1,x3)
This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make
torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy').
"""
try:
return torch.meshgrid(*tensors, indexing='ij')
except TypeError:
# old PyTorch < 1.10 will follow this path as it does not have indexing arg,
# the old behaviour of meshgrid was 'ij'
return torch.meshgrid(*tensors)
def valid_block(H, W, block_size, fix_reshape=False):
clipped_block_size = min(block_size, H, W)
if fix_reshape:
# This should match the .reshape() dimension order below.
h_i, w_i = ndgrid(torch.arange(H), torch.arange(W))
else:
# The original produces crazy stride patterns, due to .reshape() offset winding.
# This is only visible when H != W.
w_i, h_i = ndgrid(torch.arange(W), torch.arange(H))
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
valid_block = torch.reshape(valid_block, (1, 1, H, W))
return valid_block
def main():
common_args = dict(H=4, W=5, block_size=3)
for fix in [False, True]:
args = dict(H=4, W=5, block_size=3, fix_reshape=fix)
grid = valid_block(**args)
print(args)
print(f"{grid.shape=}")
print(grid)
print()
if __name__ == "__main__":
main()
```
I realized I'd been so focused on fixing the bug; I hadn't noticed that all of the meshgrid stuff was entirely un-needed. Switched to slice assignment; it does the same thing.
@crutcher indeed yeah original should have flipped the assignmetn with ndgrid or used meshgrid + indexing='xy' ... but slice assignment is clearer. Additionally, with slice assignment there's no point in using bool, should create zeroes array with x.dtype and assign 1.0 to avoid another allocation + dtype conversion.
I think there's also technically a problem with even block sizes and the padding + feature map sizing. I believe max pool w/ 'same' padding (asymmetric padding) is needed to for even blocks to work no?
HuggingFaceDocBuilderDev
commented
Aug 18, 2025
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@rwightman I deep dove on this yesterday, (I'm implementing it for burn, which is how I found this).
The weird behavior with even-sized kernels is correct, but I don't want to submit yet.
max_pool2d (and conv2d that it sits on) has weird but defined behavior about where it considers the midpoint location of even-sized kernels. I now (think) i understand the implications, and I want to document (and test) this so that the next person doesn't spend another day on it.
@rwightman Ok; i think this in a much better state.
H != W. (削除ここまで)I don't know enough about this test base to mask out the failing int/bool dtype tests for jit, so I just killed them.
@crutcher I'm out on vacation for a week and a bit so won't be able to take a closer look re merge for a bit...
Now that it is not using meshgrid, batchwise is a huge speed win (proportional to batch size), and partial_edge_blocks is extremely minor.
Also drop_block_2d_fast wasn't doing batchwise, so wasn't really "fast".
@rwightman ping?
@crutcher got busy with dinov3 and some other things after vacation, also, couldn't move forward with this one because there were too many unrelated formatting changes, I forgot to @ you though so my bad. I don't use Black style indentation, but PEP double indent for args, and would like to keep the diff to functional changes related to PR.
There are two bugs in the
valid_blockcode fordrop_block_2d.The current code uses (W, H) to generate the meshgrid; but then uses a
.reshape((1, 1, H, W))to unsqueeze the block map.The simplest fix to the first bug is a one-line change:
This is a longer patch, that attempts to make the code testable.
Note: The current code behaves oddly when the block_size or clipped_block_size is even; I've added tests exposing the behavior; but have not changed it.
When you trigger the reshape bug, you get wild results:
Here's a tiny exceprt script, showing the problem; it generated the above output.