Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

stchakwdev/Mamba_KAN

Repository files navigation

Mamba-KAN

A Rigorous Factorial Comparison of Neural Network Architectures

Python 3.9+ PyTorch 2.0+ License: MIT CI Code style: black

Documentation | Quick Start | Results | Citation


Overview

This project implements a comprehensive ×ばつ3 factorial experiment comparing neural network architectures, investigating the interplay between feedforward components and sequence modeling approaches.

Research Question

Do Kolmogorov-Arnold Networks (KAN) outperform MLPs due to their learnable B-spline activation functions, or their unique network topology?

Following Wu et al. (2024), we isolate these effects by including MLP+B-spline baselines alongside Transformer and Mamba sequence models.


Architecture

×ばつ3 FACTORIAL DESIGN │ ├─────────────────────────────────────────────────────────────────┤ │ │ │ Feedforward Type Sequence Model │ │ ════════════════ ══════════════ │ │ │ │ ┌─────────────┐ ┌─────────────┐ │ │ │ MLP │──────────│ Transformer │ ─► mlp_transformer │ │ │ (ReLU/GELU)│ │ (Attention) │ │ │ └─────────────┘ └─────────────┘ │ │ │ │ │ │ │ │ │ │ ┌─────────────┐ ┌─────────────┐ │ │ │ MLP + │──────────│ Transformer │ ─► bspline_transformer│ │ │ B-spline │ │ (Attention) │ │ │ └─────────────┘ └─────────────┘ │ │ │ │ │ │ │ │ │ │ ┌─────────────┐ ┌─────────────┐ │ │ │ Full KAN │──────────│ Transformer │ ─► kan_transformer │ │ │ (Learnable)│ │ (Attention) │ │ │ └─────────────┘ └─────────────┘ │ │ │ │ │ │ │ ┌─────────────┐ │ │ └──────────────────│ Mamba │ ─► *_mamba variants │ │ │ (SSM) │ │ │ └─────────────┘ │ │ │ └─────────────────────────────────────────────────────────────────┘">
┌─────────────────────────────────────────────────────────────────┐
│ ×ばつ3 FACTORIAL DESIGN │
├─────────────────────────────────────────────────────────────────┤
│ │
│ Feedforward Type Sequence Model │
│ ════════════════ ══════════════ │
│ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ MLP │──────────│ Transformer │ ─► mlp_transformer │
│ │ (ReLU/GELU)│ │ (Attention) │ │
│ └─────────────┘ └─────────────┘ │
│ │ │ │
│ │ │ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ MLP + │──────────│ Transformer │ ─► bspline_transformer│
│ │ B-spline │ │ (Attention) │ │
│ └─────────────┘ └─────────────┘ │
│ │ │ │
│ │ │ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Full KAN │──────────│ Transformer │ ─► kan_transformer │
│ │ (Learnable)│ │ (Attention) │ │
│ └─────────────┘ └─────────────┘ │
│ │ │ │
│ │ ┌─────────────┐ │
│ └──────────────────│ Mamba │ ─► *_mamba variants │
│ │ (SSM) │ │
│ └─────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘

Six Model Variants

Variant Feedforward Sequence Purpose
mlp_transformer MLP (ReLU/GELU) Attention Baseline
bspline_transformer MLP + B-spline Attention Isolate activation effect
kan_transformer Full KAN Attention Full KAN architecture
mlp_mamba MLP (ReLU/GELU) SSM Mamba baseline
bspline_mamba MLP + B-spline SSM Activation + SSM
kan_mamba Full KAN SSM Full KAN + SSM (novel)

Key Results

60 experiments completed: 3 models ×ばつ 2 tasks ×ばつ 10 seeds on NVIDIA H100 80GB

Results Visualization

Experiment Results

Accuracy Comparison

Task MLP Transformer KAN Transformer B-spline Transformer
Symbolic Regression 0.0077 ± 0.0023 0.0080 ± 0.0028 0.0082 ± 0.0038
Language Modeling 10.8366 ± 0.0007 10.8373 ± 0.0006 10.8363 ± 0.0017

Training Speed Comparison

Model Speed (steps/s) Time per Experiment Slowdown vs MLP
MLP Transformer 92.4 26s ×ばつ (baseline)
KAN Transformer 52.0 50s ×ばつ slower
B-spline Transformer 19.6 633s ×ばつ slower

Key Findings

Metric MLP KAN B-spline
Accuracy Best ~Equal ~Equal
Speed Fastest ×ばつ slower ×ばつ slower
Recommendation Use this If interpretability needed Not recommended

Model Comparison

Conclusions

  1. All models perform similarly on accuracy - differences are within statistical noise
  2. MLP wins on speed - fastest training with best or equal accuracy
  3. KAN is practical with efficient-kan - only ×ばつ slower (vs ×ばつ with pykan)
  4. B-spline provides no benefit - slowest model without accuracy gains

Quick Start

Installation

# Clone the repository
git clone https://github.com/stchakwdev/Mamba_KAN.git
cd Mamba_KAN
# Create environment
conda create -n mamba_kan python=3.10
conda activate mamba_kan
# Install PyTorch (adjust CUDA version as needed)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# Install package with development dependencies
pip install -e ".[dev]"

Run Quick Validation

# Quick test (1 model, 1 seed, 100 steps)
make train
# Or directly:
python scripts/run_experiment.py \
 --model mlp_transformer \
 --task symbolic \
 --seeds 1 \
 --max-steps 100 \
 --no-wandb

Run Full Comparison

# Full factorial experiment (all models, 10 seeds)
python scripts/run_full_comparison.py \
 --tasks symbolic special_functions timeseries language long_context \
 --seeds 10 \
 --max-steps 10000 \
 --output-dir ./results \
 --generate-report

Benchmark Tasks

1. Symbolic Regression (KAN-favorable)

Tests function approximation with learnable activations:

  • Basic functions: sin, cos, exp, log, sqrt
  • Special functions: Bessel (J0, J1), Legendre (P2, P3, P4)
  • Deep compositions: sin(exp(cos(x))), nested trigonometrics

2. Time Series Forecasting

Tests temporal pattern recognition:

  • Patterns: sine with trend, multi-seasonal, chaotic, AR process
  • Sequence lengths: 128, 256, 512
  • Prediction horizons: 10, 20 steps

3. Language Modeling (Mamba-favorable)

Tests long-range dependency modeling:

  • Standard sequences: 256, 512 tokens
  • Long-context: 2048, 4096 tokens
  • Mamba's O(n) complexity provides significant advantage

Project Structure

mamba_kan/
├── models/ # Model implementations
│ ├── base.py # Task-aware base class
│ ├── mlp_transformer.py # MLP-Transformer (baseline)
│ ├── bspline_transformer.py # B-spline activation baseline
│ ├── kan_transformer.py # Full KAN-Transformer
│ ├── *_mamba.py # Mamba variants
│ └── components/
│ ├── bspline_mlp.py # Learnable B-spline activation
│ ├── kan_layers.py # KAN building blocks
│ ├── mamba_layers.py # Mamba with B-spline support
│ └── transformer_layers.py
├── training/
│ ├── trainer.py # PyTorch Lightning module
│ ├── scheduler.py # Learning rate schedules
│ └── callbacks.py # Training monitoring
├── analysis/
│ └── statistics.py # Friedman, Wilcoxon, bootstrap CI
├── visualization/ # Plotting and dashboards
│ ├── plots.py # Training curves, comparisons
│ ├── heatmaps.py # Statistical visualizations
│ ├── animations.py # GIF generation
│ └── dashboard.py # Interactive HTML reports
├── data/
│ └── datasets.py # All benchmark datasets
└── configs/
 └── base_config.py # Configuration system
scripts/
├── run_experiment.py # Single experiment runner
├── run_full_comparison.py # Full factorial comparison
├── generate_assets.py # Generate README visualizations
└── runpod_setup.sh # Cloud GPU setup

Statistical Analysis

The project implements rigorous statistical testing following Demšar (2006):

  • Friedman Test: Non-parametric comparison across multiple classifiers
  • Wilcoxon Signed-Rank: Pairwise post-hoc comparisons
  • Holm-Bonferroni Correction: Multiple comparison adjustment
  • Bootstrap Confidence Intervals: Effect size uncertainty quantification
from mamba_kan.analysis import run_full_analysis, print_analysis_summary
results = {
 'mlp_transformer': [0.52, 0.51, 0.53, ...], # Loss per seed
 'bspline_transformer': [0.48, 0.47, 0.49, ...],
 'kan_transformer': [0.45, 0.44, 0.46, ...],
 # ... other models
}
analysis = run_full_analysis(results)
print_analysis_summary(analysis)

Hardware Requirements

Configuration Specification
Minimum NVIDIA GPU, 8GB VRAM, 16GB RAM
Recommended RTX 3080+ or A100, 32GB RAM
Full experiments H100, 80GB VRAM

Cloud Deployment

# RunPod H100 setup
chmod +x scripts/runpod_setup.sh
./scripts/runpod_setup.sh
# Run full experiment suite
python scripts/run_full_comparison.py --task all --seeds 10

Development

# Install dev dependencies
pip install -e ".[dev]"
# Run tests
make test
# Run linting
make lint
# Format code
make format
# Generate visualizations from results
make visualize

Documentation


References

Papers

Resources

  • efficient-kan - Fast KAN implementation (used in this project, ×ばつ faster than pykan)
  • pykan - Official KAN implementation
  • mamba-ssm - Official Mamba implementation
  • awesome-kan - Comprehensive KAN resources

Citation

@misc{mamba_kan_2025,
 title={Mamba-KAN: A Factorial Comparison of Neural Network Architectures},
 author={Samuel T. Chakwera},
 year={2025},
 url={https://github.com/stchakwdev/Mamba_KAN},
 note={Investigating whether KAN advantages stem from B-spline activations or network topology}
}

License

MIT License - see LICENSE for details.


Back to Top

Made with PyTorch Lightning and scientific rigor

About

A rigorous 2x3 factorial comparison of neural network architectures: KAN vs MLP feedforward layers combined with Transformer vs Mamba sequence models. Investigates whether KAN advantages stem from B-spline activations or network topology.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

Contributors

AltStyle によって変換されたページ (->オリジナル) /