WeSearch

PyTorch NaNs Are Silent Killers — So I Built a 3ms Hook to Catch Them at the Exact Layer

Emmimal P Alexander· ·10 min read · 0 reactions · 0 comments · 1 view
PyTorch NaNs Are Silent Killers — So I Built a 3ms Hook to Catch Them at the Exact Layer

NaNs don’t crash your training — they quietly destroy it. After losing hours to a silent failure in a ResNet training run, I built a lightweight detector that pinpoints the exact layer and batch where things break. Using forward hooks and gradient checks, it catches issues early with minimal overhead — without slowing your model to a crawl.

Original article
Towards Data Science · Emmimal P Alexander
Read full at Towards Data Science →
Full article excerpt tap to expand

Deep Learning PyTorch NaNs Are Silent Killers — So I Built a 3ms Hook to Catch Them at the Exact Layer This forward-hook detector catches NaNs and exploding gradients at the exact layer and batch they first appear — with ~3–4 ms overhead vs ~7–8 ms for set_detect_anomaly on CPU. On GPU, the gap becomes significantly larger. Emmimal P Alexander Apr 28, 2026 11 min read Share Image by the author, generated with ChatGPT (DALL·E) TL;DR NaNs don’t originate where they appear — they silently propagate across layers torch.autograd.set_detect_anomaly is too slow and often misleading for real debugging A forward hook–based detector can catch NaNs at the exact layer and batch they first occur Overhead is ~3–4 ms per forward pass, far lower than anomaly detection (especially on GPU) Gradient explosion is the real root cause in most cases — catching it early prevents NaNs entirely The system logs structured events (layer, batch, stats) for precise debugging Designed for production: thread-safe, memory-bounded, and scalable It was batch 47,000. A ResNet variant I had been training for six hours on a custom medical imaging dataset. The loss was converging cleanly — 1.4, 1.1, 0.87, 0.73 — and then, nothing. Not an error. Not a crash. Just nan. I added torch.autograd.set_detect_anomaly(True) and restarted. The training slowed to a crawl — roughly 7–10× longer per batch on CPU alone — and after three hours I finally got a stack trace pointing to a layer that, frankly, looked fine. The real culprit was a learning rate scheduler interacting badly with a custom normalization layer two layers upstream. set_detect_anomaly had pointed me at the symptom, not the source. That debugging session cost me most of a day. So I built something better. NaNs don’t crash your model — they quietly corrupt it. By the time you notice, you’re already debugging the wrong layer. Complete code: https://github.com/Emmimal/pytorch-nan-detector/ The Problem with set_detect_anomaly PyTorch ships with torch.autograd.set_detect_anomaly(True), which is the standard recommendation for debugging NaN issues. It works by retaining the full computation graph and checking for anomalies during the backward pass. This is powerful, but it comes with serious costs that make it unsuitable for anything beyond a quick local sanity check. The core issue is that it forces PyTorch’s autograd engine into a synchronous mode where it saves intermediate activations for every single operation. On GPU, this means breaking the asynchronous execution pipeline — every kernel launch has to complete before the next one begins. The result, as reported in the PyTorch documentation and widely observed in practice, is an overhead that ranges from roughly 10–15× on CPU to 50–100× on GPU for larger models [1][2]. There is a second problem: set_detect_anomaly points you at where the NaN propagated to in the backward pass, not necessarily where it originated. If a NaN enters your network at layer 3 of a 50-layer model, the backward pass will surface an error somewhere in the gradient computation for a later layer, and you are left working backward from there. My benchmark, run on a small CPU MLP (64→256→256→10), measured: MethodMean latencyOverhead vs baselineNo detection~0.60 msbaselineNaNDetector (forward hooks)~3–4 ms~5–6×set_detect_anomaly~7–8 ms~12–13× Forward hook–based NaN detection adds ~3 ms per pass, while set_detect_anomaly adds ~7 ms — a small gap here, but a major slowdown at scale, especially on GPU.…

This excerpt is published under fair use for community discussion. Read the full article at Towards Data Science.

Anonymous · no account needed
Share 𝕏 Facebook Reddit LinkedIn Email

Discussion

0 comments

More from Towards Data Science