# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import os
import torch
import torch.nn
import torch.optim
import torch.distributed
import torch.multiprocessing
import torch.utils.data
import random
import numpy as np
from tqdm import tqdm
from packaging import version
from torch.utils.tensorboard import SummaryWriter
from .sampler import InfSampler, DistributedInfSampler
from .tracker import AverageTracker
from .config import parse_args
from .lr_scheduler import get_lr_scheduler
[docs]class Solver:
r''' A lightweight base class for PyTorch training and evaluation loops.
Subclasses implement the model, dataset, and step-specific logic, while this
class handles logging, checkpointing, distributed training, and scheduling.
'''
def __init__(self, FLAGS, is_master=True):
r''' Initializes the solver runtime state.
Args:
FLAGS: The experiment config tree.
is_master (bool): If True, enables logging and checkpoint writing.
'''
self.FLAGS = FLAGS
self.is_master = is_master
self.world_size = self.get_world_size(FLAGS)
self.device = torch.cuda.current_device()
self.disable_tqdm = not (is_master and FLAGS.SOLVER.progress_bar)
self.start_epoch = 1
self.amp_mode = FLAGS.SOLVER.amp_mode
self.model = None # torch.nn.Module
self.optimizer = None # torch.optim.Optimizer
self.scheduler = None # torch.optim.lr_scheduler._LRScheduler
self.summary_writer = None # torch.utils.tensorboard.SummaryWriter
self.log_file = None # str, used to save training logs
self.eval_rst = dict() # used to save evalation results
self.best_val = None # used to save the best validation result
# config the gradscaler
newer_than_230 = version.parse(torch.__version__) > version.parse('2.3.0')
if self.amp_mode == 'fp16':
self.scaler = torch.GradScaler() if newer_than_230 else torch.cuda.amp.GradScaler()
else:
self.scaler = None
[docs] def get_model(self):
r''' Returns the model used by the current experiment. '''
raise NotImplementedError
[docs] def get_dataset(self, flags):
r''' Returns a dataset and its collate function.
Args:
flags: The dataset config node.
'''
raise NotImplementedError
[docs] def train_step(self, batch):
r''' Returns the outputs of one training step.
The returned dict should contain the key ``train/loss``, which is the scalar
loss used for back propagation and optimizer updates.
Args:
batch (dict): One batch produced by the training dataloader.
'''
raise NotImplementedError
[docs] def test_step(self, batch):
r''' Returns the outputs of one testing step.
Args:
batch (dict): One batch produced by the test dataloader.
'''
raise NotImplementedError
[docs] def eval_step(self, batch):
r''' Evaluates the model on a batch.
Args:
batch (dict): One batch produced by the evaluation dataloader.
'''
raise NotImplementedError
[docs] def result_callback(self, avg_tracker: AverageTracker, epoch):
r''' Runs custom logic after a test epoch finishes.
Args:
avg_tracker (AverageTracker): The epoch-level tracker.
epoch (int): The current epoch number.
'''
pass # additional operations based on the avg_tracker
[docs] def config_dataloader(self, disable_train_data=False):
r''' Builds the train and test dataloaders when enabled.
Args:
disable_train_data (bool): If True, skips the training dataloader.
'''
flags_train, flags_test = self.FLAGS.DATA.train, self.FLAGS.DATA.test
if not disable_train_data and not flags_train.disable:
self.train_loader = self.get_dataloader(flags_train)
self.train_iter = iter(self.train_loader)
if not flags_test.disable:
self.test_loader = self.get_dataloader(flags_test)
self.test_iter = iter(self.test_loader)
[docs] def get_dataloader(self, flags):
r''' Builds one dataloader from a dataset config node.
Args:
flags: The dataset config node.
Returns:
torch.utils.data.DataLoader: The configured dataloader.
'''
dataset, collate_fn = self.get_dataset(flags)
if self.world_size > 1:
sampler = DistributedInfSampler(dataset, shuffle=flags.shuffle)
else:
sampler = InfSampler(dataset, shuffle=flags.shuffle)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=flags.batch_size, num_workers=flags.num_workers,
sampler=sampler, collate_fn=collate_fn, pin_memory=flags.pin_memory)
return data_loader
[docs] def config_model(self):
r''' Builds the model, moves it to CUDA, and wraps DDP if needed. '''
flags = self.FLAGS.MODEL
model = self.get_model(flags)
model.cuda(device=self.device)
if self.world_size > 1:
if flags.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[self.device],
output_device=self.device, broadcast_buffers=False,
find_unused_parameters=flags.find_unused_parameters)
if self.is_master:
print(model) # print the model structure
total_params = 0
for p in model.parameters():
total_params += p.numel()
print("Total number of parameters: %.3fM" % (total_params / 1e6))
self.model = model
[docs] def config_optimizer(self):
r''' Builds the optimizer for the current model.
The base learning rate is scaled by ``world_size`` to match the effective
global batch size in distributed training.
'''
# The base learning rate `base_lr` scales with regard to the world_size
flags = self.FLAGS.SOLVER
base_lr = flags.lr * self.world_size
parameters = self.model.parameters()
# config the optimizer
if flags.type.lower() == 'sgd':
self.optimizer = torch.optim.SGD(
parameters, lr=base_lr, weight_decay=flags.weight_decay, momentum=0.9)
elif flags.type.lower() == 'adam':
self.optimizer = torch.optim.Adam(
parameters, lr=base_lr, weight_decay=flags.weight_decay)
elif flags.type.lower() == 'adamw':
self.optimizer = torch.optim.AdamW(
parameters, lr=base_lr, weight_decay=flags.weight_decay)
else:
raise ValueError
[docs] def config_lr_scheduler(self):
r''' Builds the learning-rate scheduler for the current optimizer. '''
# This function must be called after :func:`configure_optimizer`
self.scheduler = get_lr_scheduler(self.optimizer, self.FLAGS.SOLVER)
[docs] def train_epoch(self, epoch):
r''' Runs one full training epoch.
Args:
epoch (int): The current epoch number.
'''
self.model.train()
if self.world_size > 1:
self.train_loader.sampler.set_epoch(epoch)
flags = self.FLAGS.SOLVER
avg_tracker = AverageTracker()
rng = range(len(self.train_loader))
for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm):
# clear cache every 50 iterations
if flags.empty_cache > 0 and it % flags.empty_cache == 0:
torch.cuda.empty_cache()
# load data
batch = next(self.train_iter)
batch['iter_num'] = it
batch['epoch'] = epoch
# forward and backward
self.optimizer.zero_grad(flags.zero_grad_to_none)
clip_grad = self.FLAGS.SOLVER.clip_grad
if self.amp_mode == 'none':
output = self.train_step(batch)
loss = output['train/loss']
loss.backward()
if clip_grad > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad)
self.optimizer.step()
elif self.amp_mode == 'fp16':
with torch.autocast('cuda', dtype=torch.float16):
output = self.train_step(batch)
loss = output['train/loss']
self.scaler.scale(loss).backward()
if clip_grad > 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad)
self.scaler.step(self.optimizer)
self.scaler.update()
elif self.amp_mode == 'bf16':
with torch.autocast('cuda', dtype=torch.bfloat16):
output = self.train_step(batch)
loss = output['train/loss']
loss.backward()
if clip_grad > 0:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_grad)
self.optimizer.step()
else:
raise ValueError(f'Invalid amp mode: {self.amp_mode}')
# track the averaged tensors
avg_tracker.update(output)
avg_tracker.record_time()
# output intermediate logs
log_per_iter = flags.log_per_iter
if self.is_master and log_per_iter > 0 and it % log_per_iter == 0:
notes = 'iter: %d' % it
avg_tracker.log(epoch, msg_tag='- ', notes=notes, print_time=False)
# save logs
if self.world_size > 1:
avg_tracker.average_all_gather()
if self.is_master:
avg_tracker.log(epoch, self.summary_writer, print_time=True)
[docs] def test_epoch(self, epoch):
r''' Runs one full test epoch.
Args:
epoch (int): The current epoch number.
'''
self.model.eval()
avg_tracker = AverageTracker()
rng = range(len(self.test_loader))
for it in tqdm(rng, ncols=80, leave=False, disable=self.disable_tqdm):
# clear cache every 50 iterations
if it % 50 == 0 and self.FLAGS.SOLVER.empty_cache:
torch.cuda.empty_cache()
# forward
batch = next(self.test_iter)
batch['iter_num'] = it
batch['epoch'] = epoch
# with torch.no_grad():
output = self.test_step(batch)
# track the averaged tensors
avg_tracker.update(output)
if self.world_size > 1:
avg_tracker.average_all_gather()
if self.is_master:
self.result_callback(avg_tracker, epoch)
self.save_best_checkpoint(avg_tracker, epoch)
avg_tracker.log(epoch, self.summary_writer, self.log_file, msg_tag='=>')
[docs] def eval_epoch(self, epoch):
r''' Runs one evaluation epoch in ``evaluate`` mode.
Args:
epoch (int): The current epoch number.
'''
self.model.eval()
eval_step = min(self.FLAGS.SOLVER.eval_step, len(self.test_loader))
if eval_step < 1:
eval_step = len(self.test_loader)
for it in tqdm(range(eval_step), ncols=80, leave=False):
batch = next(self.test_iter)
batch['iter_num'] = it
batch['epoch'] = epoch
with torch.no_grad():
self.eval_step(batch)
[docs] def save_best_checkpoint(self, tracker: AverageTracker, epoch: int):
r''' Saves the best-performing model according to ``SOLVER.best_val``.
Args:
tracker (AverageTracker): The tracker holding averaged test metrics.
epoch (int): The current epoch number.
'''
best_val = self.FLAGS.SOLVER.best_val
if not (best_val and self.FLAGS.SOLVER.run == 'train'):
return # return if best_val is empty or it is not in the train mode
compare, key = best_val.split(':')
key = 'test/' + key
assert compare in ['max', 'min']
operator = lambda x, y: x > y if compare == 'max' else x < y
if key in tracker.value:
curr_val = (tracker.value[key] / tracker.num[key]).item()
if self.best_val is None or operator(curr_val, self.best_val):
self.best_val = curr_val
model_dict = (self.model.module.state_dict() if self.world_size > 1
else self.model.state_dict())
torch.save(model_dict, os.path.join(self.logdir, 'best_model.pth'))
msg = 'epoch: %d, %s: %f' % (epoch, key, curr_val)
with open(os.path.join(self.logdir, 'best_model.txt'), 'a') as fid:
fid.write(msg + '\n')
tqdm.write('=> Best model at ' + msg)
[docs] def save_checkpoint(self, epoch):
r''' Saves a training checkpoint for the given epoch.
Args:
epoch (int): The epoch number used in the checkpoint filename.
'''
if not self.is_master: return
# clean up
ckpts = sorted(os.listdir(self.ckpt_dir))
ckpts = [ck for ck in ckpts if ck.endswith('.pth') or ck.endswith('.tar')]
if len(ckpts) > self.FLAGS.SOLVER.ckpt_num:
for ckpt in ckpts[:-self.FLAGS.SOLVER.ckpt_num]:
os.remove(os.path.join(self.ckpt_dir, ckpt))
# save ckpt
model_dict = (self.model.module.state_dict() if self.world_size > 1
else self.model.state_dict())
ckpt_name = os.path.join(self.ckpt_dir, '%05d' % epoch)
torch.save(model_dict, ckpt_name + '.model.pth')
ckpt_data = {'model_dict': model_dict, 'epoch': epoch,
'optimizer_dict': self.optimizer.state_dict(),
'scheduler_dict': self.scheduler.state_dict()}
if self.scaler:
ckpt_data['scaler_dict'] = self.scaler.state_dict()
torch.save(ckpt_data, ckpt_name + '.solver.tar')
[docs] def load_checkpoint(self):
r''' Loads the requested checkpoint or the latest checkpoint defined
in ``SOLVER.ckpt``.
'''
ckpt = self.FLAGS.SOLVER.ckpt
if not ckpt:
# If ckpt is empty, then get the latest checkpoint from ckpt_dir
if not os.path.exists(self.ckpt_dir):
return
ckpts = sorted(os.listdir(self.ckpt_dir))
ckpts = [ck for ck in ckpts if ck.endswith('solver.tar')]
if len(ckpts) > 0:
ckpt = os.path.join(self.ckpt_dir, ckpts[-1])
if not ckpt:
return # return if ckpt is still empty
# load trained model
# check: map_location = {'cuda:0' : 'cuda:%d' % self.rank}
trained_dict = torch.load(ckpt, map_location='cuda')
if ckpt.endswith('.solver.tar'):
model_dict = trained_dict['model_dict']
self.start_epoch = trained_dict['epoch'] + 1 # !!! add 1
if self.optimizer:
self.optimizer.load_state_dict(trained_dict['optimizer_dict'])
if self.scheduler:
self.scheduler.load_state_dict(trained_dict['scheduler_dict'])
if self.amp_mode == 'fp16' and 'scaler_dict' in trained_dict:
self.scaler.load_state_dict(trained_dict['scaler_dict'])
else:
model_dict = trained_dict
model = self.model.module if self.world_size > 1 else self.model
model.load_state_dict(model_dict)
# print messages
if self.is_master:
tqdm.write('Load the checkpoint: %s' % ckpt)
tqdm.write('The start_epoch is %d' % self.start_epoch)
[docs] def manual_seed(self):
r''' Sets random seeds when ``SOLVER.rand_seed`` is positive. '''
rand_seed = self.FLAGS.SOLVER.rand_seed
if rand_seed > 0:
random.seed(rand_seed)
np.random.seed(rand_seed)
torch.manual_seed(rand_seed)
torch.cuda.manual_seed(rand_seed)
torch.cuda.manual_seed_all(rand_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
[docs] def train(self):
r''' Runs the end-to-end training workflow. '''
self.manual_seed()
self.config_model()
self.config_dataloader()
self.config_optimizer()
self.config_lr_scheduler()
self.configure_log()
self.load_checkpoint()
flags = self.FLAGS.SOLVER
rng = range(self.start_epoch, flags.max_epoch+1)
for epoch in tqdm(rng, ncols=80, disable=self.disable_tqdm):
# training epoch
self.train_epoch(epoch)
# update learning rate
self.scheduler.step()
if self.is_master:
lr = self.scheduler.get_last_lr() # lr is a list
self.summary_writer.add_scalar('train/lr', lr[0], epoch)
# checkpoint and test at specified intervals
if epoch != 0 and epoch % flags.test_every_epoch == 0:
self.save_checkpoint(epoch)
self.test_epoch(epoch)
# sync and exit
if self.world_size > 1:
torch.distributed.barrier()
[docs] def test(self):
r''' Loads a checkpoint and runs the test loop once. '''
self.manual_seed()
self.config_model()
self.configure_log(set_writer=False)
self.config_dataloader(disable_train_data=True)
self.load_checkpoint()
self.test_epoch(epoch=0)
[docs] def evaluate(self):
r''' Loads a checkpoint and runs the evaluation loop. '''
self.manual_seed()
self.config_model()
self.configure_log(set_writer=False)
self.config_dataloader(disable_train_data=True)
self.load_checkpoint()
for epoch in tqdm(range(self.FLAGS.SOLVER.eval_epoch), ncols=80):
self.eval_epoch(epoch)
[docs] def profile(self):
r''' Profiles a few training iterations with the PyTorch profiler.
Set ``DATA.train.num_workers 0`` when using this function.
'''
self.config_model()
self.config_dataloader()
logdir = self.FLAGS.SOLVER.logdir
# check
newer_than_191 = version.parse(torch.__version__) > version.parse('1.9.1')
if not newer_than_191:
print('This function is only available for Pytorch>1.9.1.')
return
# profile
batch = next(iter(self.train_loader))
schedule = torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1)
activities = [torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA, ]
with torch.profiler.profile(
activities=activities, schedule=schedule,
on_trace_ready=torch.profiler.tensorboard_trace_handler(logdir),
record_shapes=True, profile_memory=True, with_stack=False,
with_modules=True) as prof:
for _ in range(5):
if self.amp_mode == 'none':
output = self.train_step(batch)
loss = output['train/loss']
loss.backward()
elif self.amp_mode == 'fp16':
with torch.autocast('cuda', dtype=torch.float16):
output = self.train_step(batch)
loss = output['train/loss']
self.scaler.scale(loss).backward()
elif self.amp_mode == 'bf16':
with torch.autocast('cuda', dtype=torch.bfloat16):
output = self.train_step(batch)
loss = output['train/loss']
loss.backward()
prof.step()
print(prof.key_averages(group_by_input_shape=True, group_by_stack_n=10)
.table(sort_by="cuda_time_total", row_limit=10))
print(prof.key_averages(group_by_input_shape=True, group_by_stack_n=10)
.table(sort_by="cuda_memory_usage", row_limit=10))
[docs] def run(self):
r''' Dispatches to the run mode configured in ``SOLVER.run``. '''
eval('self.%s()' % self.FLAGS.SOLVER.run)
[docs] @staticmethod
def get_world_size(FLAGS):
r''' Returns the distributed world size implied by the launch mode.
Args:
FLAGS: The experiment config tree.
'''
world_size = 1
if FLAGS.SOLVER.ddp_mode == "spawn":
world_size = len(FLAGS.SOLVER.gpu)
elif FLAGS.SOLVER.ddp_mode == "torchrun":
world_size = int(os.environ.get("WORLD_SIZE", 1))
else:
raise NotImplementedError
return world_size
[docs] @classmethod
def update_configs(cls):
r''' Updates class-specific configs before parsing command-line arguments. '''
pass
[docs] @classmethod
def worker(cls, rank, FLAGS):
r''' Runs one solver worker process.
Args:
rank (int): The process rank in the current launch.
FLAGS: The experiment config tree.
'''
world_size = cls.get_world_size(FLAGS)
newer_than_280 = version.parse(torch.__version__) >= version.parse('2.8.0')
if FLAGS.SOLVER.ddp_mode == "spawn":
# Set the GPU to use.
gpu_rank = FLAGS.SOLVER.gpu[rank]
torch.cuda.set_device(gpu_rank)
if world_size > 1:
# Initialize the process group. This piece of code only supports the
# `single node + multiple GPU` mode.
url = 'tcp://localhost:%d' % FLAGS.SOLVER.port
param = {'backend': 'nccl', 'init_method': url,
'world_size': world_size, 'rank': rank}
if newer_than_280: param['device_id'] = gpu_rank
torch.distributed.init_process_group(**param)
torch.distributed.barrier()
elif FLAGS.SOLVER.ddp_mode == "torchrun":
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if world_size > 1:
# Initialize the process group. torch.distributed.run ensures that this
# will work by exporting all the env vars needed to initialize the
# process group. Support `multiple nodes + multiple GPUs` mode.
param = {'backend': 'nccl', 'init_method': 'env://'}
if newer_than_280: param['device_id'] = local_rank
torch.distributed.init_process_group(**param)
torch.distributed.barrier()
else:
raise NotImplementedError
# The master process is responsible for logging, writing and loading
# checkpoints. In the multi-GPU setting, we assign the master role to
# the rank 0 process.
is_master = rank == 0
the_solver = cls(FLAGS, is_master)
the_solver.run()
# clean up
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
[docs] @classmethod
def main(cls):
r''' Parses configs and launches the solver with the configured DDP mode. '''
cls.update_configs()
FLAGS = parse_args()
if FLAGS.SOLVER.ddp_mode == "spawn":
num_gpus = len(FLAGS.SOLVER.gpu)
if num_gpus > 1:
torch.multiprocessing.spawn(cls.worker, nprocs=num_gpus, args=(FLAGS,))
else:
cls.worker(0, FLAGS)
elif FLAGS.SOLVER.ddp_mode == "torchrun":
# FLAGS.SOLVER.gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "")
rank = int(os.environ.get("RANK", 0))
cls.worker(rank, FLAGS)
else:
raise NotImplementedError