# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
# autopep8: off
import os
import sys
import shutil
import argparse
from datetime import datetime
from yacs.config import CfgNode as CN
_C = CN(new_allowed=True)
_C.BASE = ['']
# SOLVER related parameters
_C.SOLVER = CN(new_allowed=True)
_C.SOLVER.alias = '' # The experiment alias
_C.SOLVER.gpu = (0,) # The gpu ids
_C.SOLVER.run = 'train' # Choose from train or test
_C.SOLVER.logdir = 'logs' # Directory where to write event logs
_C.SOLVER.ckpt = '' # Restore weights from checkpoint file
_C.SOLVER.ckpt_num = 10 # The number of checkpoint kept
_C.SOLVER.type = 'sgd' # Choose from sgd or adam
_C.SOLVER.weight_decay = 0.0005 # The weight decay on model weights
_C.SOLVER.clip_grad = -1.0 # Clip gradient norm (-1: disable)
_C.SOLVER.max_epoch = 300 # Maximum training epoch
_C.SOLVER.warmup_epoch = 20 # The warmup epoch number
_C.SOLVER.warmup_init = 0.001 # The initial ratio of the warmup
_C.SOLVER.eval_epoch = 1 # Maximum evaluating epoch
_C.SOLVER.eval_step = -1 # Maximum evaluating steps
_C.SOLVER.test_every_epoch = 10 # Test model every n training epochs
_C.SOLVER.log_per_iter = -1 # Output log every k training iteration
_C.SOLVER.best_val = 'min:loss' # The best validation metric
_C.SOLVER.zero_grad_to_none = False # Set optimizer.zero_grad(set_to_none)
_C.SOLVER.amp_mode = 'none' # Use automatic mixed precision
_C.SOLVER.ddp_mode = 'spawn' # DistributedDataParallel mode
_C.SOLVER.lr_type = 'step' # Learning rate type: step or cos
_C.SOLVER.lr = 0.1 # Initial learning rate
_C.SOLVER.lr_min = 0.0001 # The minimum learning rate
_C.SOLVER.gamma = 0.1 # Learning rate step-wise decay
_C.SOLVER.milestones = (120,180,) # Learning rate milestones
_C.SOLVER.lr_power = 0.9 # Used in poly learning rate
# _C.SOLVER.dist_url = 'tcp://localhost:10001'
_C.SOLVER.port = 10001 # The port number for distributed training
_C.SOLVER.progress_bar = True # Enable the progress_bar or not
_C.SOLVER.rand_seed = -1 # Fix the random seed if larger than 0
_C.SOLVER.empty_cache = 50 # Empty cuda cache periodically
# DATA related parameters
_C.DATA = CN(new_allowed=True)
_C.DATA.train = CN(new_allowed=True)
_C.DATA.train.name = '' # The name of the dataset
_C.DATA.train.disable = False # Disable this dataset or not
_C.DATA.train.pin_memory = True
# For octree building
_C.DATA.train.depth = 5 # The octree depth
_C.DATA.train.full_depth = 2 # The full depth
# _C.DATA.train.adaptive = False # Build the adaptive octree
# For transformation
_C.DATA.train.orient_normal = '' # Used to re-orient normal directions
# For data augmentation
_C.DATA.train.distort = False # Whether to apply data augmentation
_C.DATA.train.scale = 0.0 # Scale the points
_C.DATA.train.uniform = False # Generate uniform scales
_C.DATA.train.jitter = 0.0 # Jitter the points
_C.DATA.train.interval = (1, 1, 1) # Use interval&angle to generate random angle
_C.DATA.train.angle = (180, 180, 180)
_C.DATA.train.flip = (0.0, 0.0, 0.0)
# For data loading
_C.DATA.train.location = '' # The data location
_C.DATA.train.filelist = '' # The data filelist
_C.DATA.train.batch_size = 32 # Training data batch size
_C.DATA.train.take = -1 # Number of samples used for training
_C.DATA.train.num_workers = 4 # Number of workers to load the data
_C.DATA.train.shuffle = False # Shuffle the input data
# _C.DATA.train.in_memory = False # Load the training data into memory
_C.DATA.test = _C.DATA.train.clone()
_C.DATA.test.num_workers = 2
# MODEL related parameters
_C.MODEL = CN(new_allowed=True)
_C.MODEL.name = '' # The name of the model
_C.MODEL.feature = 'ND' # The input features
_C.MODEL.channel = 3 # The input feature channel
_C.MODEL.nempty = False # Perform Octree Conv on non-empty octree nodes
_C.MODEL.sync_bn = False # Use sync_bn when training the network
_C.MODEL.use_checkpoint = False # Use checkpoint to save memory
_C.MODEL.find_unused_parameters = False # Used in DistributedDataParallel
# loss related parameters
_C.LOSS = CN(new_allowed=True)
_C.LOSS.name = '' # The name of the loss
# _C.LOSS.num_class = 40 # The class number for the cross-entropy loss
# _C.LOSS.label_smoothing = 0.0 # The factor of label smoothing
# backup the commands
_C.SYS = CN(new_allowed=True)
_C.SYS.cmds = '' # Used to backup the commands
FLAGS = _C
def _load_from_file(filename):
r''' Loads a config file together with all config files listed in ``BASE``.
The config files are merged from base to leaf so that the current experiment
file overrides values defined in its dependencies.
Args:
filename (str): The path to the config file.
Returns:
list: A list of :class:`yacs.config.CfgNode` objects in merge order.
'''
cfgs = []
bases = [filename]
while len(bases) > 0:
base = bases.pop(0)
if base:
with open(base, 'r') as fid:
cfg = CN.load_cfg(fid)
cfgs.append(cfg)
if 'BASE' in cfg:
assert isinstance(cfg.BASE, list), 'BASE should be a list'
bases += cfg.BASE
cfgs.reverse()
return cfgs
def _update_config(FLAGS, args):
r''' Updates :attr:`FLAGS` from the parsed command-line arguments.
Args:
FLAGS (CfgNode): The global config tree.
args (argparse.Namespace): The parsed command-line arguments.
'''
FLAGS.defrost()
if args.config:
# FLAGS.merge_from_file(args.config)
cfgs = _load_from_file(args.config)
for cfg in cfgs:
FLAGS.merge_from_other_cfg(cfg)
if args.opts:
FLAGS.merge_from_list(args.opts)
FLAGS.SYS.cmds = 'python ' + ' '.join(sys.argv)
# update logdir
alias = FLAGS.SOLVER.alias.lower()
if 'time' in alias: # 'time' is a special keyword
alias = alias.replace('time', datetime.now().strftime('%m%d%H%M')) #%S
if alias != '':
FLAGS.SOLVER.logdir += '_' + alias
FLAGS.freeze()
def _backup_config(FLAGS, args):
r''' Backs up the resolved config into the logging directory.
Args:
FLAGS (CfgNode): The final merged config tree.
args (argparse.Namespace): The parsed command-line arguments.
'''
logdir = FLAGS.SOLVER.logdir
os.makedirs(logdir, exist_ok=True)
# copy the file to logdir
if args.config:
shutil.copy2(args.config, logdir)
# dump all configs
filename = os.path.join(logdir, 'all_configs.yaml')
with open(filename, 'w') as fid:
fid.write(FLAGS.dump())
def _set_env_var(FLAGS):
r''' Exports the selected GPU ids through ``CUDA_VISIBLE_DEVICES``.
Args:
FLAGS (CfgNode): The config tree containing ``SOLVER.gpu``.
'''
gpus = ','.join([str(a) for a in FLAGS.SOLVER.gpu])
os.environ['CUDA_VISIBLE_DEVICES'] = gpus
[docs]def get_config():
r''' Returns the global default config tree. '''
return FLAGS
[docs]def parse_args(backup=True):
r''' Parses command-line arguments and returns the merged config.
Args:
backup (bool): If True, saves the experiment config into the log directory.
Returns:
CfgNode: The final merged and frozen config tree.
'''
parser = argparse.ArgumentParser(description='The configs')
parser.add_argument('--config', type=str,
help='experiment configure file name')
parser.add_argument('--local-rank', '--local_rank', type=int, default=0,
help='local rank for distributed training')
parser.add_argument('opts', nargs=argparse.REMAINDER,
help="Modify config options using the command-line")
args = parser.parse_args()
_update_config(FLAGS, args)
if backup:
_backup_config(FLAGS, args)
# _set_env_var(FLAGS)
return FLAGS
if __name__ == '__main__':
flags = parse_args(backup=False)
print(flags)