Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

mjun0812/PyTorch-Project-Template

Repository files navigation

PyTorch Project Template

A comprehensive, production-ready PyTorch project template with modular architecture, distributed training support, and modern tooling.

Features

  • 🧩 Modular Architecture: Registry-based component system for easy extensibility
  • ⚙️ Configuration Management: Hierarchical config system with inheritance and CLI overrides
  • 🚀 Distributed Training: Multi-node/multi-GPU training with DDP, FSDP, and DataParallel
  • 📊 Experiment Tracking: MLflow and Weights & Biases integration with auto-visualization
  • 🔧 Modern Tooling: uv package management, pre-commit hooks, Docker support
  • 💾 Resume Training: Automatic checkpoint saving and loading with state preservation
  • 🌐 Cross-Platform: Development support on macOS (Apple Silicon MPS), Linux with optimized builds
  • 🐳 Development Environment: Devcontainer and Jupyter Lab integration
  • ⚡ Performance Optimization: RAM caching, mixed precision, torch.compile support
  • 📚 Auto Documentation: Sphinx-based API docs with live reloading
  • 📱 Slack Notifications: Training completion and error notifications
  • 🛡️ Error Handling: Robust error recovery and automatic retries

Requirements

  • Python: 3.11+
  • Package Manager: uv
  • CUDA: 12.8 (for GPU training)
  • PyTorch: 2.7.1

Quick Start

1. Setup Project

Create a new project using this template:

# Option 1: Use as GitHub template (recommended)
# Click "Use this template" on GitHub
# Option 2: Clone and setup manually
git clone <your-repo-url>
cd your-project-name
# Option 3: Merge updates from this template
git remote add upstream https://github.com/mjun0812/PyTorch-Project-Template.git
git fetch upstream main
git merge --allow-unrelated-histories --squash upstream/main

2. Environment Configuration

# Copy environment template
cp template.env .env
# Edit .env with your API keys and settings

Example .env configuration:

# Slack notifications (optional)
# You can use either SLACK_TOKEN or SLACK_WEBHOOK_URL
SLACK_TOKEN="xoxb-your-token"
SLACK_CHANNEL="#notifications"
SLACK_USERNAME="Training Bot"
# Alternative: Webhook URL (simpler setup)
SLACK_WEBHOOK_URL="https://hooks.slack.com/services/..."
# MLflow tracking
MLFLOW_TRACKING_URI="./result/mlruns" # or remote URI
# Weights & Biases (optional)
WANDB_API_KEY="your-wandb-key"

3. Installation

Choose your preferred installation method:

Option A: Local Development (Recommended)

# Install dependencies
uv sync
# Setup development environment
uv run pre-commit install
# Run training
uv run python train.py config/dummy.yaml

Option B: Docker

# Build container
./docker/build.sh
# Run training in container
./docker/run.sh python train.py config/dummy.yaml

Option C: Development Container

Open the project in VS Code and use the devcontainer configuration for a consistent development environment.

Usage

Basic Training

Start with the dummy configuration to test your setup:

# Basic training with dummy dataset
python train.py config/dummy.yaml
# Override configuration from command line
python train.py config/dummy.yaml batch=32 gpu.use=0 optimizer.lr=0.001

Configuration Management

This template uses hierarchical configuration with inheritance support:

# Use dot notation to modify nested values
python train.py config/dummy.yaml gpu.use=0,1 model.backbone.depth=50
# Multiple overrides
python train.py config/dummy.yaml batch=64 epoch=100 optimizer.lr=0.01
# View current configuration
python script/show_config.py config/dummy.yaml
# Batch edit configuration files
python script/edit_configs.py config/dummy.yaml "optimizer.lr=0.01,batch=64"

Configuration hierarchy:

  1. Dataclass defaults (src/config/config.py)
  2. Base configs (config/__base__/)
  3. Experiment configs (config/*.yaml) with __base__ inheritance
  4. CLI overrides

Development Tools

# Launch Jupyter Lab for experimentation
./script/run_notebook.sh
# Start MLflow UI for experiment tracking
./script/run_mlflow.sh
# View all registered components
python script/show_registers.py
# View model architecture
python script/show_model.py
# Visualize learning rate schedules
python script/show_scheduler.py
# View data transformation pipeline
python script/show_transform.py
# Clean up orphaned result directories
python script/clean_result.py
# Aggregate MLflow results to CSV
python script/aggregate_mlflow.py all
# Start documentation server (auto-reloads on changes)
./script/run_docs.sh

Distributed Training

Scale your training across multiple GPUs and nodes:

Single Node, Multiple GPUs

# Use torchrun for DDP training (recommended)
./torchrun.sh 4 train.py config/dummy.yaml gpu.use="0,1,2,3"
# Alternative: DataParallel (not recommended for production)
python train.py config/dummy.yaml gpu.use="0,1,2,3" gpu.multi_strategy="dp"

Multi-Node Training

# Master node (node 0)
./multinode.sh 2 4 12345 0 master-ip:12345 train.py config/dummy.yaml gpu.use="0,1,2,3"
# Worker nodes (node 1+)
./multinode.sh 2 4 12345 1 master-ip:12345 train.py config/dummy.yaml gpu.use="0,1,2,3"

FSDP (Fully Sharded Data Parallel)

For very large models that don't fit in GPU memory:

python train.py config/dummy.yaml gpu.multi_strategy="fsdp" gpu.fsdp.min_num_params=100000000

Results and Checkpointing

Training results are automatically saved to:

result/[dataset_name]/[date]_[model_name]_[tag]/
├── config.yaml # Complete configuration used
├── models/ # Model checkpoints (latest.pth, best.pth, epoch_N.pth)
├── optimizers/ # Optimizer states 
└── schedulers/ # Scheduler states

Resume Training

Resume interrupted training using saved checkpoints:

# Resume from automatically saved checkpoint
python train.py result/dataset_name/20240108_ResNet_experiment/config.yaml
# Resume and extend training
python train.py result/dataset_name/20240108_ResNet_experiment/config.yaml epoch=200
# Resume with different configuration
python train.py result/dataset_name/20240108_ResNet_experiment/config.yaml gpu.use=1 batch=64

Evaluation

Run evaluation separately from training:

# Evaluate using saved model configuration
python test.py result/dataset_name/20240108_ResNet_experiment/config.yaml
# Evaluate with different GPU
python test.py result/dataset_name/20240108_ResNet_experiment/config.yaml gpu.use=1

Performance Optimization

RAM Caching

Speed up training by caching datasets in RAM:

python train.py config/dummy.yaml use_ram_cache=true ram_cache_size_gb=16

Implement caching in your custom dataset:

if self.cache is not None and idx in self.cache:
 data = self.cache.get(idx)
else:
 data = self.load_data(idx) # Your data loading logic
 if self.cache is not None:
 self.cache.set(idx, data)

Mixed Precision Training

# Enable automatic mixed precision with fp16
python train.py config/dummy.yaml use_amp=true amp_dtype="fp16"
# Use bfloat16 for newer hardware (A100, H100)
python train.py config/dummy.yaml use_amp=true amp_dtype="bf16"

torch.compile

# Enable PyTorch 2.0 compilation for speedup
python train.py config/dummy.yaml use_compile=true compile_backend="inductor"
# Alternative backends
python train.py config/dummy.yaml use_compile=true compile_backend="aot_eager"

Slack Notifications

Get notified about training progress and errors:

# Training will automatically send notifications on completion/error
python train.py config/dummy.yaml
# Manual notification testing
uv run --frozen pytest tests/test_slack_notification.py -v

Architecture

Project Structure

src/
├── config/ # Configuration management with inheritance
├── dataloaders/ # Dataset and DataLoader implementations 
├── models/ # Model definitions and backbones
│ ├── backbone/ # Pre-trained backbones (ResNet, Swin, etc.)
│ ├── layers/ # Custom layers and building blocks
│ └── losses/ # Loss function implementations
├── optimizer/ # Optimizer builders (including ScheduleFree)
├── scheduler/ # Learning rate schedulers
├── transform/ # Data preprocessing and augmentation
├── evaluator/ # Metrics and evaluation
├── runner/ # Training and testing loops
└── utils/ # Utilities (logger, registry, torch utils)
config/
├── __base__/ # Base configuration templates
└── *.yaml # Experiment configurations
script/ # Utility scripts
├── run_*.sh # Service startup scripts
├── show_*.py # Visualization tools
└── aggregate_*.py # Result aggregation tools

Registry System

Components are registered using decorators for dynamic instantiation:

from src.models import MODEL_REGISTRY
@MODEL_REGISTRY.register()
class MyModel(BaseModel):
 def __init__(self, ...):
 super().__init__()
 # Model implementation
# Custom name registration
@MODEL_REGISTRY.register("custom_name")
class AnotherModel(BaseModel):
 pass

Available registries:

  • MODEL_REGISTRY: Model architectures
  • DATASET_REGISTRY: Dataset implementations
  • TRANSFORM_REGISTRY: Data transformations
  • OPTIMIZER_REGISTRY: Optimizers
  • LR_SCHEDULER_REGISTRY: Learning rate schedulers
  • EVALUATOR_REGISTRY: Evaluation metrics

Configuration System

The configuration system supports inheritance and modular composition:

# config/my_experiment.yaml
__base__: "__base__/config.yaml"
# Override specific values
batch: 64
optimizer:
 lr: 0.001
 
# Import specific sections
transform:
 __import__: "__base__/transform/imagenet.yaml"

Error Handling and Notifications

The template includes comprehensive error handling:

  • Automatic Slack notifications for training completion and errors
  • Graceful error recovery with detailed logging
  • Checkpoint preservation even during failures
  • Distributed training fault tolerance

Development

Testing

# Run all tests
uv run --frozen pytest
# Run specific test modules
uv run --frozen pytest tests/test_modules.py
uv run --frozen pytest tests/test_slack_notification.py -v
# Run with verbose output
uv run --frozen pytest -v

Code Quality

# Format code
uv run --frozen ruff format .
# Check code quality
uv run --frozen ruff check .
# Fix auto-fixable issues
uv run --frozen ruff check . --fix

Documentation

# Start documentation server with live reload
./script/run_docs.sh

Docker Development

# Build development image
./docker/build.sh
# Run commands in container
./docker/run.sh python train.py config/dummy.yaml
./docker/run.sh bash # Interactive shell

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

No packages published

AltStyle によって変換されたページ (->オリジナル) /