Skip to content

Commit

Permalink
Set img_size to padded input size
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Apr 23, 2023
1 parent 2825f26 commit 240369a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 54 deletions.
93 changes: 40 additions & 53 deletions vsgrlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,10 @@ def grlir(
stream = [torch.cuda.Stream(device=device) for _ in range(num_streams)]
stream_lock = [Lock() for _ in range(num_streams)]

scale = 1

match model:
case 0:
model_name = "bsr_grl_base.ckpt"
module = GRL(
img_size=128,
model_args = dict(
embed_dim=180,
upscale=4,
upsampler="nearest+conv",
Expand All @@ -125,13 +122,11 @@ def grlir(
local_connection=True,
fp16=fp16,
)
scale = 4
tile_pad = fallback(tile_pad, 16)
pad_size = 64
case 1:
model_name = "db_defocus_single_pixel_grl_base.ckpt"
module = GRL(
img_size=480,
model_args = dict(
embed_dim=180,
upscale=1,
upsampler="",
Expand All @@ -150,8 +145,7 @@ def grlir(
pad_size = 96
case 2:
model_name = "db_motion_grl_base_gopro.ckpt"
module = GRL(
img_size=480,
model_args = dict(
embed_dim=180,
upscale=1,
upsampler="",
Expand All @@ -170,8 +164,7 @@ def grlir(
pad_size = 96
case 3:
model_name = "db_motion_grl_base_realblur_j.ckpt"
module = GRL(
img_size=480,
model_args = dict(
embed_dim=180,
upscale=1,
upsampler="",
Expand All @@ -190,8 +183,7 @@ def grlir(
pad_size = 96
case 4:
model_name = "db_motion_grl_base_realblur_r.ckpt"
module = GRL(
img_size=480,
model_args = dict(
embed_dim=180,
upscale=1,
upsampler="",
Expand All @@ -210,8 +202,7 @@ def grlir(
pad_size = 96
case 5:
model_name = "dm_grl_small.ckpt"
module = GRL(
img_size=64,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -230,8 +221,7 @@ def grlir(
pad_size = 32
case 6:
model_name = "dn_grl_small_c3s15.ckpt"
module = GRL(
img_size=256,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -250,8 +240,7 @@ def grlir(
pad_size = 128
case 7:
model_name = "dn_grl_small_c3s25.ckpt"
module = GRL(
img_size=256,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -270,8 +259,7 @@ def grlir(
pad_size = 128
case 8:
model_name = "dn_grl_small_c3s50.ckpt"
module = GRL(
img_size=256,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -290,8 +278,7 @@ def grlir(
pad_size = 128
case 9:
model_name = "jpeg_grl_small_c3q10.ckpt"
module = GRL(
img_size=288,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -310,8 +297,7 @@ def grlir(
pad_size = 144
case 10:
model_name = "jpeg_grl_small_c3q20.ckpt"
module = GRL(
img_size=288,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -330,8 +316,7 @@ def grlir(
pad_size = 144
case 11:
model_name = "jpeg_grl_small_c3q30.ckpt"
module = GRL(
img_size=288,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -350,8 +335,7 @@ def grlir(
pad_size = 144
case 12:
model_name = "jpeg_grl_small_c3q40.ckpt"
module = GRL(
img_size=288,
model_args = dict(
embed_dim=128,
upscale=1,
upsampler="",
Expand All @@ -370,8 +354,7 @@ def grlir(
pad_size = 144
case 13:
model_name = "sr_grl_small_c3x2.ckpt"
module = GRL(
img_size=256,
model_args = dict(
embed_dim=128,
upscale=2,
upsampler="pixelshuffle",
Expand All @@ -386,13 +369,11 @@ def grlir(
local_connection=False,
fp16=fp16,
)
scale = 2
tile_pad = fallback(tile_pad, 32)
pad_size = 64
case 14:
model_name = "sr_grl_small_c3x3.ckpt"
module = GRL(
img_size=256,
model_args = dict(
embed_dim=128,
upscale=3,
upsampler="pixelshuffle",
Expand All @@ -407,13 +388,11 @@ def grlir(
local_connection=False,
fp16=fp16,
)
scale = 3
tile_pad = fallback(tile_pad, 32)
pad_size = 64
case 15:
model_name = "sr_grl_small_c3x4.ckpt"
module = GRL(
img_size=256,
model_args = dict(
embed_dim=128,
upscale=4,
upsampler="pixelshuffle",
Expand All @@ -428,31 +407,35 @@ def grlir(
local_connection=False,
fp16=fp16,
)
scale = 4
tile_pad = fallback(tile_pad, 32)
pad_size = 64

model_path = os.path.join(model_dir, model_name)
if tile_w > 0 and tile_h > 0:
pad_w = math.ceil(min(tile_w + 2 * tile_pad, clip.width) / pad_size) * pad_size
pad_h = math.ceil(min(tile_h + 2 * tile_pad, clip.height) / pad_size) * pad_size
else:
pad_w = math.ceil(clip.width / pad_size) * pad_size
pad_h = math.ceil(clip.height / pad_size) * pad_size

model_args |= dict(img_size=(pad_h, pad_w))

state_dict = torch.load(model_path, map_location="cpu")
state_dict = torch.load(os.path.join(model_dir, model_name), map_location="cpu")
if "state_dict" in state_dict:
state_dict = state_dict["state_dict"]
state_dict = {k.replace("model_g.", ""): v for k, v in state_dict.items() if "model_g." in k}
state_dict = {
k.removeprefix("model_g."): v
for k, v in state_dict.items()
if "model_g." in k and "table_" not in k and "index_" not in k and "mask_" not in k
}
else:
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items() if "model." in k}
state_dict = {k.removeprefix("model."): v for k, v in state_dict.items() if "model." in k}

module.load_state_dict(state_dict, strict=False)
module = GRL(**model_args)
module.load_state_dict(state_dict)
module.eval().to(device, memory_format=torch.channels_last)
if fp16:
module.half()

if tile_w > 0 and tile_h > 0:
pad_w = math.ceil(min(tile_w + 2 * tile_pad, clip.width) / pad_size) * pad_size
pad_h = math.ceil(min(tile_h + 2 * tile_pad, clip.height) / pad_size) * pad_size
else:
pad_w = math.ceil(clip.width / pad_size) * pad_size
pad_h = math.ceil(clip.height / pad_size) * pad_size

if cuda_graphs:
graph: list[torch.cuda.CUDAGraph] = []
static_input: list[torch.Tensor] = []
Expand Down Expand Up @@ -492,7 +475,9 @@ def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:
img = frame_to_tensor(f[0], device)

if tile_w > 0 and tile_h > 0:
output = tile_process(img, scale, tile_w, tile_h, tile_pad, pad_w, pad_h, backend, local_index)
output = tile_process(
img, model_args["upscale"], tile_w, tile_h, tile_pad, pad_w, pad_h, backend, local_index
)
else:
h, w = img.shape[2:]
img = F.pad(img, (0, pad_w - w, 0, pad_h - h), "reflect")
Expand All @@ -504,11 +489,13 @@ def inference(n: int, f: list[vs.VideoFrame]) -> vs.VideoFrame:
else:
output = module(img)

output = output[:, :, : h * scale, : w * scale]
output = output[:, :, : h * model_args["upscale"], : w * model_args["upscale"]]

return tensor_to_frame(output, f[1].copy())

new_clip = clip.std.BlankClip(width=clip.width * scale, height=clip.height * scale, keep=True)
new_clip = clip.std.BlankClip(
width=clip.width * model_args["upscale"], height=clip.height * model_args["upscale"], keep=True
)
return new_clip.std.FrameEval(
lambda n: new_clip.std.ModifyFrame([clip, new_clip], inference), clip_src=[clip, new_clip]
)
Expand Down
2 changes: 1 addition & 1 deletion vsgrlir/grl.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
}
)
for k, v in self.set_table_index_mask(self.input_resolution).items():
self.register_buffer(k, v)
self.register_buffer(k, v, persistent=False)

self.layers = nn.ModuleList()
for i in range(len(depths)):
Expand Down

0 comments on commit 240369a

Please sign in to comment.