Skip to content

Commit

Permalink
Merge pull request #111 from LoSealL/dev
Browse files Browse the repository at this point in the history
Fix the evaluation (eval.py) to benchmark dataset
  • Loading branch information
LoSealL authored May 29, 2020
2 parents 5c6d4c0 + 1b4eea3 commit 16b38ee
Show file tree
Hide file tree
Showing 11 changed files with 35 additions and 17 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ seafile-ignore.txt
.idea/
.vscode/
.pytest_cache/
.vsr/
__pycache__/

# Install info.
Expand Down
18 changes: 10 additions & 8 deletions Tools/NtireHelper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def main():
parser = argparse.ArgumentParser(
description=r"""NTIRE Data helpers for NTIRE 2019
description=r"""NTIRE Data helpers for NTIRE 2019
Containing tasks for:
- [Real Image Super-Resolution](https://competitions.codalab.org/competitions/21439)
- [Real Image Denoising (sRGB)](https://competitions.codalab.org/competitions/21266)
Expand Down Expand Up @@ -149,7 +149,7 @@ def divide(flags):
def _divide(img: Image, stride: int, size: int) -> list:
w = img.width
h = img.height
img = img_to_array(img)
img = img_to_array(img, data_format='channels_last')
patches = []
img = np.pad(img, [[0, size - h % stride or stride],
[0, size - w % stride or stride], [0, 0]],
Expand All @@ -167,7 +167,8 @@ def _divide(img: Image, stride: int, size: int) -> list:
for f in tqdm.tqdm(files, ascii=True):
pf = _divide(Image.open(f), flags.stride, flags.patch)
for i, p in enumerate(pf):
array_to_img(p, 'RGB').save(f"{save_dir}/{f.stem}_{i:04d}.png")
array_to_img(p, 'RGB', data_format='channels_last').save(
f"{save_dir}/{f.stem}_{i:04d}.png")


def combine(flags):
Expand All @@ -191,21 +192,22 @@ def _combine(ref: Image, sub: list, stride) -> Image:
p = sub[k]
k += 1
try:
blank[i:i + p.height, j:j + p.width] += img_to_array(p)
blank[i:i + p.height, j:j + p.width] += img_to_array(
p, 'channels_last')
except ValueError:
blank[i:i + p.height, j:j + p.width] += img_to_array(p)[:h - i,
:w - j]
blank[i:i + p.height, j:j + p.width] += img_to_array(
p, 'channels_last')[:h - i, :w - j]
count[i:i + p.height, j:j + p.width] += 1
blank /= count
return array_to_img(blank, 'RGB')
return array_to_img(blank, 'RGB', 'channels_last')

save_dir = Path(flags.save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
files = sorted(Path(flags.ref).glob("*.png"))
print(" [!] Combining...\n")
results = Path(flags.input_dir)
for f in tqdm.tqdm(files, ascii=True):
sub = list(results.glob("{}_????.png".format(f.stem)))
sub = list(results.glob("{}_*.png".format(f.stem)))
sub.sort(key=lambda x: int(x.stem[-4:]))
sub = [Image.open(s) for s in sub]
img = _combine(Image.open(f), sub, flags.stride)
Expand Down
2 changes: 1 addition & 1 deletion VSR/Backend/Torch/Models/Carn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ def export(self, export_dir):
# Sounds stupid to set a 48x48 inputs.

device = list(self.carn.parameters())[0].device
inputs = torch.randn(1, self.channel, 48, 48, device=device)
inputs = torch.randn(1, self.channel, 144, 128, device=device)
scale = torch.tensor(self.scale, device=device)
torch.onnx.export(self.carn, (inputs, scale), export_dir / 'carn.onnx')
5 changes: 5 additions & 0 deletions VSR/Backend/Torch/Models/Dbpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,8 @@ def eval(self, inputs, labels=None, **kwargs):
if labels is not None:
metrics['psnr'] = Metrics.psnr(sr.numpy(), labels[0].cpu().numpy())
return [sr.numpy()], metrics

def export(self, export_dir):
device = list(self.body.parameters())[0].device
inputs = torch.randn(1, self.channel, 144, 128, device=device)
torch.onnx.export(self.body, (inputs,), export_dir / 'dbpn.onnx')
5 changes: 5 additions & 0 deletions VSR/Backend/Torch/Models/Esrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,8 @@ def eval(self, inputs, labels=None, **kwargs):
writer.image('lr', inputs[0], step=step)
writer.image('hr', labels[0], step=step)
return [sr.numpy()], metrics

def export(self, export_dir):
device = list(self.rrdb.parameters())[0].device
inputs = torch.randn(1, self.channel, 144, 128, device=device)
torch.onnx.export(self.rrdb, (inputs,), export_dir / 'rrdb.onnx')
Empty file.
2 changes: 1 addition & 1 deletion VSR/Backend/Torch/Models/carn/carn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, **kwargs):
group=group)
self.exit = nn.Conv2d(64, 3, 3, 1, 1)

def forward(self, x, scale):
def forward(self, x, scale=None):
x = self.sub_mean(x)
x = self.entry(x)
c0 = o0 = x
Expand Down
2 changes: 1 addition & 1 deletion VSR/Backend/Torch/Models/carn/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self,

self.multi_scale = multi_scale

def forward(self, x, scale):
def forward(self, x, scale=None):
if self.multi_scale:
if scale == 2:
return self.up2(x)
Expand Down
7 changes: 4 additions & 3 deletions VSR/DataLoader/Loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def __init__(self, loader, shape, steps, shuffle=None):
temporal_padding = not shuffle
self.index = []
for i in range(t):
idx_ = [(i, np.array([j + x for x in range(self.depth)])) for j in
range(-(self.depth // 2), frame_nums[i] - (self.depth // 2))]
d2_ = self.depth // 2
depth = self.depth if self.depth >= 0 else len(self.loader.data['hr'][i])
idx_ = [(i, np.array([j + x for x in range(depth)])) for j in
range(-(depth // 2), frame_nums[i] - (depth // 2))]
d2_ = depth // 2
self.index += idx_ if temporal_padding or d2_ == 0 else idx_[d2_ : -d2_]
self.steps = steps if steps >= 0 else len(self.index) // shape[0]
while len(self.index) < self.steps * shape[0] and self.index:
Expand Down
9 changes: 6 additions & 3 deletions VSR/Util/ImageProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,23 @@
from ..Backend import DATA_FORMAT


def array_to_img(x: np.ndarray, mode='RGB', min_val=0, max_val=255):
def array_to_img(x: np.ndarray, mode='RGB', data_format=None, min_val=0,
max_val=255):
"""Convert an ndarray to PIL Image."""

x = np.squeeze(x).astype('float32')
x = (x - min_val) / (max_val - min_val)
x = x.clip(0, 1) * 255
if data_format not in ('channels_first', 'channels_last'):
data_format = DATA_FORMAT
if np.ndim(x) == 2:
return Image.fromarray(x.astype('uint8'), mode='L').convert(mode)
elif np.ndim(x) == 3:
if DATA_FORMAT == 'channels_first':
if data_format == 'channels_first':
x = x.transpose([1, 2, 0])
return Image.fromarray(x.astype('uint8'), mode=mode)
elif np.ndim(x) == 4:
if DATA_FORMAT == 'channels_first':
if data_format == 'channels_first':
x = x.transpose([0, 2, 3, 1])
ret = [Image.fromarray(np.round(i).astype('uint8'), mode=mode) for i in x]
return ret.pop() if len(ret) is 1 else ret
Expand Down
1 change: 1 addition & 0 deletions prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
'Tspmc.zip': 'https://github.com/LoSealL/Model/releases/download/spmc/spmc.zip',
'Tvespcn.zip': 'https://github.com/LoSealL/Model/releases/download/vespcn/Tvespcn.zip',
'Tsrmd.zip': '1ORKH05-aLSbQaWB4qQulIm2INoRufuD_',
'Tdbpn.zip': '1PbhtuMz1zF3-d16dthurJ0xIQ9uyMvkz'
}


Expand Down

0 comments on commit 16b38ee

Please sign in to comment.