Source code for thsolver.dataset

# --------------------------------------------------------
# Octree-based Sparse Convolutional Neural Networks
# Copyright (c) 2022 Peng-Shuai Wang <wangps@hotmail.com>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------

import os
import torch
import torch.utils.data
import numpy as np
from tqdm import tqdm


[docs]def read_file(filename): r''' Reads a binary sample file into a tensor of bytes. Args: filename (str): The file to read. Returns: torch.Tensor: A 1-D tensor containing the raw file bytes. ''' points = np.fromfile(filename, dtype=np.uint8) return torch.from_numpy(points) # convert it to torch.tensor
[docs]class Dataset(torch.utils.data.Dataset): r''' A lightweight dataset helper based on file lists. Args: root (str): The dataset root directory. filelist (str): The text file listing the samples. transform (callable): The callable applied to each loaded sample. read_file (callable): The file reader used to load raw samples. in_memory (bool): If True, loads all samples into memory at startup. take (int): Limits the number of samples used from the file list. ''' def __init__(self, root, filelist, transform, read_file=read_file, in_memory=False, take: int = -1): r''' Initializes the dataset helper. ''' super(Dataset, self).__init__() self.root = root self.filelist = filelist self.transform = transform self.in_memory = in_memory self.read_file = read_file self.take = take self.filenames, self.labels = self.load_filenames() if self.in_memory: print('Load files into memory from ' + self.filelist) self.samples = [self.read_file(os.path.join(self.root, f)) for f in tqdm(self.filenames, ncols=80, leave=False)] def __len__(self): r''' Returns the number of samples in the dataset. ''' return len(self.filenames) def __getitem__(self, idx): r''' Returns one transformed sample. Args: idx (int): The sample index. Returns: dict: The transformed sample with ``label`` and ``filename`` attached. ''' sample = (self.samples[idx] if self.in_memory else self.read_file(os.path.join(self.root, self.filenames[idx]))) output = self.transform(sample, idx) # data augmentation + build octree output['label'] = self.labels[idx] output['filename'] = self.filenames[idx] return output
[docs] def load_filenames(self): r''' Loads filenames and labels from the file list. Returns: tuple: A pair ``(filenames, labels)`` truncated according to :attr:`self.take`. ''' filenames, labels = [], [] with open(self.filelist) as fid: lines = fid.readlines() for line in lines: tokens = line.split() filename = tokens[0].replace('\\', '/') label = tokens[1] if len(tokens) == 2 else 0 filenames.append(filename) labels.append(int(label)) num = len(filenames) if self.take > num or self.take < 1: self.take = num return filenames[:self.take], labels[:self.take]