PyTorch NaNs are silent killers. This hook catches them at the exact layer and batch — with ~3–4 ms overhead vs ~7–8 ms for set_detect_anomaly.
A lightweight forward-hook NaN/Inf detector for PyTorch — catches the exact layer and batch where NaNs first appear, with ~3 ms overhead.
Companion code for the Towards Data Science article:
PyTorch NaNs Are Silent Killers — I Built a 3 ms Hook That Pinpoints Them to the Exact Layer and Batch
Training loop → Forward hooks → NaNEvent → layer + batch + stats
↑
Gradient norm guard (catches explosion before NaN)
| Method | Mean (ms) | Overhead |
|---|---|---|
| No detection | ~0.60 ms | baseline |
| NaNDetector | ~3–4 ms | ×ばつ |
set_detect_anomaly |
~7–8 ms | ×ばつ |
CPU · 4-layer MLP · batch size 64 · 30 forward passes.
On GPU with large models, set_detect_anomaly reaches ×ばつ.
| Operation | Latency |
|---|---|
| Hook check per layer | ~0.02 ms |
| Full forward pass overhead | ~0.10 ms |
set_detect_anomaly equivalent |
~7–8 ms |
No package. Single file — drop it into your project:
curl -O https://raw.githubusercontent.com/Emmimal/pytorch-nan-detector/main/nan_detector.py
Requires: torch>=1.11 · matplotlib>=3.5 · Python>=3.10
from nan_detector import NaNDetector with NaNDetector(model) as det: for batch_idx, (x, y) in enumerate(loader): det.set_batch(batch_idx) loss = criterion(model(x), y) loss.backward() det.check_grad_norms() optimizer.step() if det.triggered: print(det.event) break
NaN/Inf detected! [FORWARD PASS]
Batch : 12
Layer : layer4
Type : Linear
Flags : NaN in INPUT, NaN in OUTPUT
Out shape : (8, 1)
Out stats : min=n/a max=n/a mean=n/a (all non-finite)
Run all three demos and generate plots:
python nan_detector.py
Loss curve — NaN detected at batch 12 Loss curve
Gradient norm explosion — caught one step before NaN Grad norms
Worth it when you have:
- Training runs longer than a few minutes where
set_detect_anomalyslowdown is unacceptable - A need to know which layer originated the NaN, not just that one occurred
- Multi-worker
DataLoadersetups where anomaly detection is unusable at scale
Skip it when you have:
- Quick single-run debugging on a tiny model —
set_detect_anomalyis fine - NaNs originating inside a custom CUDA kernel (forward hooks won't see it)
- Hard latency requirements under 1 ms per forward pass
- Forward hooks won't catch NaNs inside
torch.autograd.Function.backward()— usecheck_backward=True - Hook overhead scales with model depth — use
skip_typesto exclude non-parametric layers on very deep models - Token estimation and GPU benchmarks not included — the ×ばつ figure is from PyTorch docs, not measured here
MIT — see LICENSE