Checkpoint Management
Purpose and Scope
Section titled “Purpose and Scope”This page documents the checkpoint management system in Templar, which is responsible for saving and loading model states, optimizers, schedulers, and momentum values during distributed training. This system is crucial for ensuring training resilience, enabling nodes to recover from failures, synchronize with the network, and resume training from previous states. For information about how checkpoint management interacts with blockchain commitments, see Chain Integration.
Checkpoint Structure and Contents
Section titled “Checkpoint Structure and Contents”A Templar checkpoint contains all the state information needed to fully restore training:
classDiagram class Checkpoint { +Dict "model_state_dict" +Dict "optimizer_state_dict" +Dict "scheduler_state_dict" +Dict "momentum" +Int "start_window" +Int "current_window" } class ModelState { +Tensor "weight_tensors" +Tensor "bias_tensors" } class OptimizerState { +Dict "state" +List "param_groups" } class SchedulerState { +Int "last_epoch" +Float "base_lrs" } class Momentum { +Tensor "param_momentum_tensors" } Checkpoint --> ModelState : "contains" Checkpoint --> OptimizerState : "contains" Checkpoint --> SchedulerState : "contains" Checkpoint --> Momentum : "contains"
Sources: neurons/miner.py:265-282 , tests/test_checkpoints.py:40-54
Each component serves a specific purpose:
Component | Description |
---|---|
model_state_dict | Parameter tensors for the LLaMA model |
optimizer_state_dict | SGD optimizer state (step counts, parameter-specific states) |
scheduler_state_dict | Learning rate scheduler state (current epoch, base learning rates) |
momentum | Momentum tensors for gradient accumulation |
start_window | Training start window (for global step calculation) |
current_window | Window at which the checkpoint was saved |
All tensors in checkpoints are stored on CPU to ensure compatibility when loading across different devices.
Sources: src/tplr/comms.py:924-937
Storage System
Section titled “Storage System”Templar uses Cloudflare R2 Storage as the primary checkpoint repository, with local filesystem caching for performance.
graph TD subgraph "Storage" R2["Cloudflare R2 Storage"] Local["Local Cache (/tmp)"] end subgraph "Comms.Checkpoint Methods" Save["save_checkpoint()"] Load["load_checkpoint()"] GetLatest["get_latest_checkpoint()"] end subgraph "Neurons" Miner["Miner"] Validator["Validator"] end Miner -->|"triggers save/load"| Save Validator -->|"triggers save/load"| Load Save -->|"writes to"| R2 Save -->|"caches to"| Local Load -->|"requests"| GetLatest GetLatest -->|"queries"| R2 GetLatest -->|"falls back to"| Local
Sources: src/tplr/comms.py:122-148 , neurons/miner.py:730-747 , neurons/validator.py:582-613
Checkpoint File Naming
Section titled “Checkpoint File Naming”Checkpoint files follow this naming convention:
checkpoint-{global_step}-{uid}-v{version}.pt
Where:
global_step
: Training step at which the checkpoint was saveduid
: Unique identifier of the node that created the checkpointversion
: Code version (fromtplr.__version__
)
This convention enables efficient filtering and retrieval of checkpoints by version, step, or node.
Sources: tests/test_checkpoints.py:83-87
Checkpoint Operations
Section titled “Checkpoint Operations”Saving Checkpoints
Section titled “Saving Checkpoints”Checkpoints are saved periodically during training based on the checkpoint_frequency
parameter in hparams.json
.
sequenceDiagram participant Neuron as "Miner/Validator" participant Comms as "Comms" participant R2 as "R2 Storage" Note over Neuron: "Check if global_step % checkpoint_frequency == 0" Neuron->>Comms: "save_checkpoint(model, optimizer, scheduler, momentum, ...)" Comms->>Comms: "Create checkpoint dictionary" Comms->>Comms: "Move all tensors to CPU" Comms->>R2: "Upload checkpoint to R2 bucket" R2-->>Comms: "Upload confirmation" Comms-->>Neuron: "Checkpoint saved"
The checkpoint saving process:
- Creates a checkpoint dictionary containing all state components
- Ensures all tensors are moved to CPU for compatibility
- Saves the checkpoint to R2 storage with versioning information
- Handles large file uploads using multipart upload when necessary
Sources: neurons/miner.py:730-747 , src/tplr/comms.py:894-949
Loading Checkpoints
Section titled “Loading Checkpoints”Loading checkpoints is performed at node startup and involves several steps:
sequenceDiagram participant Neuron as "Miner/Validator" participant Comms as "Comms" participant R2 as "R2 Storage" Neuron->>Comms: "load_checkpoint(model, optimizer, scheduler, ...)" Comms->>R2: "Get latest checkpoint" R2-->>Comms: "Return checkpoint data" Comms->>Comms: "Move tensors to target device" Comms->>Comms: "Restore model state" Comms->>Comms: "Restore optimizer state" Comms->>Comms: "Restore scheduler state" Comms->>Comms: "Calculate window difference" alt "Catch-up needed" Comms->>Comms: "Apply catch-up updates" end Comms-->>Neuron: "Return (success, momentum, loaded_checkpoint_window, optimizer, scheduler)"
The checkpoint loading process:
- Retrieves the latest compatible checkpoint from R2 storage
- Moves tensors to the appropriate device (CPU, CUDA)
- Restores model, optimizer, and scheduler states
- Determines if catch-up is needed
- Applies catch-up updates if necessary
Sources: neurons/miner.py:273-316 , neurons/validator.py:582-613 , src/tplr/comms.py:955-1073
Catch-up Mechanism
Section titled “Catch-up Mechanism”The catch-up mechanism brings models up-to-date when loading checkpoints from earlier windows:
graph TD LC["load_checkpoint()"] --> CW["Calculate window_difference = current_window - checkpoint_window"] subgraph "Catch-up Process" BC["Batch windows into manageable chunks"] AO["Apply optimizer steps for each missing window"] AS["Apply scheduler steps for each missing window"] UG["Update global_step"] end CW -->|"if window_difference > 0"| BC BC --> AO AO --> AS AS --> UG
This ensures learning rates and optimizer states match current training progress when loading an older checkpoint.
Sources: neurons/miner.py:300-316 , tests/test_checkpoints.py:472-543
Version Management
Section titled “Version Management”Templar’s checkpoint system handles version compatibility through:
- Version-specific checkpoint files (
v{version}.pt
suffix) - Bootstrap version configuration (
checkpoint_init_version
in hparams.json) - Fallback to local cache when compatible R2 versions are unavailable
During startup, miners and validators will attempt to load the latest checkpoint matching their current version. For initial setup, they use the configured bootstrap version.
Sources: neurons/miner.py:167-168 , neurons/validator.py:201-205 , hparams.json:52
Configuration
Section titled “Configuration”The checkpoint system is configured through hyperparameters:
Parameter | Description | Default |
---|---|---|
checkpoint_frequency | How often to save checkpoints (in global steps) | 100 |
checkpoint_init_version | Version to use for initial checkpoint loading | ”0.2.73” |
Sources: hparams.json:31-52
Implementation Details
Section titled “Implementation Details”Key Methods
Section titled “Key Methods”The checkpoint management system is implemented in the Comms
class with these core methods:
classDiagram class Comms { +async "save_checkpoint(model, optimizer, scheduler, momentum, global_step, current_window, start_window)" +async "load_checkpoint(model, optimizer, scheduler, current_window, device, init_version)" +async "get_latest_checkpoint(version)" -async "s3_put_object(key, file_path)" -async "s3_get_object(key, bucket, timeout)" -async "upload_large_file(file_path, key, s3_client)" -async "download_large_file(s3_client, bucket, key, file_size, temp_file_path)" }
The system handles multiple file sizes with specialized methods for large file transfers, properly managing asynchronous I/O operations.
Sources: src/tplr/comms.py:894-1073
Error Handling
Section titled “Error Handling”The checkpoint system includes robust error handling for:
- Network failures during upload/download operations
- Corrupted checkpoint files
- Version incompatibilities
- Missing checkpoint files
It implements:
- Retry logic with exponential backoff
- Local cache fallback
- Detailed error logging
- Graceful failure modes that won’t crash the application
Sources: src/tplr/comms.py:366-371 , src/tplr/comms.py:423-427
Usage Patterns
Section titled “Usage Patterns”In Miner Nodes
Section titled “In Miner Nodes”graph TD MS["Miner.run()"] --> LC["Load latest checkpoint"] LC -->|"Success"| UC["Catch up if needed"] LC -->|"Failure"| IM["Initialize from scratch"] UC --> TR["Train for current window"] IM --> TR TR --> CF["Check if global_step % checkpoint_frequency == 0"] CF -->|"Yes"| SC["Save checkpoint"] CF -->|"No"| NW["Wait for next window"] SC --> NW NW --> TR
Sources: neurons/miner.py:267-317 , neurons/miner.py:730-747
In Validator Nodes
Section titled “In Validator Nodes”graph TD VS["Validator.run()"] --> LC["Load latest checkpoint"] LC -->|"Success"| UC["Catch up if needed"] LC -->|"Failure"| IM["Initialize from scratch"] UC --> AG["Aggregate/evaluate gradients"] IM --> AG AG --> CF["Check if global_step % checkpoint_frequency == 0"] CF -->|"Yes"| SC["Save checkpoint"] CF -->|"No"| NW["Process next window"] SC --> NW NW --> AG
Sources: neurons/validator.py:576-620 , neurons/validator.py:729-735
Evaluator Integration
Section titled “Evaluator Integration”The Evaluator service uses the checkpoint system to periodically load the latest model checkpoints and evaluate their performance on benchmarks. It maintains a record of the last evaluated window to prevent duplicate evaluations.
graph TD ES["Evaluator Service"] --> GLC["get_latest_checkpoint()"] GLC --> CW["Compare checkpoint window to last_eval_window"] CW -->|"window > last_eval_window"| LC["Load checkpoint"] LC --> EM["Evaluate model performance"] EM --> UL["Update last_eval_window"] CW -->|"window <= last_eval_window"| SK["Skip (already evaluated)"]
Sources: tests/test_evaluator.py:60-146
Troubleshooting
Section titled “Troubleshooting”Common checkpoint-related issues and solutions:
Issue | Possible Causes | Solution |
---|---|---|
Checkpoint loading fails | Version mismatch, corrupted file | Check version compatibility, verify R2 access |
Catch-up process errors | Large window gap, memory issues | Reduce catch-up batch size, ensure sufficient memory |
Slow checkpoint saving | Large model size, network issues | Check network connectivity, monitor R2 performance |
Missing checkpoint | Process started for first time | Node will initialize from scratch |
Sources: src/tplr/comms.py:423-427 , neurons/miner.py:305-316
Summary
Section titled “Summary”The checkpoint management system is a critical component of Templar that ensures training resilience and continuity. By periodically saving complete training state and providing efficient loading mechanisms, it enables nodes to recover from failures, sync with the network, and maintain training progress in a distributed environment.
Through careful version management and the catch-up mechanism, the system ensures that nodes can join or rejoin training seamlessly, maintaining the integrity of the distributed training process.