PyRecover

Author: Shaswat Gupta
Contact: Email
Project Repository: PyRecover
Overview
PyRecover is a robust distributed checkpointing and job management system for multi-GPU SLURM workloads. It enables efficient, time-aware checkpointing to maximize cluster utilization and prevent loss of training progress.
Key Features
- Distributed checkpointing for large models and multi-GPU jobs
 - Time-aware job management to avoid job preemption and maximize resource usage
 - Seamless SLURM integration for easy deployment on HPC clusters
 - Fault-tolerant training with automatic resume from latest checkpoint
 - Support for Flash Attention and other advanced optimizations
 - Flexible configuration for both single-node and multi-node jobs
 - Comprehensive benchmarking and loss convergence tracking
 - Open-source under the MIT License
 
Installation
# Clone the repository
git clone https://github.com/Shaswat-G/PyRecover
cd pyrecover
# Create and activate conda environment
conda env create -f env.yml
conda activate pyrecover
Installation with Flash Attention
Ensure CUDA toolkit, C++ compiler, and Python dev headers are installed. Then:
./setup_flashattention.sh
# or
pip install ".[flash-attention]"
Project Structure
pyrecover/
├── train.py                      # Main training script
├── env.yml                       # Conda environment file
├── submit-training-simple.sh     # SLURM submission script
├── setup_flashattention.sh       # Flash Attention setup
├── tests/                        # Benchmark and test scripts
│   └── check_weights_equality.py # Checkpoint equality checker
├── ...                           # Other modules and utilities
Quick Start
Non-distributed Training
sbatch submit-training-simple.sh --exp_name=my_experiment
Distributed Training
sbatch submit-training-simple.sh --distributed --exp_name=distributed_exp
Resume from Checkpoint
sbatch submit-training-simple.sh --distributed --continue --use_torch_distributed_ckpt
Command Line Arguments
The training script (train.py) accepts various arguments:
| Argument | Description | Default | 
|---|---|---|
--dataset | 
      Path to parquet file with text data | /capstor/store/cscs/ethz/large-sc/datasets/train_data.parquet | 
    
--sequence-length | 
      Maximum sequence length | 2048 | 
--batch-size | 
      Batch size per GPU | 1 | 
--learning-rate | 
      Learning rate | 1e-5 | 
--training-steps | 
      Number of training steps | 1000 | 
--distributed | 
      Enable distributed training | False | 
--model-dtype | 
      Model precision (fp16/bf16/fp32/fp64) | “bf16” | 
--checkpoint-dir | 
      Directory for checkpoints | “checkpoints/” | 
--checkpoint-frequency | 
      Save checkpoint every N steps | 10 | 
--resume-from-checkpoint | 
      Path to checkpoint or “latest” | None | 
--profile | 
      activates profiling support for nsys | False | 
--experiment_name | 
      Name of experiment (for checkpoint subfolder) | “default-exp” | 
--use-torch-distributed-ckpt | 
      Use distributed checkpointing | False | 
--compile | 
      Compile model with torch.compile | False | 
--fused-optimizer | 
      Use fused optimizer | False | 
--use_flash_attention | 
      Use flash-attention in the model | False | 
--log-loss-to-csv | 
      Log loss to a csv for plots/comparison | False | 
--timeaware-checkpointing | 
      Activates time aware checkpointing | False | 
For a complete list, run:
python train.py --help
SLURM Submission Script
The script submit-training-simple.sh launches jobs on SLURM clusters. Key parameters:
| SLURM Parameter | Description | 
|---|---|
--nodes | 
      Number of nodes to allocate | 
--ntasks-per-node | 
      Tasks per node (typically 1 per GPU) | 
--gpus-per-node | 
      GPUs to use per node | 
--time | 
      Time limit for the job | 
--partition | 
      SLURM partition to use | 
Checkpointing
- Vanilla Checkpointing: Standard PyTorch checkpointing (default)
 - Distributed Checkpointing: Faster for large models (enable with 
--use_torch_distributed_ckpt) - Time-Aware Checkpointing: Add 
--timeaware-checkpointingto auto-save before walltime ends 
Benchmarks
- Check equality of weights: 
python check_weights_equality.py <checkpoint1> <checkpoint2> [--distributed] [--tolerance 1e-7] [--verbose] - Loss convergence: Add 
--log-loss-to-csvto log step-wise loss for analysis 
License
This project is licensed under the MIT License.
Acknowledgments
Developed at ETH Zurich for robust, large-scale deep learning on HPC clusters.