-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
75 changed files
with
122,422 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1.2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# ------------------------------------------------------------------------ | ||
# Copyright (c) 2021 megvii-model. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
# Modified from BasicSR (https://github.com/xinntao/BasicSR) | ||
# Copyright 2018-2020 BasicSR Authors | ||
# ------------------------------------------------------------------------ | ||
|
||
import importlib | ||
import numpy as np | ||
import random | ||
import torch | ||
import torch.utils.data | ||
from functools import partial | ||
from os import path as osp | ||
|
||
from basicsr.data.prefetch_dataloader import PrefetchDataLoader | ||
from basicsr.utils import get_root_logger, scandir | ||
from basicsr.utils.dist_util import get_dist_info | ||
|
||
__all__ = ['create_dataset', 'create_dataloader'] | ||
|
||
# automatically scan and import dataset modules | ||
# scan all the files under the data folder with '_dataset' in file names | ||
data_folder = osp.dirname(osp.abspath(__file__)) | ||
dataset_filenames = [ | ||
osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) | ||
if v.endswith('_dataset.py') | ||
] | ||
# import all the dataset modules | ||
_dataset_modules = [ | ||
importlib.import_module(f'basicsr.data.{file_name}') | ||
for file_name in dataset_filenames | ||
] | ||
|
||
|
||
def create_dataset(dataset_opt): | ||
"""Create dataset. | ||
Args: | ||
dataset_opt (dict): Configuration for dataset. It constains: | ||
name (str): Dataset name. | ||
type (str): Dataset type. | ||
""" | ||
dataset_type = dataset_opt['type'] | ||
|
||
# dynamic instantiation | ||
for module in _dataset_modules: | ||
dataset_cls = getattr(module, dataset_type, None) | ||
if dataset_cls is not None: | ||
break | ||
if dataset_cls is None: | ||
raise ValueError(f'Dataset {dataset_type} is not found.') | ||
|
||
dataset = dataset_cls(dataset_opt) | ||
|
||
logger = get_root_logger() | ||
logger.info( | ||
f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' | ||
'is created.') | ||
return dataset | ||
|
||
|
||
def create_dataloader(dataset, | ||
dataset_opt, | ||
num_gpu=1, | ||
dist=False, | ||
sampler=None, | ||
seed=None): | ||
"""Create dataloader. | ||
Args: | ||
dataset (torch.utils.data.Dataset): Dataset. | ||
dataset_opt (dict): Dataset options. It contains the following keys: | ||
phase (str): 'train' or 'val'. | ||
num_worker_per_gpu (int): Number of workers for each GPU. | ||
batch_size_per_gpu (int): Training batch size for each GPU. | ||
num_gpu (int): Number of GPUs. Used only in the train phase. | ||
Default: 1. | ||
dist (bool): Whether in distributed training. Used only in the train | ||
phase. Default: False. | ||
sampler (torch.utils.data.sampler): Data sampler. Default: None. | ||
seed (int | None): Seed. Default: None | ||
""" | ||
phase = dataset_opt['phase'] | ||
rank, _ = get_dist_info() | ||
if phase == 'train': | ||
if dist: # distributed training | ||
batch_size = dataset_opt['batch_size_per_gpu'] | ||
num_workers = dataset_opt['num_worker_per_gpu'] | ||
else: # non-distributed training | ||
multiplier = 1 if num_gpu == 0 else num_gpu | ||
batch_size = dataset_opt['batch_size_per_gpu'] * multiplier | ||
num_workers = dataset_opt['num_worker_per_gpu'] * multiplier | ||
dataloader_args = dict( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
num_workers=num_workers, | ||
sampler=sampler, | ||
drop_last=True, | ||
persistent_workers=True | ||
) | ||
if sampler is None: | ||
dataloader_args['shuffle'] = True | ||
dataloader_args['worker_init_fn'] = partial( | ||
worker_init_fn, num_workers=num_workers, rank=rank, | ||
seed=seed) if seed is not None else None | ||
elif phase in ['val', 'test']: # validation | ||
dataloader_args = dict( | ||
dataset=dataset, batch_size=1, shuffle=False, num_workers=0) | ||
else: | ||
raise ValueError(f'Wrong dataset phase: {phase}. ' | ||
"Supported ones are 'train', 'val' and 'test'.") | ||
|
||
dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) | ||
|
||
prefetch_mode = dataset_opt.get('prefetch_mode') | ||
if prefetch_mode == 'cpu': # CPUPrefetcher | ||
num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) | ||
logger = get_root_logger() | ||
logger.info(f'Use {prefetch_mode} prefetch dataloader: ' | ||
f'num_prefetch_queue = {num_prefetch_queue}') | ||
return PrefetchDataLoader( | ||
num_prefetch_queue=num_prefetch_queue, **dataloader_args) | ||
else: | ||
# prefetch_mode=None: Normal dataloader | ||
# prefetch_mode='cuda': dataloader for CUDAPrefetcher | ||
return torch.utils.data.DataLoader(**dataloader_args) | ||
|
||
|
||
def worker_init_fn(worker_id, num_workers, rank, seed): | ||
# Set the worker seed to num_workers * rank + worker_id + seed | ||
worker_seed = num_workers * rank + worker_id + seed | ||
np.random.seed(worker_seed) | ||
random.seed(worker_seed) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# ------------------------------------------------------------------------ | ||
# Copyright (c) 2021 megvii-model. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
# Modified from BasicSR (https://github.com/xinntao/BasicSR) | ||
# Copyright 2018-2020 BasicSR Authors | ||
# ------------------------------------------------------------------------ | ||
|
||
import math | ||
import torch | ||
from torch.utils.data.sampler import Sampler | ||
|
||
|
||
class EnlargedSampler(Sampler): | ||
"""Sampler that restricts data loading to a subset of the dataset. | ||
Modified from torch.utils.data.distributed.DistributedSampler | ||
Support enlarging the dataset for iteration-based training, for saving | ||
time when restart the dataloader after each epoch | ||
Args: | ||
dataset (torch.utils.data.Dataset): Dataset used for sampling. | ||
num_replicas (int | None): Number of processes participating in | ||
the training. It is usually the world_size. | ||
rank (int | None): Rank of the current process within num_replicas. | ||
ratio (int): Enlarging ratio. Default: 1. | ||
""" | ||
|
||
def __init__(self, dataset, num_replicas, rank, ratio=1): | ||
self.dataset = dataset | ||
self.num_replicas = num_replicas | ||
self.rank = rank | ||
self.epoch = 0 | ||
self.num_samples = math.ceil( | ||
len(self.dataset) * ratio / self.num_replicas) | ||
self.total_size = self.num_samples * self.num_replicas | ||
|
||
def __iter__(self): | ||
# deterministically shuffle based on epoch | ||
g = torch.Generator() | ||
g.manual_seed(self.epoch) | ||
indices = torch.randperm(self.total_size, generator=g).tolist() | ||
|
||
dataset_size = len(self.dataset) | ||
indices = [v % dataset_size for v in indices] | ||
|
||
# subsample | ||
indices = indices[self.rank:self.total_size:self.num_replicas] | ||
assert len(indices) == self.num_samples | ||
|
||
return iter(indices) | ||
|
||
def __len__(self): | ||
return self.num_samples | ||
|
||
def set_epoch(self, epoch): | ||
self.epoch = epoch |
Oops, something went wrong.