# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import time
import torch
import torch.distributed
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from tqdm import tqdm
from typing import Dict, Optional
[docs]class AverageTracker:
r''' Tracks and logs averaged scalar tensors across iterations and epochs. '''
def __init__(self):
r''' Initializes the tracker state. '''
self.value = dict()
self.num = dict()
self.max_len = 76
self.start_time = self.get_time()
[docs] def get_time(self):
r''' Returns the current synchronized wall-clock time. '''
torch.cuda.synchronize()
return time.time()
[docs] def update(self, value: Dict[str, torch.Tensor]):
r'''Update the tracker with the given value. This function is called at the
end of each iteration.
'''
for key, val in value.items():
self.value[key] = self.value.get(key, 0) + val.detach()
self.num[key] = self.num.get(key, 0) + 1
[docs] def record_time(self, num_iters: int = 1):
r''' Roughly records the elapsed time per iteration.
Args:
num_iters (int): The number of iterations represented by the update.
'''
self.value['time/iter'] = self.get_time() - self.start_time
self.num['time/iter'] = self.num.get('time/iter', 0) + num_iters
[docs] def average(self):
r''' Returns the averaged values accumulated in the tracker. '''
return {key: float(val) / self.num[key] for key, val in self.value.items()}
[docs] @torch.no_grad()
def average_all_gather(self):
r'''Average the tensors on all GPUs using all_gather, which is called at the
end of each epoch.
'''
for key, tensor in self.value.items():
if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
# only gather tensors on GPU
tensors_gather = [torch.ones_like(tensor)
for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
tensors = torch.stack(tensors_gather, dim=0)
self.value[key] = torch.mean(tensors)
[docs] def log(self, epoch: int, summary_writer: Optional[SummaryWriter] = None,
log_file: Optional[str] = None, msg_tag: str = '->', notes: str = '',
print_time: bool = True, print_memory: bool = False):
r''' Logs the average value to the console, TensorBoard, and a log file.
Args:
epoch (int): The current epoch index.
summary_writer (SummaryWriter or None): The TensorBoard writer.
log_file (str or None): The CSV-like log file path.
msg_tag (str): The prefix printed before the log line.
notes (str): Extra notes appended to the log message.
print_time (bool): If True, prints the timestamp and elapsed time.
print_memory (bool): If True, prints the reserved CUDA memory.
'''
avg = self.average()
msg = 'Epoch: %d' % epoch
for key, val in avg.items():
msg += ', %s: %.3f' % (key, val)
if summary_writer is not None:
summary_writer.add_scalar(key, val, epoch)
# if the log_file is provided, save the log
if log_file is not None:
with open(log_file, 'a') as fid:
fid.write(msg + '\n')
# memory
memory = ''
if print_memory and torch.cuda.is_available():
size = torch.cuda.memory_reserved()
# size = torch.cuda.memory_allocated()
memory = ', memory: {:.3f}GB'.format(size / 2**30)
# time
time_str = ''
if print_time:
curr_time = self.get_time()
time_str += ', time: ' + datetime.now().strftime("%Y/%m/%d %H:%M:%S")
time_str += ', duration: {:.2f}s'.format(curr_time - self.start_time)
# other notes
if notes:
notes = ', ' + notes
# concatenate all messages
msg += memory + time_str + notes
# split the msg for better display
chunks = [msg[i:i+self.max_len] for i in range(0, len(msg), self.max_len)]
msg = (msg_tag + ' ') + ('\n' + len(msg_tag) * ' ' + ' ').join(chunks)
tqdm.write(msg)