-
Notifications
You must be signed in to change notification settings - Fork 9.8k
fix(train): un-rot tch-backend vs workspace tch 0.24 + clamp/floor-div runtime bugs + real --val-dir validation and --eval-only mode#1014
Open
stuinfla wants to merge 1 commit into
Conversation
...ixes, clamp_min + floor-div runtime bugs, rpath build.rs, real --val-dir validation and --eval-only mode (ruvnet#1010) The optional tch-backend feature is not CI-gated (mod metrics/model/trainer/ losses/proof are cfg'd behind it), so it drifted: 13 compile errors against the workspace's own tch = "0.24" pin, plus latent runtime bugs the revived tests exposed. All verified by compiling, running the unit suite, and training + evaluating end-to-end on real CSI: - tch 0.24 API: Vec<f64>::from(tensor) -> TryFrom (model.rs, proof.rs, trainer.rs — 8 sites); (&tensor % i64) Rem dropped -> .fmod() (trainer.rs); t.numel() usize/i64 sum mismatch (model.rs:185); petgraph EdgeReference methods need use petgraph::visit::EdgeRef (metrics.rs). - Runtime panic: losses.rs .clamp(1.0, f64::MAX) — PyTorch >= 2.x rejects f64::MAX -> f32 tensor conversion; first loss call panics. .clamp_min(1.0) is the intended semantic. - Latent decode bug: trainer.rs heatmap_to_keypoints used (&arg / w) on the integer argmax tensor — TRUE division in torch >= 1.6 (36/8 = 4.5), so every decoded y row was skewed by up to one row. floor_divide_scalar restores integer row decomposition. - macOS/Linux + LIBTORCH_USE_PYTORCH=1: binaries and the unit-test binary failed at dyld (libtorch_cpu not found) — torch-sys only rpaths its own static lib. New build.rs emits cargo:rustc-link-arg=-Wl,-rpath,<torch>/lib (plain variant: -bins/-tests do not reach the lib unit-test binary). - Test fixes (pre-existing, tch-gated so invisible to CI): losses.rs:749 E0507 move-out-of-FnMut-capture; two gaussian-peak thresholds that were mathematically unreachable at half-pixel keypoints (max attainable exp(-0.5/(2*sigma^2)) = 0.9394 / 0.8007 vs asserted 0.95 / 0.9). - Real validation: --val-dir <MM-Fi dir> uses a held-out directory as the validation set instead of the always-synthetic one (which made val_pck noise and checkpoint selection blind on real data). New --eval-only --checkpoint <FILE> [--dump-preds <JSONL>] mode runs the validation loop standalone via Trainer::evaluate_with_dump. Suite: 194/195 with --features tch-backend (the 1 failure is a tch-0.24 vs PyTorch-2.11 .pt jit reload interplay, pre-existing; checkpoint reload verified via .safetensors reproducing PCK@0.2=0.3999 exactly). Closes ruvnet#1010 Co-Authored-By: claude-flow <ruv@ruv.net>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
wifi-densepose-train's optionaltch-backendfeature does not compile against the workspace's owntch = "0.24"pin (13 errors), and even after compiling, the first real training step panics. Becausemod metrics,mod model,mod trainer, etc. are all#[cfg(feature = "tch-backend")]and CI builds--no-default-features, none of this is visible to the PR gate — the feature bit-rotted silently. Full write-up with error sites: #1010.On top of that,
bin/train.rsalways pairs a real--data-dirwith a synthetic validation set, so on real dataval_pckis noise (~0.02–0.06) and best-checkpoint selection / early stopping are driven by that noise — silently.Root cause → fix (file:line against current main)
1. tch 0.24 API drift (compile errors):
Vec::<f64>::from(tensor)→ conversions moved toTryFromin tch 0.24 —model.rs:300,359,proof.rs:156,160,264,trainer.rs:585,590,643,657(&tensor % i64)→ scalarRemimpl dropped; use.fmod()—trainer.rs:627t.numel()returnsusize, summed intoi64—model.rs:185(as i64)EdgeReference::id()/.target()/.weight()now needuse petgraph::visit::EdgeRef;—metrics.rs(errors at 1103–1105)2. Runtime panic:
losses.rs:121,168,237—.clamp(1.0, f64::MAX): PyTorch ≥ 2.x rejects convertingf64::MAXinto an f32 tensor ("cannot be converted to type float without overflow"), so the first loss call panics..clamp_min(1.0)is the exact intended semantic.3. dyld failure with
LIBTORCH_USE_PYTORCH=1(macOS/Linux): torch-sys only rpaths its own cc-built static lib, which does not propagate to downstream binaries —traindies at launch withlibtorch_cpunot found, and so does the crate's own unit-test binary undercargo test --features tch-backend. Newbuild.rsemitscargo:rustc-link-arg=-Wl,-rpath,<torch>/libwhen that env var is set (plainrustc-link-arg, because the-bins/-testsvariants do not reach the lib unit-test binary — verified withotool -l); a no-op otherwise.4. Latent decode bug exposed by the revived tests:
trainer.rs:694heatmap_to_keypoints—(&arg / w)on the integer argmax tensor is TRUE division in torch ≥ 1.6 (36/8 = 4.5), so every decoded y row was skewed by up to one row.floor_divide_scalar(w)restores integer row decomposition (the unit test pins center peak (4,4) → 4/7 exactly).5. Test-code fixes (pre-existing, tch-gated so invisible to CI):
losses.rs:749— E0507:moveclosure inside anFnMutclosure moves the capturedheatmapsarray on every outer iteration; shadowed withlet heatmaps = &heatmaps;losses.rsgaussian-peak thresholds were mathematically unreachable: kp 0.5 sits on a half-pixel, so the best attainable peak is exp(−0.5/(2σ2)) — 0.9394 for σ=2 (test demanded > 0.95) and 0.8007 for σ=1.5 (test demanded > 0.9). Relaxed to 0.93 / 0.79 with the math in comments.6. No real-validation path: new
--val-dir <DIR>loads a held-out MM-Fi directory viaMmFiDataset::discoveras the validation set (synthetic fallback unchanged when the flag is absent). New--eval-only --checkpoint <FILE> [--dump-preds <JSONL>]mode loads a checkpoint and runs the existing validation loop standalone (Trainer::evaluate_with_dumpdumps per-sample pred+GT keypoints in dataset order for offline rendering/analysis).Test evidence
cargo check -p wifi-densepose-cli -p wifi-densepose-signal -p wifi-densepose-train --no-default-features— clean (matches CI gate)LIBTORCH_USE_PYTORCH=1 cargo check -p wifi-densepose-train --features tch-backend— clean against PyTorch 2.11.0 (the 3 remaining warnings —model.rs:31unusednn::ModuleT,proof.rs:28unusedCsiDataset,trainer.rs:723unusednum_kp— are pre-existing on main and deliberately untouched)cargo test -p wifi-densepose-train --no-default-features— 7/7 passLIBTORCH_USE_PYTORCH=1 cargo test -p wifi-densepose-train --features tch-backend --lib -- --test-threads=1— 194/195 pass. Honest accounting of what still fails and why (all pre-existing, none regressions from this PR):model::tests::save_and_load_roundtripfails against PyTorch 2.11 libtorch: tch 0.24's.pt(jit_load_parameters) reload hitsExpected GenericDict but got Object. Version interplay, not this code — real checkpoint reload verified via.safetensors(below). Likely passes against the libtorch generation tch 0.24 targets.proof::tests::hash_model_weights_is_deterministicandgenerate_and_verify_hash_matchesare racy under parallel test threads (other tests draw from torch's shared global RNG between the twomanual_seedcalls); both pass with--test-threads=1, so they are counted above as passing. A proper fix (RNG mutex) felt out of scope.--val-dirreached 40% cross-subject PCK@0.2 (vs 11.6% reported in ADR-150), with checkpoint selection driven by the real metric instead of synthetic noise.--eval-onlyround-trip verified: reloadingbest_epoch0016_pck0.3999(as.safetensors— see the PyTorch-2.11.ptnote above) and re-running the validation loop reproduces exactly PCK@0.2 = 0.3999 / OKS = 0.1277 (n=3191) on the original val set — the eval harness is bit-faithful to training-time validation.Formatting note:
src/ablation.rsandsrc/bin/aa_score_runner.rscurrently failcargo fmt --checkon main; left untouched here to keep this PR surgical.Closes #1010
🤖 Generated with claude-flow