Source code for thsolver.solver

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import os
import torch
import torch.nn
import torch.optim
import torch.distributed
import torch.multiprocessing
import torch.utils.data
import random
import numpy as np
from tqdm import tqdm
from packaging import version
from torch.utils.tensorboard import SummaryWriter

from .sampler import InfSampler, DistributedInfSampler
from .tracker import AverageTracker
from .config import parse_args
from .lr_scheduler import get_lr_scheduler


[docs]class Solver: r''' A lightweight base class for PyTorch training and evaluation loops. Subclasses implement the model, dataset, and step-specific logic, while this class handles logging, checkpointing, distributed training, and scheduling. ''' def __init__(self, FLAGS, is_master=True): r''' Initializes the solver runtime state. Args: FLAGS: The experiment config tree. is_master (bool): If True, enables logging and checkpoint writing. ''' self.FLAGS = FLAGS self.is_master = is_master self.world_size = self.get_world_size(FLAGS) self.device = torch.cuda.current_device() self.disable_tqdm = not (is_master and FLAGS.SOLVER.progress_bar) self.start_epoch = 1 self.amp_mode = FLAGS.SOLVER.amp_mode self.model = None # torch.nn.Module self.optimizer = None # torch.optim.Optimizer self.scheduler = None # torch.optim.lr_scheduler._LRScheduler self.summary_writer = None # torch.utils.tensorboard.SummaryWriter self.log_file = None # str, used to save training logs self.eval_rst = dict() # used to save evalation results self.best_val = None # used to save the best validation result # config the gradscaler newer_than_230 = version.parse(torch.__version__) > version.parse('2.3.0') if self.amp_mode == 'fp16': self.scaler = torch.GradScaler() if newer_than_230 else torch.cuda.amp.GradScaler() else: self.scaler = None
[docs] def get_model(self): r''' Returns the model used by the current experiment. ''' raise NotImplementedError
[docs] def get_dataset(self, flags): r''' Returns a dataset and its collate function. Args: flags: The dataset config node. ''' raise NotImplementedError
[docs] def train_step(self, batch): r''' Returns the outputs of one training step. The returned dict should contain the key ``train/loss``, which is the scalar loss used for back propagation and optimizer updates. Args: batch (dict): One batch produced by the training dataloader. ''' raise NotImplementedError
[docs] def test_step(self, batch): r''' Returns the outputs of one testing step. Args: batch (dict): One batch produced by the test dataloader. ''' raise NotImplementedError
[docs] def eval_step(self, batch): r''' Evaluates the model on a batch. Args: batch (dict): One batch produced by the evaluation dataloader. ''' raise NotImplementedError
[docs] def result_callback(self, avg_tracker: AverageTracker, epoch): r''' Runs custom logic after a test epoch finishes. Args: avg_tracker (AverageTracker): The epoch-level tracker. epoch (int): The current epoch number. ''' pass # additional operations based on the avg_tracker
[docs] def config_dataloader(self, disable_train_data=False): r''' Builds the train and test dataloaders when enabled. Args: disable_train_data (bool): If True, skips the training dataloader. ''' flags_train, flags_test = self.FLAGS.DATA.train, self.FLAGS.DATA.test if not disable_train_data and not flags_train.disable: self.train_loader = self.get_dataloader(flags_train) self.train_iter = iter(self.train_loader) if not flags_test.disable: self.test_loader = self.get_dataloader(flags_test) self.test_iter = iter(self.test_loader)
[docs] def get_dataloader(self, flags): r''' Builds one dataloader from a dataset config node. Args: flags: The dataset config node. Returns: torch.utils.data.DataLoader: The configured dataloader. ''' dataset, collate_fn = self.get_dataset(flags) if self.world_size > 1: sampler = DistributedInfSampler(dataset, shuffle=flags.shuffle) else: sampler = InfSampler(dataset, shuffle=flags.shuffle) data_loader = torch.utils.data.DataLoader( dataset, batch_size=flags.batch_size, num_workers=flags.num_workers, sampler=sampler, collate_fn=collate_fn, pin_memory=flags.pin_memory) return data_loader
[docs] def config_model(self): r''' Builds the model, moves it to CUDA, and wraps DDP if needed. ''' flags = self.FLAGS.MODEL model = self.get_model(flags) model.cuda(device=self.device) if self.world_size > 1: if flags.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = torch.nn.parallel.DistributedDataParallel( module=model, device_ids=[self.device], output_device=self.device, broadcast_buffers=False, find_unused_parameters=flags.find_unused_parameters) if self.is_master: print(model) # print the model structure total_params = 0 for p in model.parameters(): total_params += p.numel() print("Total number of parameters: %.3fM" % (total_params / 1e6)) self.model = model
[docs] def config_optimizer(self): r''' Builds the optimizer for the current model. The base learning rate is scaled by ``world_size`` to match the effective global batch size in distributed training. ''' # The base learning rate `base_lr` scales with regard to the world_size flags = self.FLAGS.SOLVER base_lr = flags.lr * self.world_size parameters = self.model.parameters() # config the optimizer if flags.type.lower() == 'sgd': self.optimizer = torch.optim.SGD( parameters, lr=base_lr, weight_decay=flags.weight_decay, momentum=0.9) elif flags.type.lower() == 'adam': self.optimizer = torch.optim.Adam( parameters, lr=base_lr, weight_decay=flags.weight_decay) elif flags.type.lower() == 'adamw': self.optimizer = torch.optim.AdamW( parameters, lr=base_lr, weight_decay=flags.weight_decay) else: raise ValueError
[docs] def config_lr_scheduler(self): r''' Builds the learning-rate scheduler for the current optimizer. ''' # This function must be called after :func:`configure_optimizer` self.scheduler = get_lr_scheduler(self.optimizer, self.FLAGS.SOLVER)
[docs] def configure_log(self, set_writer=True): r''' Configures the log directory, checkpoint directory, and writers. Args: set_writer (bool): If True, creates the TensorBoard writer. ''' self.logdir = self.FLAGS.SOLVER.logdir self.ckpt_dir = os.path.join(self.logdir, 'checkpoints') self.log_file = os.path.join(self.logdir, 'log.csv') if self.is_master: tqdm.write('Logdir: ' + self.logdir) if self.is_master and set_writer: self.summary_writer = SummaryWriter(self.logdir, flush_secs=20) os.makedirs(self.ckpt_dir, exist_ok=True)
[docs] def train_epoch(self, epoch): r''' Runs one full training epoch. Args: epoch (int): The current epoch number. ''' self.model.train() if self.world_size > 1: self.train_loader.sampler.set_epoch(epoch) flags = self.FLAGS.SOLVER avg_tracker = AverageTracker() rng = range(len(self.train_loader)) for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm): # clear cache every 50 iterations if flags.empty_cache > 0 and it % flags.empty_cache == 0: torch.cuda.empty_cache() # load data batch = next(self.train_iter) batch['iter_num'] = it batch['epoch'] = epoch # forward and backward self.optimizer.zero_grad(flags.zero_grad_to_none) clip_grad = self.FLAGS.SOLVER.clip_grad if self.amp_mode == 'none': output = self.train_step(batch) loss = output['train/loss'] loss.backward() if clip_grad > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad) self.optimizer.step() elif self.amp_mode == 'fp16': with torch.autocast('cuda', dtype=torch.float16): output = self.train_step(batch) loss = output['train/loss'] self.scaler.scale(loss).backward() if clip_grad > 0: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad) self.scaler.step(self.optimizer) self.scaler.update() elif self.amp_mode == 'bf16': with torch.autocast('cuda', dtype=torch.bfloat16): output = self.train_step(batch) loss = output['train/loss'] loss.backward() if clip_grad > 0: torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad) self.optimizer.step() else: raise ValueError(f'Invalid amp mode: {self.amp_mode}') # track the averaged tensors avg_tracker.update(output) avg_tracker.record_time() # output intermediate logs log_per_iter = flags.log_per_iter if self.is_master and log_per_iter > 0 and it % log_per_iter == 0: notes = 'iter: %d' % it avg_tracker.log(epoch, msg_tag='- ', notes=notes, print_time=False) # save logs if self.world_size > 1: avg_tracker.average_all_gather() if self.is_master: avg_tracker.log(epoch, self.summary_writer, print_time=True)
[docs] def test_epoch(self, epoch): r''' Runs one full test epoch. Args: epoch (int): The current epoch number. ''' self.model.eval() avg_tracker = AverageTracker() rng = range(len(self.test_loader)) for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm): # clear cache every 50 iterations if it % 50 == 0 and self.FLAGS.SOLVER.empty_cache: torch.cuda.empty_cache() # forward batch = next(self.test_iter) batch['iter_num'] = it batch['epoch'] = epoch # with torch.no_grad(): output = self.test_step(batch) # track the averaged tensors avg_tracker.update(output) if self.world_size > 1: avg_tracker.average_all_gather() if self.is_master: self.result_callback(avg_tracker, epoch) self.save_best_checkpoint(avg_tracker, epoch) avg_tracker.log(epoch, self.summary_writer, self.log_file, msg_tag='=>')
[docs] def eval_epoch(self, epoch): r''' Runs one evaluation epoch in ``evaluate`` mode. Args: epoch (int): The current epoch number. ''' self.model.eval() eval_step = min(self.FLAGS.SOLVER.eval_step, len(self.test_loader)) if eval_step < 1: eval_step = len(self.test_loader) for it in tqdm(range(eval_step), ncols=80, leave=False): batch = next(self.test_iter) batch['iter_num'] = it batch['epoch'] = epoch with torch.no_grad(): self.eval_step(batch)
[docs] def save_best_checkpoint(self, tracker: AverageTracker, epoch: int): r''' Saves the best-performing model according to ``SOLVER.best_val``. Args: tracker (AverageTracker): The tracker holding averaged test metrics. epoch (int): The current epoch number. ''' best_val = self.FLAGS.SOLVER.best_val if not (best_val and self.FLAGS.SOLVER.run == 'train'): return # return if best_val is empty or it is not in the train mode compare, key = best_val.split(':') key = 'test/' + key assert compare in ['max', 'min'] operator = lambda x, y: x > y if compare == 'max' else x < y if key in tracker.value: curr_val = (tracker.value[key] / tracker.num[key]).item() if self.best_val is None or operator(curr_val, self.best_val): self.best_val = curr_val model_dict = (self.model.module.state_dict() if self.world_size > 1 else self.model.state_dict()) torch.save(model_dict, os.path.join(self.logdir, 'best_model.pth')) msg = 'epoch: %d, %s: %f' % (epoch, key, curr_val) with open(os.path.join(self.logdir, 'best_model.txt'), 'a') as fid: fid.write(msg + '\n') tqdm.write('=> Best model at ' + msg)
[docs] def save_checkpoint(self, epoch): r''' Saves a training checkpoint for the given epoch. Args: epoch (int): The epoch number used in the checkpoint filename. ''' if not self.is_master: return # clean up ckpts = sorted(os.listdir(self.ckpt_dir)) ckpts = [ck for ck in ckpts if ck.endswith('.pth') or ck.endswith('.tar')] if len(ckpts) > self.FLAGS.SOLVER.ckpt_num: for ckpt in ckpts[:-self.FLAGS.SOLVER.ckpt_num]: os.remove(os.path.join(self.ckpt_dir, ckpt)) # save ckpt model_dict = (self.model.module.state_dict() if self.world_size > 1 else self.model.state_dict()) ckpt_name = os.path.join(self.ckpt_dir, '%05d' % epoch) torch.save(model_dict, ckpt_name + '.model.pth') ckpt_data = {'model_dict': model_dict, 'epoch': epoch, 'optimizer_dict': self.optimizer.state_dict(), 'scheduler_dict': self.scheduler.state_dict()} if self.scaler: ckpt_data['scaler_dict'] = self.scaler.state_dict() torch.save(ckpt_data, ckpt_name + '.solver.tar')
[docs] def load_checkpoint(self): r''' Loads the requested checkpoint or the latest checkpoint defined in ``SOLVER.ckpt``. ''' ckpt = self.FLAGS.SOLVER.ckpt if not ckpt: # If ckpt is empty, then get the latest checkpoint from ckpt_dir if not os.path.exists(self.ckpt_dir): return ckpts = sorted(os.listdir(self.ckpt_dir)) ckpts = [ck for ck in ckpts if ck.endswith('solver.tar')] if len(ckpts) > 0: ckpt = os.path.join(self.ckpt_dir, ckpts[-1]) if not ckpt: return # return if ckpt is still empty # load trained model # check: map_location = {'cuda:0' : 'cuda:%d' % self.rank} trained_dict = torch.load(ckpt, map_location='cuda') if ckpt.endswith('.solver.tar'): model_dict = trained_dict['model_dict'] self.start_epoch = trained_dict['epoch'] + 1 # !!! add 1 if self.optimizer: self.optimizer.load_state_dict(trained_dict['optimizer_dict']) if self.scheduler: self.scheduler.load_state_dict(trained_dict['scheduler_dict']) if self.amp_mode == 'fp16' and 'scaler_dict' in trained_dict: self.scaler.load_state_dict(trained_dict['scaler_dict']) else: model_dict = trained_dict model = self.model.module if self.world_size > 1 else self.model model.load_state_dict(model_dict) # print messages if self.is_master: tqdm.write('Load the checkpoint: %s' % ckpt) tqdm.write('The start_epoch is %d' % self.start_epoch)
[docs] def manual_seed(self): r''' Sets random seeds when ``SOLVER.rand_seed`` is positive. ''' rand_seed = self.FLAGS.SOLVER.rand_seed if rand_seed > 0: random.seed(rand_seed) np.random.seed(rand_seed) torch.manual_seed(rand_seed) torch.cuda.manual_seed(rand_seed) torch.cuda.manual_seed_all(rand_seed) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True
[docs] def train(self): r''' Runs the end-to-end training workflow. ''' self.manual_seed() self.config_model() self.config_dataloader() self.config_optimizer() self.config_lr_scheduler() self.configure_log() self.load_checkpoint() flags = self.FLAGS.SOLVER rng = range(self.start_epoch, flags.max_epoch+1) for epoch in tqdm(rng, ncols=80, disable=self.disable_tqdm): # training epoch self.train_epoch(epoch) # update learning rate self.scheduler.step() if self.is_master: lr = self.scheduler.get_last_lr() # lr is a list self.summary_writer.add_scalar('train/lr', lr[0], epoch) # checkpoint and test at specified intervals if epoch != 0 and epoch % flags.test_every_epoch == 0: self.save_checkpoint(epoch) self.test_epoch(epoch) # sync and exit if self.world_size > 1: torch.distributed.barrier()
[docs] def test(self): r''' Loads a checkpoint and runs the test loop once. ''' self.manual_seed() self.config_model() self.configure_log(set_writer=False) self.config_dataloader(disable_train_data=True) self.load_checkpoint() self.test_epoch(epoch=0)
[docs] def evaluate(self): r''' Loads a checkpoint and runs the evaluation loop. ''' self.manual_seed() self.config_model() self.configure_log(set_writer=False) self.config_dataloader(disable_train_data=True) self.load_checkpoint() for epoch in tqdm(range(self.FLAGS.SOLVER.eval_epoch), ncols=80): self.eval_epoch(epoch)
[docs] def profile(self): r''' Profiles a few training iterations with the PyTorch profiler. Set ``DATA.train.num_workers 0`` when using this function. ''' self.config_model() self.config_dataloader() logdir = self.FLAGS.SOLVER.logdir # check newer_than_191 = version.parse(torch.__version__) > version.parse('1.9.1') if not newer_than_191: print('This function is only available for Pytorch>1.9.1.') return # profile batch = next(iter(self.train_loader)) schedule = torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1) activities = [torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ] with torch.profiler.profile( activities=activities, schedule=schedule, on_trace_ready=torch.profiler.tensorboard_trace_handler(logdir), record_shapes=True, profile_memory=True, with_stack=False, with_modules=True) as prof: for _ in range(5): if self.amp_mode == 'none': output = self.train_step(batch) loss = output['train/loss'] loss.backward() elif self.amp_mode == 'fp16': with torch.autocast('cuda', dtype=torch.float16): output = self.train_step(batch) loss = output['train/loss'] self.scaler.scale(loss).backward() elif self.amp_mode == 'bf16': with torch.autocast('cuda', dtype=torch.bfloat16): output = self.train_step(batch) loss = output['train/loss'] loss.backward() prof.step() print(prof.key_averages(group_by_input_shape=True, group_by_stack_n=10) .table(sort_by="cuda_time_total", row_limit=10)) print(prof.key_averages(group_by_input_shape=True, group_by_stack_n=10) .table(sort_by="cuda_memory_usage", row_limit=10))
[docs] def run(self): r''' Dispatches to the run mode configured in ``SOLVER.run``. ''' eval('self.%s()' % self.FLAGS.SOLVER.run)
[docs] @staticmethod def get_world_size(FLAGS): r''' Returns the distributed world size implied by the launch mode. Args: FLAGS: The experiment config tree. ''' world_size = 1 if FLAGS.SOLVER.ddp_mode == "spawn": world_size = len(FLAGS.SOLVER.gpu) elif FLAGS.SOLVER.ddp_mode == "torchrun": world_size = int(os.environ.get("WORLD_SIZE", 1)) else: raise NotImplementedError return world_size
[docs] @classmethod def update_configs(cls): r''' Updates class-specific configs before parsing command-line arguments. ''' pass
[docs] @classmethod def worker(cls, rank, FLAGS): r''' Runs one solver worker process. Args: rank (int): The process rank in the current launch. FLAGS: The experiment config tree. ''' world_size = cls.get_world_size(FLAGS) newer_than_280 = version.parse(torch.__version__) >= version.parse('2.8.0') if FLAGS.SOLVER.ddp_mode == "spawn": # Set the GPU to use. gpu_rank = FLAGS.SOLVER.gpu[rank] torch.cuda.set_device(gpu_rank) if world_size > 1: # Initialize the process group. This piece of code only supports the # `single node + multiple GPU` mode. url = 'tcp://localhost:%d' % FLAGS.SOLVER.port param = {'backend': 'nccl', 'init_method': url, 'world_size': world_size, 'rank': rank} if newer_than_280: param['device_id'] = gpu_rank torch.distributed.init_process_group(**param) torch.distributed.barrier() elif FLAGS.SOLVER.ddp_mode == "torchrun": local_rank = int(os.environ.get("LOCAL_RANK", 0)) torch.cuda.set_device(local_rank) if world_size > 1: # Initialize the process group. torch.distributed.run ensures that this # will work by exporting all the env vars needed to initialize the # process group. Support `multiple nodes + multiple GPUs` mode. param = {'backend': 'nccl', 'init_method': 'env://'} if newer_than_280: param['device_id'] = local_rank torch.distributed.init_process_group(**param) torch.distributed.barrier() else: raise NotImplementedError # The master process is responsible for logging, writing and loading # checkpoints. In the multi-GPU setting, we assign the master role to # the rank 0 process. is_master = rank == 0 the_solver = cls(FLAGS, is_master) the_solver.run() # clean up if torch.distributed.is_initialized(): torch.distributed.destroy_process_group()
[docs] @classmethod def main(cls): r''' Parses configs and launches the solver with the configured DDP mode. ''' cls.update_configs() FLAGS = parse_args() if FLAGS.SOLVER.ddp_mode == "spawn": num_gpus = len(FLAGS.SOLVER.gpu) if num_gpus > 1: torch.multiprocessing.spawn(cls.worker, nprocs=num_gpus, args=(FLAGS,)) else: cls.worker(0, FLAGS) elif FLAGS.SOLVER.ddp_mode == "torchrun": # FLAGS.SOLVER.gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "") rank = int(os.environ.get("RANK", 0)) cls.worker(rank, FLAGS) else: raise NotImplementedError