Getting Started
The normal workflow is:
Define a model factory and optionally register it with
thsolver.registry.register_model().Define a dataset factory and optionally register it with
thsolver.registry.register_dataset().Subclass
thsolver.solver.Solver.Implement the hooks for your task.
Launch the run with
thsolver.solver.Solver.main().
Minimal Solver Skeleton
The base solver owns the training loop. Your subclass supplies the pieces that are task-specific.
import torch
from thsolver import Solver, build_dataset, build_model
class DemoSolver(Solver):
def get_model(self, flags):
return build_model(flags)
def get_dataset(self, flags):
return build_dataset(flags)
def train_step(self, batch):
output = self.model(batch)
loss = output["loss"]
return {
"train/loss": loss,
"train/metric": output["metric"],
}
@torch.no_grad()
def test_step(self, batch):
output = self.model(batch)
return {
"test/loss": output["loss"],
"test/metric": output["metric"],
}
def eval_step(self, batch):
output = self.model(batch)
# write predictions or accumulate task-specific results here
if __name__ == "__main__":
DemoSolver.main()
Note
The example above is intentionally schematic. thsolver does not impose a
fixed batch structure or model signature. Your dataloader decides what a
batch looks like, and your hooks decide how to move it to CUDA, run the
model, and compute losses or metrics.
Required Hooks
The following methods are expected in most projects:
thsolver.solver.Solver.get_model()returns atorch.nn.Module.thsolver.solver.Solver.get_dataset()returns(dataset, collate_fn).thsolver.solver.Solver.train_step()must return a dict containing the key"train/loss".thsolver.solver.Solver.test_step()returns test losses or metrics that should be averaged over the epoch.thsolver.solver.Solver.eval_step()is used for evaluation-only jobs such as dumping predictions.
Run Modes
SOLVER.run selects which entry point the base class will execute:
trainruns the full training loop.testloads a checkpoint and evaluates on the test loader.evaluatecallsthsolver.solver.Solver.eval_step()for side-effect evaluation jobs.profileruns a short PyTorch profiler session on the training step.
What the Base Class Handles
Once the hooks are in place, thsolver.solver.Solver takes care of:
dataloader creation with infinite samplers
optimizer and learning-rate scheduler setup
mixed precision with
none,fp16, orbf16checkpoint save and restore
TensorBoard logging and CSV logging
best-checkpoint tracking through
SOLVER.best_valsingle-node spawn-based DDP and
torchrun-based launches