Skip to content

Checkpoint Management

Relevant Source Files

This page documents the checkpoint management system in Templar, which is responsible for saving and loading model states, optimizers, schedulers, and momentum values during distributed training. This system is crucial for ensuring training resilience, enabling nodes to recover from failures, synchronize with the network, and resume training from previous states. For information about how checkpoint management interacts with blockchain commitments, see Chain Integration.

A Templar checkpoint contains all the state information needed to fully restore training:

classDiagram
    class Checkpoint {
        +Dict "model_state_dict"
        +Dict "optimizer_state_dict"
        +Dict "scheduler_state_dict"
        +Dict "momentum"
        +Int "start_window"
        +Int "current_window"
    }
    
    class ModelState {
        +Tensor "weight_tensors"
        +Tensor "bias_tensors"
    }
    
    class OptimizerState {
        +Dict "state"
        +List "param_groups"
    }
    
    class SchedulerState {
        +Int "last_epoch"
        +Float "base_lrs"
    }
    
    class Momentum {
        +Tensor "param_momentum_tensors"
    }
    
    Checkpoint --> ModelState : "contains"
    Checkpoint --> OptimizerState : "contains"
    Checkpoint --> SchedulerState : "contains"
    Checkpoint --> Momentum : "contains"

Sources: neurons/miner.py:265-282 , tests/test_checkpoints.py:40-54

Each component serves a specific purpose:

ComponentDescription
model_state_dictParameter tensors for the LLaMA model
optimizer_state_dictSGD optimizer state (step counts, parameter-specific states)
scheduler_state_dictLearning rate scheduler state (current epoch, base learning rates)
momentumMomentum tensors for gradient accumulation
start_windowTraining start window (for global step calculation)
current_windowWindow at which the checkpoint was saved

All tensors in checkpoints are stored on CPU to ensure compatibility when loading across different devices.

Sources: src/tplr/comms.py:924-937

Templar uses Cloudflare R2 Storage as the primary checkpoint repository, with local filesystem caching for performance.

graph TD
    subgraph "Storage"
        R2["Cloudflare R2 Storage"]
        Local["Local Cache (/tmp)"]
    end
    
    subgraph "Comms.Checkpoint Methods"
        Save["save_checkpoint()"]
        Load["load_checkpoint()"]
        GetLatest["get_latest_checkpoint()"]
    end
    
    subgraph "Neurons"
        Miner["Miner"]
        Validator["Validator"]
    end
    
    Miner -->|"triggers save/load"| Save
    Validator -->|"triggers save/load"| Load
    
    Save -->|"writes to"| R2
    Save -->|"caches to"| Local
    Load -->|"requests"| GetLatest
    GetLatest -->|"queries"| R2
    GetLatest -->|"falls back to"| Local

Sources: src/tplr/comms.py:122-148 , neurons/miner.py:730-747 , neurons/validator.py:582-613

Checkpoint files follow this naming convention:

checkpoint-{global_step}-{uid}-v{version}.pt

Where:

  • global_step: Training step at which the checkpoint was saved
  • uid: Unique identifier of the node that created the checkpoint
  • version: Code version (from tplr.__version__)

This convention enables efficient filtering and retrieval of checkpoints by version, step, or node.

Sources: tests/test_checkpoints.py:83-87

Checkpoints are saved periodically during training based on the checkpoint_frequency parameter in hparams.json.

sequenceDiagram
    participant Neuron as "Miner/Validator"
    participant Comms as "Comms"
    participant R2 as "R2 Storage"
    
    Note over Neuron: "Check if global_step % checkpoint_frequency == 0"
    
    Neuron->>Comms: "save_checkpoint(model, optimizer, scheduler, momentum, ...)"
    Comms->>Comms: "Create checkpoint dictionary"
    Comms->>Comms: "Move all tensors to CPU"
    Comms->>R2: "Upload checkpoint to R2 bucket"
    R2-->>Comms: "Upload confirmation"
    Comms-->>Neuron: "Checkpoint saved"

The checkpoint saving process:

  1. Creates a checkpoint dictionary containing all state components
  2. Ensures all tensors are moved to CPU for compatibility
  3. Saves the checkpoint to R2 storage with versioning information
  4. Handles large file uploads using multipart upload when necessary

Sources: neurons/miner.py:730-747 , src/tplr/comms.py:894-949

Loading checkpoints is performed at node startup and involves several steps:

sequenceDiagram
    participant Neuron as "Miner/Validator"
    participant Comms as "Comms"
    participant R2 as "R2 Storage"
    
    Neuron->>Comms: "load_checkpoint(model, optimizer, scheduler, ...)"
    Comms->>R2: "Get latest checkpoint"
    R2-->>Comms: "Return checkpoint data"
    
    Comms->>Comms: "Move tensors to target device"
    Comms->>Comms: "Restore model state"
    Comms->>Comms: "Restore optimizer state"
    Comms->>Comms: "Restore scheduler state"
    
    Comms->>Comms: "Calculate window difference"
    
    alt "Catch-up needed"
        Comms->>Comms: "Apply catch-up updates"
    end
    
    Comms-->>Neuron: "Return (success, momentum, loaded_checkpoint_window, optimizer, scheduler)"

The checkpoint loading process:

  1. Retrieves the latest compatible checkpoint from R2 storage
  2. Moves tensors to the appropriate device (CPU, CUDA)
  3. Restores model, optimizer, and scheduler states
  4. Determines if catch-up is needed
  5. Applies catch-up updates if necessary

Sources: neurons/miner.py:273-316 , neurons/validator.py:582-613 , src/tplr/comms.py:955-1073

The catch-up mechanism brings models up-to-date when loading checkpoints from earlier windows:

graph TD
    LC["load_checkpoint()"] --> CW["Calculate window_difference = current_window - checkpoint_window"]
    
    subgraph "Catch-up Process"
        BC["Batch windows into manageable chunks"]
        AO["Apply optimizer steps for each missing window"]
        AS["Apply scheduler steps for each missing window"]
        UG["Update global_step"]
    end
    
    CW -->|"if window_difference > 0"| BC
    BC --> AO
    AO --> AS
    AS --> UG

This ensures learning rates and optimizer states match current training progress when loading an older checkpoint.

Sources: neurons/miner.py:300-316 , tests/test_checkpoints.py:472-543

Templar’s checkpoint system handles version compatibility through:

  1. Version-specific checkpoint files (v{version}.pt suffix)
  2. Bootstrap version configuration (checkpoint_init_version in hparams.json)
  3. Fallback to local cache when compatible R2 versions are unavailable

During startup, miners and validators will attempt to load the latest checkpoint matching their current version. For initial setup, they use the configured bootstrap version.

Sources: neurons/miner.py:167-168 , neurons/validator.py:201-205 , hparams.json:52

The checkpoint system is configured through hyperparameters:

ParameterDescriptionDefault
checkpoint_frequencyHow often to save checkpoints (in global steps)100
checkpoint_init_versionVersion to use for initial checkpoint loading”0.2.73”

Sources: hparams.json:31-52

The checkpoint management system is implemented in the Comms class with these core methods:

classDiagram
    class Comms {
        +async "save_checkpoint(model, optimizer, scheduler, momentum, global_step, current_window, start_window)"
        +async "load_checkpoint(model, optimizer, scheduler, current_window, device, init_version)"
        +async "get_latest_checkpoint(version)"
        -async "s3_put_object(key, file_path)"
        -async "s3_get_object(key, bucket, timeout)"
        -async "upload_large_file(file_path, key, s3_client)"
        -async "download_large_file(s3_client, bucket, key, file_size, temp_file_path)"
    }

The system handles multiple file sizes with specialized methods for large file transfers, properly managing asynchronous I/O operations.

Sources: src/tplr/comms.py:894-1073

The checkpoint system includes robust error handling for:

  • Network failures during upload/download operations
  • Corrupted checkpoint files
  • Version incompatibilities
  • Missing checkpoint files

It implements:

  • Retry logic with exponential backoff
  • Local cache fallback
  • Detailed error logging
  • Graceful failure modes that won’t crash the application

Sources: src/tplr/comms.py:366-371 , src/tplr/comms.py:423-427

graph TD
    MS["Miner.run()"] --> LC["Load latest checkpoint"]
    LC -->|"Success"| UC["Catch up if needed"]
    LC -->|"Failure"| IM["Initialize from scratch"]
    UC --> TR["Train for current window"]
    IM --> TR
    TR --> CF["Check if global_step % checkpoint_frequency == 0"]
    CF -->|"Yes"| SC["Save checkpoint"]
    CF -->|"No"| NW["Wait for next window"]
    SC --> NW
    NW --> TR

Sources: neurons/miner.py:267-317 , neurons/miner.py:730-747

graph TD
    VS["Validator.run()"] --> LC["Load latest checkpoint"]
    LC -->|"Success"| UC["Catch up if needed"]
    LC -->|"Failure"| IM["Initialize from scratch"]
    UC --> AG["Aggregate/evaluate gradients"]
    IM --> AG
    AG --> CF["Check if global_step % checkpoint_frequency == 0"]
    CF -->|"Yes"| SC["Save checkpoint"]
    CF -->|"No"| NW["Process next window"]
    SC --> NW
    NW --> AG

Sources: neurons/validator.py:576-620 , neurons/validator.py:729-735

The Evaluator service uses the checkpoint system to periodically load the latest model checkpoints and evaluate their performance on benchmarks. It maintains a record of the last evaluated window to prevent duplicate evaluations.

graph TD
    ES["Evaluator Service"] --> GLC["get_latest_checkpoint()"]
    GLC --> CW["Compare checkpoint window to last_eval_window"]
    CW -->|"window > last_eval_window"| LC["Load checkpoint"]
    LC --> EM["Evaluate model performance"]
    EM --> UL["Update last_eval_window"]
    CW -->|"window <= last_eval_window"| SK["Skip (already evaluated)"]

Sources: tests/test_evaluator.py:60-146

Common checkpoint-related issues and solutions:

IssuePossible CausesSolution
Checkpoint loading failsVersion mismatch, corrupted fileCheck version compatibility, verify R2 access
Catch-up process errorsLarge window gap, memory issuesReduce catch-up batch size, ensure sufficient memory
Slow checkpoint savingLarge model size, network issuesCheck network connectivity, monitor R2 performance
Missing checkpointProcess started for first timeNode will initialize from scratch

Sources: src/tplr/comms.py:423-427 , neurons/miner.py:305-316

The checkpoint management system is a critical component of Templar that ensures training resilience and continuity. By periodically saving complete training state and providing efficient loading mechanisms, it enables nodes to recover from failures, sync with the network, and maintain training progress in a distributed environment.

Through careful version management and the catch-up mechanism, the system ensures that nodes can join or rejoin training seamlessly, maintaining the integrity of the distributed training process.