Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

NorMuon: opt-in nan_guard_fallback skips step on non-finite NS output#79

Open
JohnLangford wants to merge 2 commits into
main from
jcl/issue-76-nan-fallback
Open

NorMuon: opt-in nan_guard_fallback skips step on non-finite NS output #79
JohnLangford wants to merge 2 commits into
main from
jcl/issue-76-nan-fallback

Conversation

@JohnLangford

@JohnLangford JohnLangford commented May 8, 2026

Copy link
Copy Markdown
Contributor

Summary

Re: #76. Defensive guard for the intermittent NaN bug in NorMuon +
gram-newton-schulz + quack-kernels 0.4.1: once a single rank's NS output
goes non-finite, the bad value poisons the parameter, all-reduces into
the next step, and the run is dead. With this PR, opting into
nan_guard_fallback=True lets all ranks agree (via a single
all_reduce(MAX) of one byte) to skip the entire post-ortho update on
detection.

  • V (variance buffer) and X (param) stay strictly unchanged on a
    skipped step. M (momentum) was updated from the clean gradient in
    pre-orthogonalize and is left alone.
  • The early-return semantics are deliberately "this batch never
    happened" rather than "zero out U and run normalization", because
    decaying V on zero updates contaminates future steps and applying
    weight decay alone is a half-update we don't want.

Companion to #78 (the env-gated capture wrapper that helps identify the
offending input). #78 is for diagnosis; this PR is for surviving the
issue in production until the upstream quack/gns regression is fixed.

Cost

Off (default) On
Per-step 1 Python branch 1-byte all_reduce(MAX) + .item() per shape group
Behavior change None Skip step on detection; emit RuntimeWarning on rank 0

The sync is required: if rank 0 takes the fallback and rank 1 doesn't,
DDP weights drift or the next megabatch alltoall mismatches and
deadlocks. Single-byte allreduce is microseconds; total overhead is
dominated by the device->host .item() sync.

Usage

opt = NorMuon(
 params,
 distributed_mesh=process_group,
 use_gram_newton_schulz=True,
 use_triton=True,
 nan_guard_fallback=True, # opt-in defense for issue #76
)

Test plan

  • test_fallback_off_lets_nan_propagate_to_params — baselines that
    the bug actually exists in unit-test form so the positive test isn't
    trivially green
  • test_fallback_on_skips_step_when_ns_returns_nan — bit-exact
    param + zero variance buffer + RuntimeWarning on rank 0
  • test_fallback_on_does_not_block_normal_step — guard is inert
    when NS output is finite
  • Existing test_optimizers.py::TestNorMuon (5 tests) still pass

All run single-rank on CUDA without NCCL.

Issue #76 reports intermittent NaN parameters with NorMuon +
gram-newton-schulz + quack-kernels 0.4.1 on multi-GPU DDP. Once a single
rank's NS output goes non-finite, the post-ortho update poisons the
parameter, the bad value gets all-reduced into the gradient/state on the
next step, and the run is dead.
Add an opt-in ``nan_guard_fallback`` arg to NorMuon. After the megabatch
orthogonalization completes, all ranks check ``isfinite(U_stacked).all()``
and exchange the result via a single-byte ``all_reduce(MAX)``. If any
rank flagged non-finite, every rank early-returns from
``normuon_update_megabatch_async``: V (variance buffer) and X (param)
stay strictly unchanged, so the run state is bit-identical to "this
batch never happened" for these params. M (momentum) was already
updated in pre-orthogonalize from the clean gradient and is left alone.
Why early-return rather than zero U + run normalization:
- normuon_normalization_stacked on zero U decays V toward 0, contaminating
 future steps' normalization.
- weight_decay applies in the post-ortho path; when we know the optimizer
 step is junk, it's cleaner to skip everything than to apply a
 half-update.
Cost when triggered: one tiny allreduce + one device->host sync per
shape group per step. Cost when ``nan_guard_fallback=False`` (default):
zero - just one Python-level branch in the optimizer step.
Sync is required for correctness in the distributed case: if rank 0
takes the fallback and rank 1 doesn't, DDP weights drift or a future
collective deadlocks because rank 1 advances to the next step's
megabatch alltoall while rank 0 hasn't.
Tests (single-rank CUDA, no NCCL):
- ``test_fallback_off_lets_nan_propagate_to_params`` baselines that the
 bug exists when the guard is off, so the next test isn't trivially
 green.
- ``test_fallback_on_skips_step_when_ns_returns_nan`` verifies bit-exact
 param + zero variance buffer + RuntimeWarning emission on rank 0.
- ``test_fallback_on_does_not_block_normal_step`` verifies the guard is
 inert when NS output is finite (params change as usual).
Existing NorMuon tests in test_optimizers.py still pass.
...gabatch's
The head-split and FSDP2 batch-sharded paths in `_create_ortho_tasks`
deliberately override `process_group=None` for the megabatch because no
alltoall is needed there. Gating the nan-flag allreduce on that same
`process_group` silently disabled the sync in those cases:
 - Head-split + DDP: params replicated across ranks, but only some
 ranks may take the fallback. DDP weights drift -- the exact failure
 mode this guard was supposed to prevent.
 - Batch-sharded FSDP2: each rank owns a different shard of the same
 logical param. Divergent skips leave the logical tensor torn (some
 shards stepped, some not), violating the "this batch never happened"
 invariant.
Thread the optimizer's full `self._process_group` through as a separate
`nan_sync_process_group` argument so the nan-skip decision agrees across
all ranks regardless of the megabatch's local-vs-collective config.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Reviewers

No reviews

Assignees

No one assigned

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

1 participant

AltStyle によって変換されたページ (->オリジナル) /