thsolver.solver
The base solver implements the reusable training, testing, evaluation, and profiling loops.
- class Solver(FLAGS, is_master=True)[source]
A lightweight base class for PyTorch training and evaluation loops.
Subclasses implement the model, dataset, and step-specific logic, while this class handles logging, checkpointing, distributed training, and scheduling.
- get_dataset(flags)[source]
Returns a dataset and its collate function.
- Parameters:
flags – The dataset config node.
- train_step(batch)[source]
Returns the outputs of one training step.
The returned dict should contain the key
train/loss, which is the scalar loss used for back propagation and optimizer updates.- Parameters:
batch (dict) – One batch produced by the training dataloader.
- test_step(batch)[source]
Returns the outputs of one testing step.
- Parameters:
batch (dict) – One batch produced by the test dataloader.
- eval_step(batch)[source]
Evaluates the model on a batch.
- Parameters:
batch (dict) – One batch produced by the evaluation dataloader.
- result_callback(avg_tracker: AverageTracker, epoch)[source]
Runs custom logic after a test epoch finishes.
- Parameters:
avg_tracker (AverageTracker) – The epoch-level tracker.
epoch (int) – The current epoch number.
- config_dataloader(disable_train_data=False)[source]
Builds the train and test dataloaders when enabled.
- Parameters:
disable_train_data (bool) – If True, skips the training dataloader.
- get_dataloader(flags)[source]
Builds one dataloader from a dataset config node.
- Parameters:
flags – The dataset config node.
- Returns:
The configured dataloader.
- Return type:
- config_optimizer()[source]
Builds the optimizer for the current model.
The base learning rate is scaled by
world_sizeto match the effective global batch size in distributed training.
- configure_log(set_writer=True)[source]
Configures the log directory, checkpoint directory, and writers.
- Parameters:
set_writer (bool) – If True, creates the TensorBoard writer.
- train_epoch(epoch)[source]
Runs one full training epoch.
- Parameters:
epoch (int) – The current epoch number.
- test_epoch(epoch)[source]
Runs one full test epoch.
- Parameters:
epoch (int) – The current epoch number.
- eval_epoch(epoch)[source]
Runs one evaluation epoch in
evaluatemode.- Parameters:
epoch (int) – The current epoch number.
- save_best_checkpoint(tracker: AverageTracker, epoch: int)[source]
Saves the best-performing model according to
SOLVER.best_val.- Parameters:
tracker (AverageTracker) – The tracker holding averaged test metrics.
epoch (int) – The current epoch number.
- save_checkpoint(epoch)[source]
Saves a training checkpoint for the given epoch.
- Parameters:
epoch (int) – The epoch number used in the checkpoint filename.
- load_checkpoint()[source]
Loads the requested checkpoint or the latest checkpoint defined in
SOLVER.ckpt.
- profile()[source]
Profiles a few training iterations with the PyTorch profiler.
Set
DATA.train.num_workers 0when using this function.
- static get_world_size(FLAGS)[source]
Returns the distributed world size implied by the launch mode.
- Parameters:
FLAGS – The experiment config tree.
- classmethod update_configs()[source]
Updates class-specific configs before parsing command-line arguments.