Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

ahritik/StableSleep

Folders and files

NameName
Last commit message
Last commit date

Latest commit

History

5 Commits

Repository files navigation

Stabilize-While-You-Sleep: Source-Free TTA for Sleep-EDF

This project trains a compact CNN sleep-stage classifier and evaluates source-free test-time adaptation (TTA) using Tent (entropy minimization) with simple safety rails.

Why TTA? Subject/recording shift makes static models brittle. In deployment we have no labels and often cannot keep source data. Tent adapts only BatchNorm parameters online using unlabeled streams—fast, privacy-friendly, and effective.


Folder layout

tta_sleepedf/
 data/ sleepedf_mne.py # MNE-based EDF → epochs + subject splits (+ repack)
 models/ cnn_sleep.py # Compact CNN (+ tiny Temporal-SE), keeps BN for TTA
 utils/ metrics.py, common.py # Metrics, losses, scheduler, dataset, samplers
 train_source.py # Imbalance-aware source training (prior-bias init)
 tta_tent.py # Tent loop with entropy gate, EMA resets, budget
 eval_sleep.py # none / BN-only / Tent evaluation
 eda.py # Stage distribution + sample epoch plots
 figs/ paper/
 config.yaml # All hyperparameters and paths
 requirements.txt
 run_sleep_pipeline.sh # End-to-end script

Data format (raw → processed)

1) Raw input (must be in a folder named sleep-edfx)

Expected structure (examples; exact filenames may vary):

sleep-edfx/
 ├── sleep-cassette/
 │ ├── SC4001E0-PSG.edf
 │ ├── SC4001EC-Hypnogram.edf
 │ ├── SC4002E0-PSG.edf
 │ └── SC4002EC-Hypnogram.edf
 └── sleep-telemetry/
 ├── ST7011J0-PSG.edf
 ├── ST7011JV-Hypnogram.edf
 └── ...

Notes

  • Hypnogram files are detected by the substring Hypnogram in the filename (case-insensitive).
  • We start with EEG channel Fpz-Cz by default. You can pass multiple channels via --channels "Fpz-Cz,Pz-Oz" during transform.
  • Default preprocessing: resample to 100 Hz, epoch to 30 s, per-recording z-score per channel.

2) Processed shards (NPZ) + manifest

After running the transform step, you will see:

data/processed/
 ├── rec_0000.npz # one recording → many epochs
 ├── rec_0001.npz
 ├── ...
 └── manifest.json # split info + meta

Then we "repack" (copy/symlink + fresh manifest) to:

data/processed_npy/
 ├── rec_0000.npz
 ├── rec_0001.npz
 ├── ...
 └── manifest.json

Each rec_xxxx.npz contains:

  • X: float32 array with shape (N, C, T)
    • N = number of 30 s epochs in the recording
    • C = number of selected channels (default 1)
    • T = fs * epoch_sec samples (default 100 * 30 = 3000)
  • y: int64 array with shape (N, ) (sleep stage per epoch)
  • subject: string ID (e.g., "rec_0000")

Label mapping (AASM):

0: W (Wake)
1: N1
2: N2
3: N3 (N3 + N4 merged)
4: REM

Unknown/movement epochs are ignored during transform.

manifest.json schema (example):

{
 "channels": ["Fpz-Cz"],
 "fs": 100,
 "epoch_sec": 30,
 "splits": {
 "train": [
 {"id": "rec_0000", "path": "rec_0000.npz", "n": 950},
 {"id": "rec_0001", "path": "rec_0001.npz", "n": 1020}
 ],
 "val": [
 {"id": "rec_0008", "path": "rec_0008.npz", "n": 880}
 ],
 "test": [
 {"id": "rec_0012", "path": "rec_0012.npz", "n": 910}
 ]
 }
}
  • Splits are by recording (proxy for subject-level split) using the train_frac/val_frac settings.
  • You can change channels, sample rate, epoch length, and z-score behavior via CLI or config.yaml.

Python snippet to read a shard

import numpy as np, json
npz = np.load("data/processed_npy/rec_0000.npz")
X, y = npz["X"], npz["y"] # X: (N,C,T) float32, y: (N,) int64
fs = json.load(open("data/processed_npy/manifest.json"))["fs"]
print(X.shape, y.shape, fs) # e.g., (950, 1, 3000) (950,) 100

Quickstart

# 0) Install deps
python -m pip install --upgrade pip
pip install -r requirements.txt
# 1) Transform raw Sleep-EDF (adjust RAW to your 'sleep-edfx' path)
export RAW="/absolute/path/to/sleep-edfx"
python -m tta_sleepedf.data.sleepedf_mne --root "$RAW" --out data/processed --channels "Fpz-Cz" --fs 100 --epoch 30 --zscore --train_frac 0.6 --val_frac 0.2 --verbose
# 2) Repack to fast folder (symlink/copy + fresh manifest)
python -m tta_sleepedf.data.sleepedf_mne repack --src data/processed --dst data/processed_npy
# 3) Train source model (imbalance-aware)
python -m tta_sleepedf.train_source --cfg config.yaml
# 4) Evaluate (none / BN-only / Tent)
python -m tta_sleepedf.eval_sleep --cfg config.yaml --tta none
python -m tta_sleepedf.eval_sleep --cfg config.yaml --tta bn_only
python -m tta_sleepedf.eval_sleep --cfg config.yaml --tta tent
# 5) EDA
python -m tta_sleepedf.eda --root data/processed_npy

One-liner end-to-end

chmod +x run_sleep_pipeline.sh
RAW="$HOME/data/sleep-edfx" ./run_sleep_pipeline.sh

Notes & tips

  • To add a second EEG channel (e.g., Pz-Oz), either change the CLI to --channels "Fpz-Cz,Pz-Oz" during transform, or edit data.channels in config.yaml and re-run steps 1–2.
  • If early training seems unstable, reduce time masking / jitter in config.yaml:
    train:
     augment:
     time_mask_prob: 0.1
     jitter_std: 0.005
  • For a CE vs Focal comparison, set train.loss: ce in config.yaml and re-train.

EDA outputs

Running python -m tta_sleepedf.eda --root data/processed_npy produces:

  • figs/stage_distribution_counts.png
  • figs/stage_distribution_fraction.png
  • figs/example_epoch_0.png ... example_epoch_4.png

These help verify class imbalance, stage presence, and basic signal quality.


Clinical note

This code is research-grade. Any clinical use requires additional validation, governance, and regulatory review.

About

Source-Free TTA for Sleep-EDF with low demands

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

Contributors

AltStyle によって変換されたページ (->オリジナル) /