Skip to content

Commit

Permalink
Run more ops in FP16 mode
Browse files Browse the repository at this point in the history
  • Loading branch information
HolyWu committed Apr 23, 2023
1 parent 6890bd7 commit 2825f26
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 7 deletions.
16 changes: 16 additions & 0 deletions vsgrlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=True,
fp16=fp16,
)
scale = 4
tile_pad = fallback(tile_pad, 16)
Expand All @@ -143,6 +144,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=True,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 16)
pad_size = 96
Expand All @@ -162,6 +164,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=True,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 12)
pad_size = 96
Expand All @@ -181,6 +184,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=True,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 12)
pad_size = 96
Expand All @@ -200,6 +204,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=True,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 12)
pad_size = 96
Expand All @@ -219,6 +224,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 8)
pad_size = 32
Expand All @@ -238,6 +244,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 16)
pad_size = 128
Expand All @@ -257,6 +264,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 16)
pad_size = 128
Expand All @@ -276,6 +284,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 16)
pad_size = 128
Expand All @@ -295,6 +304,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 36)
pad_size = 144
Expand All @@ -314,6 +324,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 36)
pad_size = 144
Expand All @@ -333,6 +344,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 36)
pad_size = 144
Expand All @@ -352,6 +364,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
tile_pad = fallback(tile_pad, 36)
pad_size = 144
Expand All @@ -371,6 +384,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
scale = 2
tile_pad = fallback(tile_pad, 32)
Expand All @@ -391,6 +405,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
scale = 3
tile_pad = fallback(tile_pad, 32)
Expand All @@ -411,6 +426,7 @@ def grlir(
mlp_ratio=2,
anchor_window_down_factor=4,
local_connection=False,
fp16=fp16,
)
scale = 4
tile_pad = fallback(tile_pad, 32)
Expand Down
78 changes: 76 additions & 2 deletions vsgrlir/grl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
pretrained_stripe_size=[0, 0],
conv_type="1conv",
init_method="",
fp16=False,
args=None,
):
super().__init__()
Expand Down Expand Up @@ -118,12 +119,76 @@ def __init__(
pretrained_window_size=pretrained_window_size,
pretrained_stripe_size=pretrained_stripe_size,
res_scale=0.1 if init_method == "r" else 1.0,
fp16=fp16,
args=args,
)
self.blocks.append(block)

self.conv = build_last_conv(conv_type, dim)

def _apply_half(self, fn):
self.blocks.half()

def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return not torch.__future__.get_overwrite_module_params_on_conversion()
else:
return False

for key, param in self._parameters.items():
if param is None:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
out_param = param
else:
assert isinstance(param, nn.Parameter)
assert param.is_leaf
out_param = nn.Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param

if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(param.grad, grad_applied)
if should_use_set_data:
assert out_param.grad is not None
out_param.grad.data = grad_applied
else:
assert param.grad.is_leaf
out_param.grad = grad_applied.requires_grad_(param.grad.requires_grad)

for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)

return self

def half(self):
r"""Casts all floating point parameters and buffers to ``half`` datatype.
.. note::
This method modifies the module in-place.
Returns:
Module: self
"""
return self._apply_half(lambda t: t.half() if t.is_floating_point() else t)

def _init_weights(self):
for n, m in self.named_modules():
if self.init_method == "w":
Expand Down Expand Up @@ -236,6 +301,7 @@ def __init__(
conv_type="1conv",
init_method="n", # initialization method of the weight parameters used to train large scale models.
euclidean_dist=False,
fp16=False,
**kwargs,
):
super(GRL, self).__init__()
Expand Down Expand Up @@ -272,6 +338,7 @@ def __init__(
self.pretrained_window_size = pretrained_window_size
self.pretrained_stripe_size = pretrained_stripe_size
self.anchor_window_down_factor = anchor_window_down_factor
self.fp16 = fp16

# Head of the network. First convolution.
self.conv_first = nn.Conv2d(in_channels, embed_dim, 3, 1, 1)
Expand Down Expand Up @@ -321,6 +388,7 @@ def __init__(
pretrained_stripe_size=pretrained_stripe_size,
conv_type=conv_type,
init_method=init_method,
fp16=fp16,
args=args,
)
self.layers.append(layer)
Expand Down Expand Up @@ -369,6 +437,8 @@ def _apply_half(self, fn):
self.conv_first.half()
self.norm_start.half()
self.pos_drop.half()
for layer in self.layers:
layer.half()
self.conv_after_body.half()

if self.upsampler == "pixelshuffle":
Expand Down Expand Up @@ -512,6 +582,10 @@ def get_table_index_mask(self, device=None, input_resolution=None):
}
else:
table_index_mask = self.set_table_index_mask(input_resolution, device=device)
if self.fp16:
for k, v in table_index_mask.items():
if v.is_floating_point():
table_index_mask[k] = v.half()
return table_index_mask

def _init_weights(self, m):
Expand Down Expand Up @@ -552,17 +626,17 @@ def check_image_size(self, x):

def forward_features(self, x):
x_size = (x.shape[2], x.shape[3])
x_dtype = x.dtype
x = bchw_to_blc(x)
x = self.norm_start(x)
x = self.pos_drop(x)

table_index_mask = self.get_table_index_mask(x.device, x_size)
x = x.float()
for layer in self.layers:
x = layer(x, x_size, table_index_mask)

x = self.norm_end(x) # B L C
if x_dtype == torch.half:
if self.fp16:
x = x.half()
x = blc_to_bchw(x, x_size)

Expand Down
11 changes: 6 additions & 5 deletions vsgrlir/mixed_attn_block_efficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ def __init__(
pretrained_window_size=[0, 0],
pretrained_stripe_size=[0, 0],
res_scale=1.0,
fp16=False,
args=None,
):
super().__init__()
Expand All @@ -463,6 +464,7 @@ def __init__(
self.stripe_groups = stripe_groups
self.mlp_ratio = mlp_ratio
self.res_scale = res_scale
self.fp16 = fp16

self.attn = MixedAttention(
dim,
Expand Down Expand Up @@ -528,23 +530,22 @@ def _get_table_index_mask(self, all_table_index_mask):
table_index_mask["mask_w2a"] = None
return table_index_mask

@torch.cuda.amp.autocast()
def forward(self, x, x_size, all_table_index_mask):
# Mixed attention
table_index_mask = self._get_table_index_mask(all_table_index_mask)
if self.args.local_connection:
x = (
x
+ self.res_scale
* self.drop_path(self.norm1(self.attn(x, x_size, table_index_mask)))
+ self.conv(x, x_size)
* self.drop_path(self.norm1(self.attn(x.half() if self.fp16 else x, x_size, table_index_mask)))
+ self.conv(x.half() if self.fp16 else x, x_size)
)
else:
x = x + self.res_scale * self.drop_path(
self.norm1(self.attn(x, x_size, table_index_mask))
self.norm1(self.attn(x.half() if self.fp16 else x, x_size, table_index_mask))
)
# FFN
x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x)))
x = x + self.res_scale * self.drop_path(self.norm2(self.mlp(x.half() if self.fp16 else x)))

return x

Expand Down

0 comments on commit 2825f26

Please sign in to comment.