thsolver.solver

The base solver implements the reusable training, testing, evaluation, and profiling loops.

class Solver(FLAGS, is_master=True)[source]

A lightweight base class for PyTorch training and evaluation loops.

Subclasses implement the model, dataset, and step-specific logic, while this class handles logging, checkpointing, distributed training, and scheduling.

get_model()[source]

Returns the model used by the current experiment.

get_dataset(flags)[source]

Returns a dataset and its collate function.

Parameters:

flags – The dataset config node.

train_step(batch)[source]

Returns the outputs of one training step.

The returned dict should contain the key train/loss, which is the scalar loss used for back propagation and optimizer updates.

Parameters:

batch (dict) – One batch produced by the training dataloader.

test_step(batch)[source]

Returns the outputs of one testing step.

Parameters:

batch (dict) – One batch produced by the test dataloader.

eval_step(batch)[source]

Evaluates the model on a batch.

Parameters:

batch (dict) – One batch produced by the evaluation dataloader.

result_callback(avg_tracker: AverageTracker, epoch)[source]

Runs custom logic after a test epoch finishes.

Parameters:
  • avg_tracker (AverageTracker) – The epoch-level tracker.

  • epoch (int) – The current epoch number.

config_dataloader(disable_train_data=False)[source]

Builds the train and test dataloaders when enabled.

Parameters:

disable_train_data (bool) – If True, skips the training dataloader.

get_dataloader(flags)[source]

Builds one dataloader from a dataset config node.

Parameters:

flags – The dataset config node.

Returns:

The configured dataloader.

Return type:

torch.utils.data.DataLoader

config_model()[source]

Builds the model, moves it to CUDA, and wraps DDP if needed.

config_optimizer()[source]

Builds the optimizer for the current model.

The base learning rate is scaled by world_size to match the effective global batch size in distributed training.

config_lr_scheduler()[source]

Builds the learning-rate scheduler for the current optimizer.

configure_log(set_writer=True)[source]

Configures the log directory, checkpoint directory, and writers.

Parameters:

set_writer (bool) – If True, creates the TensorBoard writer.

train_epoch(epoch)[source]

Runs one full training epoch.

Parameters:

epoch (int) – The current epoch number.

test_epoch(epoch)[source]

Runs one full test epoch.

Parameters:

epoch (int) – The current epoch number.

eval_epoch(epoch)[source]

Runs one evaluation epoch in evaluate mode.

Parameters:

epoch (int) – The current epoch number.

save_best_checkpoint(tracker: AverageTracker, epoch: int)[source]

Saves the best-performing model according to SOLVER.best_val.

Parameters:
  • tracker (AverageTracker) – The tracker holding averaged test metrics.

  • epoch (int) – The current epoch number.

save_checkpoint(epoch)[source]

Saves a training checkpoint for the given epoch.

Parameters:

epoch (int) – The epoch number used in the checkpoint filename.

load_checkpoint()[source]

Loads the requested checkpoint or the latest checkpoint defined in SOLVER.ckpt.

manual_seed()[source]

Sets random seeds when SOLVER.rand_seed is positive.

train()[source]

Runs the end-to-end training workflow.

test()[source]

Loads a checkpoint and runs the test loop once.

evaluate()[source]

Loads a checkpoint and runs the evaluation loop.

profile()[source]

Profiles a few training iterations with the PyTorch profiler.

Set DATA.train.num_workers 0 when using this function.

run()[source]

Dispatches to the run mode configured in SOLVER.run.

static get_world_size(FLAGS)[source]

Returns the distributed world size implied by the launch mode.

Parameters:

FLAGS – The experiment config tree.

classmethod update_configs()[source]

Updates class-specific configs before parsing command-line arguments.

classmethod worker(rank, FLAGS)[source]

Runs one solver worker process.

Parameters:
  • rank (int) – The process rank in the current launch.

  • FLAGS – The experiment config tree.

classmethod main()[source]

Parses configs and launches the solver with the configured DDP mode.