Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

MAGICS-LAB/state_space_duality

Repository files navigation

Structured State-Space Duality Experiments

This repository hosts the experiments that illustrate the duality between diagonal state-space models (SSMs) and semi-separable attention, plus a compact suite of Mamba-vs-baseline comparisons.

Repository Layout

state_space_duality/
├── README.md # This document
├── LICENSE
├── experiments/ # Structured state-space duality demos
│ ├── __init__.py
│ ├── common.py # Shared utilities and ExperimentResult dataclass
│ ├── exp1_scalar_equivalence.py
│ ├── exp2_diagonal_equivalence.py
│ ├── exp2b_timevarying_diagonal.py
│ ├── exp3_rank_vs_state_dim.py
│ ├── exp3b_rank_vs_state_dim.py
│ ├── exp4_time_scaling.py
│ ├── exp5_softmax_rank_growth.py
│ ├── exp5b_softmax_rank_growth.py
│ ├── verify_ssd.py # Aggregate runner (uv run python -m experiments.verify_ssd)
│ ├── run_and_log.py # CLI with logging + grids
│ └── check_logs.py # Sanity checks for logged results
├── mamba_experiments/ # Small-scale Mamba vs. baseline comparisons
├── time_series_experiments/ # Synthetic time-series regression experiments + plotting
├── pyproject.toml # Project metadata + optional extras
└── uv.lock # Reproducible dependency lockfile (uv)

Setup

The SSD experiments only require NumPy (plus Matplotlib for optional Exp4 plots), while mamba_experiments/ uses PyTorch, datasets, einops, etc. Pick your preferred workflow:

  • Install via pyproject.toml extras

    uv sync

Quick Start

Run all experiments once:

uv run python -m experiments.verify_ssd

Run a single experiment module directly:

uv run python -m experiments.exp1_scalar_equivalence
uv run python -m experiments.exp2_diagonal_equivalence
uv run python -m experiments.exp2b_timevarying_diagonal
uv run python -m experiments.exp3_rank_vs_state_dim
uv run python -m experiments.exp3b_rank_vs_state_dim
uv run python -m experiments.exp4_time_scaling
uv run python -m experiments.exp4b_time_scaling_errorbars
uv run python -m experiments.exp5_softmax_rank_growth
uv run python -m experiments.exp5b_softmax_rank_growth

Running Individual Experiments

Invoke the modules directly (prepend uv run if uv manages your environment):

  • uv run python -m experiments.exp1_scalar_equivalence --T 32 --a 0.9 --seed 1
  • uv run python -m experiments.exp2_diagonal_equivalence --T 40 --seed 2 --decays 0.5 0.8
  • uv run python -m experiments.exp2b_timevarying_diagonal --T 16 --N 4 --seed 3
  • uv run python -m experiments.exp3_rank_vs_state_dim --T-values 15 30 --seeds 0 1
  • uv run python -m experiments.exp4_time_scaling --T-values 300 600 1200 --end-to-end --exp4-plot
  • uv run python -m experiments.exp4b_time_scaling_errorbars --n-trials 8 --n-repeats 5
  • uv run python -m experiments.run_and_log --experiment exp4 --repeats 10 --exp4-plot --exp4-plot-path outputs/exp4_time_scaling.png
  • uv run python -m experiments.exp5_softmax_rank_growth --T-values 80 160 320 --d-k 16

Every script has --help support.

Experiment Overview

  1. exp1_scalar_equivalence.py – Scalar SSM ≡ 1-SS attention
    Propagates a length-T sequence through a single-state recurrence and matches it with a causal 1-semiseparable attention kernel built from powers of a scalar decay a. The experiment reports the maximum absolute difference between the recurrent and attention outputs, along with the seed and decay.

  2. exp2_diagonal_equivalence.py – Diagonal SSM vs. sum of 1-SS heads
    Generalizes Exp1 to an N-state diagonal SSM. Each diagonal decay contributes one causal head whose kernel is (C[m]*B[m]) * a_m^{t-s}. The script also records the Vandermonde generator rank to show that distinct eigenvalues guarantee full-rank semiseparable generators. Control the state dimension via --exp2-n-list when using run_and_log.py.

  3. exp2b_timevarying_diagonal.py – Time-varying diagonal decays
    Allows both the recurrence weights and the input/output couplings to change per timestep. _build_time_varying_mask constructs the causal kernel ∏_{k=s+1}^t A_k. The experiment confirms that the resulting masked attention still reproduces the exact recurrent outputs even under non-stationary dynamics.

  4. exp3_rank_vs_state_dim.py – Generator rank studies
    Sweeps over sequence lengths, hand-crafted decay sets, and optional random configurations to compare the theoretical semiseparable generator rank with the empirical matrix rank of the induced attention kernel. Use --exp3-t-list, --exp3-seeds, and --exp3-random-n in run_and_log.py to extend the grid.

  5. exp3b_rank_vs_state_dim.py – Rank vs. state dimension plot
    Wraps Exp3 to sweep N and seeds with spaced decays, then plots the generator rank mean ± std against N with a y = N reference line. Saves to outputs/exp3b_rank_vs_state_dim.png by default.

  6. exp4_time_scaling.py – O(T) recurrence vs. O(T2) attention cost
    Benchmarks the wall-clock time of directly running the diagonal recurrence against explicitly forming the T ×ばつ T kernel and multiplying inputs. Toggle --exp4-end-to-end to include kernel construction time, and --exp4-plot --exp4-plot-path ... to save the matplotlib visualization generated by plot_results.

  7. exp4b_time_scaling_errorbars.py – Runtime scaling with 95% CIs
    Runs multiple random trials per sequence length and plots the mean ± 95% confidence interval for both recurrence and attention costs (styled via the Flow-KL settings) into outputs/exp4b_time_scaling.png. Configure the number of trials via --n-trials, repeats per trial via --n-repeats, and switch to end-to-end timing with --end-to-end.

  8. exp5_softmax_rank_growth.py – Softmax attention negative test
    Generates random query/key pairs, applies stable causal softmax attention, and measures matrix rank as T grows. The rank growth illustrates why unrestricted softmax attention does not stay low-rank in contrast to the semiseparable kernels above.

  9. exp5b_softmax_rank_growth.py – Softmax rank growth with variance bands
    Runs Exp5 across multiple seeds, plots the mean ± std rank vs. T, and optionally saves a rank-gap (T - rank) curve. Defaults to outputs/exp5b_rank_growth.png and outputs/exp5b_rank_growth_gap.png. Each experiment exposes a run(...) function that returns an ExperimentResult (name, human-readable details, metadata dict) for easy downstream logging.

Mamba Experiments

mamba_experiments/ compares two small Mamba variants (mamba_simple, mamba_SSD_diag_exp) on WikiText-2.

  • WikiText-2 example:

    uv run python -m mamba_experiments.train \
     --model mamba_SSD_diag_exp \
     --dataset wikitext2 \
     --seq-len 128 \
     --max-vocab 20000 \
     --train-max-samples 10000 \
     --val-max-samples 2000
  • Compare Original vs SSD Mamba with variance bands (WikiText-2, seeds 0–9):

    RUN_DIR=outputs/mamba_experiments/compare_wikitext2_s128_d64_n16_l2_e10_b64_lr1e-3_m5k1k
    for model in mamba_simple mamba_SSD_diag_exp; do
     for seed in $(seq 0 9); do
     uv run python -m mamba_experiments.train \
     --model $model \
     --dataset wikitext2 \
     --seq-len 128 \
     --max-vocab 20000 \
     --train-max-samples 5000 \
     --val-max-samples 1000 \
     --d-model 64 \
     --n-state 16 \
     --n-layers 2 \
     --epochs 10 \
     --batch-size 64 \
     --lr 1e-3 \
     --seed $seed \
     --save-plots \
     --plot-dir $RUN_DIR
     done
    done
    MPLCONFIGDIR=outputs/.mplconfig uv run python -m mamba_experiments.plot_variance_bands \
     --log-path $RUN_DIR/train_runs.jsonl \
     --out-dir $RUN_DIR/plots_compare \
     --dataset wikitext2 \
     --models mamba_simple mamba_SSD_diag_exp \
     --metric val_loss \
     --require-all-models

Helpful flags (--help for full list):

  • --model {mamba_simple,mamba_SSD_diag_exp}
  • --dataset {wikitext2} with dataset-specific args (--max-vocab, etc.)
  • Architecture/training knobs: --d-model, --n-state, --n-layers, --epochs, --batch-size, --lr
  • Logging knobs: --save-plots, --plot-dir (plots and JSONL appended to <plot_dir>/train_runs.jsonl)

Additional modules (mamba_simple.py, mamba_SSD.py, etc.) live in the same folder if you want to extend the blocks or plug them into other projects.

Time-Series Experiments

time_series_experiments/ contains a small synthetic regression benchmark (mixture-of-decays) plus JSONL logging and plotting helpers.

Run the synthetic N-sweep (diagonal SSD-style block, N=d_state):

uv run python -m time_series_experiments.exp_synthetic_ndecay --models mamba_diag_exp --N_values 1 2 --epochs 60 --seed 42 --T 100 --lambdas 0.9 0.5 --coeffs 1.0 0.7 --noise_std 1e-4

Plot the latest results (reads outputs/time_series_experiments/ts_runs.jsonl):

uv run python -m time_series_experiments.plot_runs --task synthetic_decays --latest_only --log_val_curves

Plots are written to outputs/time_series_experiments/ by default.

Notes

  • The SSD scripts in experiments/ run purely on CPU with double-precision NumPy.
  • mamba_experiments/ relies on PyTorch (and optionally GPUs) plus its extra dependencies.
  • Logged JSONL/CSV entries include the experiment name, seed, human-readable summary, and a meta dict (decays, ranks, timings, tolerances, etc.).
  • Matplotlib is only required for Exp4 plots or optional Mamba training curves.

License

This repository is licensed under the Apache License. See LICENSE for details.

AltStyle によって変換されたページ (->オリジナル) /