Source code for thsolver.sampler

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import torch
from torch.utils.data import Sampler, DistributedSampler, Dataset


[docs]class InfSampler(Sampler): r''' An infinite sampler that cycles through dataset indices. Args: dataset (Dataset): The dataset to sample from. shuffle (bool): If True, shuffles the sample order between cycles. ''' def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: r''' Initializes the sampler. ''' self.dataset = dataset self.shuffle = shuffle self.reset_sampler()
[docs] def reset_sampler(self): r''' Resets the sampled index order for the next pass over the dataset. ''' num = len(self.dataset) indices = torch.randperm(num) if self.shuffle else torch.arange(num) self.indices = indices.tolist() self.iter_num = 0
def __iter__(self): r''' Returns the sampler iterator. ''' return self def __next__(self): r''' Returns the next sample index and restarts when exhausted. ''' value = self.indices[self.iter_num] self.iter_num = self.iter_num + 1 if self.iter_num >= len(self.indices): self.reset_sampler() return value def __len__(self): r''' Returns the dataset size used by the sampler. ''' return len(self.dataset)
[docs]class DistributedInfSampler(DistributedSampler): r''' An infinite sampler for distributed data parallel training. Args: dataset (Dataset): The dataset to sample from. shuffle (bool): If True, shuffles the sample order between cycles. ''' def __init__(self, dataset: Dataset, shuffle: bool = True) -> None: r''' Initializes the distributed sampler. ''' super().__init__(dataset, shuffle=shuffle) self.reset_sampler()
[docs] def reset_sampler(self): r''' Rebuilds the local rank indices from the base distributed sampler. ''' self.indices = list(super().__iter__()) self.iter_num = 0
def __iter__(self): r''' Returns the sampler iterator. ''' return self def __next__(self): r''' Returns the next rank-local sample index and restarts when exhausted. ''' value = self.indices[self.iter_num] self.iter_num = self.iter_num + 1 if self.iter_num >= len(self.indices): self.reset_sampler() return value