-
Notifications
You must be signed in to change notification settings - Fork 0
/
wavelet_layer.py
115 lines (93 loc) · 4.36 KB
/
wavelet_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import pywt
import torch
import torch.nn as nn
from torch.autograd import Function
class DWT_Function(Function):
@staticmethod
def forward(ctx, x, w_ll, w_lh, w_hl, w_hh):
x = x.contiguous()
ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh)
ctx.shape = x.shape
dim = x.shape[1]
x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride=2, groups=dim)
x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride=2, groups=dim)
x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride=2, groups=dim)
x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride=2, groups=dim)
x = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1)
return x
@staticmethod
def backward(ctx, dx):
if ctx.needs_input_grad[0]:
w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors
B, C, H, W = ctx.shape
dx = dx.view(B, 4, -1, H // 2, W // 2)
dx = dx.transpose(1, 2).reshape(B, -1, H // 2, W // 2)
filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0).repeat(C, 1, 1, 1)
dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C)
return dx, None, None, None, None
class IDWT_Function(Function):
@staticmethod
def forward(ctx, x, filters):
ctx.save_for_backward(filters)
ctx.shape = x.shape
B, _, H, W = x.shape
x = x.view(B, 4, -1, H, W).transpose(1, 2)
C = x.shape[1]
x = x.reshape(B, -1, H, W)
filters = filters.repeat(C, 1, 1, 1)
x = torch.nn.functional.conv_transpose2d(x, filters, stride=2, groups=C)
return x
@staticmethod
def backward(ctx, dx):
if ctx.needs_input_grad[0]:
filters = ctx.saved_tensors
filters = filters[0]
B, C, H, W = ctx.shape
C = C // 4
dx = dx.contiguous()
w_ll, w_lh, w_hl, w_hh = torch.unbind(filters, dim=0)
x_ll = torch.nn.functional.conv2d(dx, w_ll.unsqueeze(1).expand(C, -1, -1, -1), stride=2, groups=C)
x_lh = torch.nn.functional.conv2d(dx, w_lh.unsqueeze(1).expand(C, -1, -1, -1), stride=2, groups=C)
x_hl = torch.nn.functional.conv2d(dx, w_hl.unsqueeze(1).expand(C, -1, -1, -1), stride=2, groups=C)
x_hh = torch.nn.functional.conv2d(dx, w_hh.unsqueeze(1).expand(C, -1, -1, -1), stride=2, groups=C)
dx = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1)
return dx, None
class IDWT_2D(nn.Module):
def __init__(self, wave):
super(IDWT_2D, self).__init__()
w = pywt.Wavelet(wave)
rec_hi = torch.Tensor(w.rec_hi)
rec_lo = torch.Tensor(w.rec_lo)
w_ll = rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1)
w_lh = rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1)
w_hl = rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1)
w_hh = rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)
w_ll = w_ll.unsqueeze(0).unsqueeze(1)
w_lh = w_lh.unsqueeze(0).unsqueeze(1)
w_hl = w_hl.unsqueeze(0).unsqueeze(1)
w_hh = w_hh.unsqueeze(0).unsqueeze(1)
filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0)
self.register_buffer("filters", filters)
self.filters = self.filters.to(dtype=torch.float)
def forward(self, x):
return IDWT_Function.apply(x, self.filters)
class DWT_2D(nn.Module):
def __init__(self, wave):
super(DWT_2D, self).__init__()
w = pywt.Wavelet(wave)
dec_hi = torch.Tensor(w.dec_hi[::-1])
dec_lo = torch.Tensor(w.dec_lo[::-1])
w_ll = dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1)
w_lh = dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1)
w_hl = dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1)
w_hh = dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)
self.register_buffer("w_ll", w_ll.unsqueeze(0).unsqueeze(0))
self.register_buffer("w_lh", w_lh.unsqueeze(0).unsqueeze(0))
self.register_buffer("w_hl", w_hl.unsqueeze(0).unsqueeze(0))
self.register_buffer("w_hh", w_hh.unsqueeze(0).unsqueeze(0))
self.w_ll = self.w_ll.to(dtype=torch.float)
self.w_lh = self.w_lh.to(dtype=torch.float)
self.w_hl = self.w_hl.to(dtype=torch.float)
self.w_hh = self.w_hh.to(dtype=torch.float)
def forward(self, x):
return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh)