|
1 | 1 | # To be run with python3
|
2 | 2 |
|
3 | 3 | """
|
4 |
| -CAUTION: you probably want to use ./torch_resampler.py instead; it's more |
5 |
| -general. |
6 |
| -
|
7 |
| -This module defines an object that can be used for upsampling and downsampling |
8 |
| -of signals. Note: unlike ./filters.py, this object has a torch dependency. |
9 |
| -(It uses ./filters.py for initialization though.) |
| 4 | +This module defines an object that can be used for signal resampling. |
| 5 | +It has a torch dependency because it does the resampling via 1d convolution. |
10 | 6 | """
|
11 | 7 |
|
12 | 8 |
|
|
15 | 11 | import math
|
16 | 12 | import torch
|
17 | 13 |
|
18 |
| -class Resampler: |
19 | 14 |
|
20 |
| - def __init__(self, N, num_zeros = 32, |
21 |
| - filter_cutoff_ratio = 0.95, |
22 |
| - full_padding = False, |
23 |
| - double_precision = False): |
24 |
| - """ |
25 |
| - This creates an object which can be used for both upsampling and |
26 |
| - downsampling of signals. This involves creating a low-pass filter with |
27 |
| - the appropriate cutoff. |
| 15 | +def gcd(a, b): |
| 16 | + """ Return the greatest common divisor of a and b""" |
| 17 | + assert isinstance(a, int) and isinstance(b, int) |
| 18 | + if b == 0: |
| 19 | + return a |
| 20 | + else: |
| 21 | + return gcd(b, a % b) |
28 | 22 |
|
29 |
| - Args: |
30 |
| - N (int): The downsampling or upsampling ratio. For example, |
31 |
| - 4 would mean we downsample or upsample by a factor of 4. |
32 |
| - Must be > 1. |
33 |
| -
|
34 |
| - num_zeros (int): The number of zeros in the filter function.. |
35 |
| - a larger number will give a sharper cutoff, but will be |
36 |
| - slower. |
37 |
| -
|
38 |
| - filter_cutoff_ratio (float): Determines where we place the |
39 |
| - cutoff of the filter used for upsampling and |
40 |
| - downsampling, relative to the Nyquist of the lower |
41 |
| - of the two frequencies. Must be >0.5 and <1.0. |
42 |
| -
|
43 |
| - full_padding (bool): If true, will pad on each side with |
44 |
| - (filter_width - 1) which ensures that a sufficiently-low-pass |
45 |
| - signal that's upsampled and then downsampled will |
46 |
| - undergo the round trip with minimal end effects. |
47 |
| - If false, we pad with filter_width when downsampling, |
48 |
| - which will give a signal length closer to |
49 |
| - input_signal_length / N and enables easier |
50 |
| - mapping of time offsets,(without worrying about time |
51 |
| - offsets). |
52 |
| -
|
53 |
| - double_precision: If true, will use torch.float64 for the filter |
54 |
| - (and expect this for the input); else will use torch.float32. |
| 23 | +class Resampler: |
| 24 | + """ |
| 25 | + This object should ideally be initialized once and used many times, |
| 26 | + but the construction time shouldn't be excessive. |
| 27 | + Please read the documentation carefully! |
| 28 | + """ |
| 29 | + |
| 30 | + def __init__(self, |
| 31 | + input_sr, output_sr, dtype, |
| 32 | + num_zeros = 64, cutoff_ratio = 0.95): |
55 | 33 | """
|
56 |
| - self.N = N |
57 |
| - if not (isinstance(N, int) and isinstance(num_zeros, int) and isinstance(filter_cutoff_ratio, float)): |
58 |
| - raise TypeError("One of the args has the wrong type") |
59 |
| - if N <= 1 or num_zeros < 2: |
60 |
| - raise ValueError("Require N > 1 and num_zeros > 1") |
61 |
| - if not (filter_cutoff_ratio > 0.5 and filter_cutoff_ratio < 1.0): |
62 |
| - raise ValueError("Invalid number for filter_cutoff_ratio: ", |
63 |
| - filter_cutoff_ratio) |
64 |
| - |
65 |
| - self.dtype = (torch.float64 if double_precision else torch.float32) |
66 |
| - |
67 |
| - # f is a numpy array. i is its central index, not really needed. |
68 |
| - (f, i) = filters.low_pass_filter(filter_cutoff_ratio / (N * 2), |
69 |
| - num_zeros = num_zeros) |
70 |
| - |
71 |
| - |
72 |
| - f_len = f.shape[0] |
| 34 | + This creates an object that can apply a symmetric FIR filter |
| 35 | + based on torch.nn.functional.conv1d. |
73 | 36 |
|
74 |
| - # self.filter is a torch.Tensor whose dimension is interpreted |
75 |
| - # as (out_channels, in_channels, width) where out_channels and |
76 |
| - # in_channels are both 1. |
77 |
| - self.forward_filter = torch.tensor(f, dtype=self.dtype).view(1, 1, f_len) |
78 |
| - |
79 |
| - self.backward_filter = self.forward_filter * N |
| 37 | + Args: |
| 38 | + input_sr: The input sampling rate, AS A SMALL INTEGER.. |
| 39 | + does not have to be the real sampling rate but should |
| 40 | + have the correct ratio with output_sr. |
| 41 | + output_sr: The output sampling rate, AS A SMALL INTEGER. |
| 42 | + It is the ratio with the input sampling rate that is |
| 43 | + important here. |
| 44 | + dtype: The torch dtype to use for computations |
| 45 | + num_zeros: The number of zeros per side in the (sinc*hanning-window) |
| 46 | + filter function. More is more accurate, but 64 is already |
| 47 | + quite a lot. |
| 48 | +
|
| 49 | + You can think of this algorithm as dividing up the signals |
| 50 | + (input,output) into blocks where there are `input_sr` input |
| 51 | + samples and `output_sr` output samples. Then we treat it |
| 52 | + using convolutional code, imagining there are `input_sr` |
| 53 | + input channels and `output_sr` output channels per time step. |
80 | 54 |
|
81 |
| - if full_padding: |
82 |
| - self.padding = f_len - 1 |
| 55 | + """ |
| 56 | + assert isinstance(input_sr, int) and isinstance(output_sr, int) |
| 57 | + if input_sr == output_sr: |
| 58 | + self.resample_type = 'trivial' |
| 59 | + return |
| 60 | + d = gcd(input_sr, output_sr) |
| 61 | + input_sr, output_sr = input_sr // d, output_sr // d |
| 62 | + |
| 63 | + assert dtype in [torch.float32, torch.float64] |
| 64 | + assert num_zeros > 3 # a reasonable bare minimum |
| 65 | + np_dtype = np.float32 if dtype == torch.float32 else np.float64 |
| 66 | + |
| 67 | + # Define one 'block' of samples `input_sr` input samples |
| 68 | + # and `output_sr` output samples. We can divide up |
| 69 | + # the samples into these blocks and have the blocks be |
| 70 | + #in correspondence. |
| 71 | + |
| 72 | + # The sinc function will have, on average, `zeros_per_block` |
| 73 | + # zeros per block. |
| 74 | + zeros_per_block = min(input_sr, output_sr) * cutoff_ratio |
| 75 | + |
| 76 | + # The convolutional kernel size will be n = (blocks_per_side*2 + 1), |
| 77 | + # i.e. we add that many blocks on each side of the central block. The |
| 78 | + # window radius (defined as distance from center to edge) |
| 79 | + # is `blocks_per_side` blocks. This ensures that each sample in the |
| 80 | + # central block can "see" all the samples in its window. |
| 81 | + # |
| 82 | + # Assuming the following division is not exact, adding 1 |
| 83 | + # will have the same effect as rounding up. |
| 84 | + blocks_per_side = 1 + int(num_zeros / zeros_per_block) |
| 85 | + |
| 86 | + kernel_width = 2*blocks_per_side + 1 |
| 87 | + |
| 88 | + # We want the weights as used by torch's conv1d code; format is |
| 89 | + # (out_channels, in_channels, kernel_width) |
| 90 | + # https://pytorch.org/docs/stable/nn.functional.html |
| 91 | + weights = torch.tensor((output_sr, input_sr, kernel_width), dtype=dtype) |
| 92 | + |
| 93 | + # Computations involving time will be in units of 1 block. Actually this |
| 94 | + # is the same as the `canonical` time axis since each block has input_sr |
| 95 | + # input samples, so it would be one of whatever time unit we are using |
| 96 | + window_radius_in_blocks = blocks_per_side |
| 97 | + |
| 98 | + |
| 99 | + # The `times` below will end up being the args to the sinc function. |
| 100 | + # For the shapes of the things below, look at the args to `view`. The terms |
| 101 | + # below will get expanded to shape (output_sr, input_sr, kernel_width) through |
| 102 | + # broadcasting |
| 103 | + # We want it so that, assuming input_sr == output_sr, along the diagonal of |
| 104 | + # the central block we have t == 0. |
| 105 | + # The signs of the output_sr and input_sr terms need to be opposite. The |
| 106 | + # sign that the kernel_width term needs to be will depend on whether it's |
| 107 | + # convolution or correlation, and the logic is tricky.. I will just find |
| 108 | + # which sign works. |
| 109 | + |
| 110 | + |
| 111 | + times = ( |
| 112 | + np.arange(output_sr, dtype=np_dtype).reshape((output_sr, 1, 1)) / output_sr - |
| 113 | + np.arange(input_sr, dtype=np_dtype).reshape((1, input_sr, 1)) / input_sr - |
| 114 | + (np.arange(kernel_width, dtype=np_dtype).reshape((1, 1, kernel_width)) - blocks_per_side)) |
| 115 | + |
| 116 | + |
| 117 | + def window_func(a): |
| 118 | + """ |
| 119 | + window_func returns the Hann window on [-1,1], which is zero |
| 120 | + if a < -1 or a > 1, and otherwise 0.5 + 0.5 cos(a/pi). |
| 121 | + This is applied elementwise to a, which should be a NumPy array. |
| 122 | +
|
| 123 | + The heaviside function returns (a > 0 ? 1 : 0). |
| 124 | + """ |
| 125 | + return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi)) |
| 126 | + |
| 127 | + |
| 128 | + # The weights below are a sinc function times a Hann-window function. |
| 129 | + # |
| 130 | + # multiplication by zeros_per_block can be seen as correctly normalizing |
| 131 | + # the sinc function (to compensate for scaling on the x-axis), so that |
| 132 | + # its integral is 1. |
| 133 | + # |
| 134 | + # division is by input_sr can be interpreted as normalizing the input |
| 135 | + # function correctly...if we view it as a stream of dirac deltas that's |
| 136 | + # passed through a low pass filter and want that to have the same |
| 137 | + # magnitude as the original input function, we need to divide by the |
| 138 | + # number of those deltas per unit time. |
| 139 | + weights = (np.sinc(times * zeros_per_block) |
| 140 | + * window_func(times / window_radius_in_blocks) |
| 141 | + * zeros_per_block / input_sr) |
| 142 | + |
| 143 | + self.input_sr = input_sr |
| 144 | + self.output_sr = output_sr |
| 145 | + |
| 146 | + |
| 147 | + # OK, at this point the dim of the weights is (output_sr, input_sr, |
| 148 | + # kernel_width). If output_sr == 1, we can fold the input_sr into the |
| 149 | + # kernel_width (i.e. have just 1 input channel); this will make the |
| 150 | + # convolution faster and avoid unnecessary reshaping. |
| 151 | + |
| 152 | + assert weights.shape == (output_sr, input_sr, kernel_width) |
| 153 | + if output_sr == 1: |
| 154 | + self.resample_type = 'integer_downsample' |
| 155 | + self.padding = input_sr * blocks_per_side |
| 156 | + weights = torch.tensor(weights, dtype=dtype) |
| 157 | + self.weights = weights.transpose(1, 2).contiguous().view(1, 1, input_sr * kernel_width) |
| 158 | + elif input_sr == 1: |
| 159 | + # In this case we'll be doing conv_transpose, so we want the same weights that |
| 160 | + # we would have if we wer *downsampling* by this factor-- i.e. as if input_sr, |
| 161 | + # output_sr had been swapped. |
| 162 | + self.resample_type = 'integer_upsample' |
| 163 | + self.padding = output_sr * blocks_per_side |
| 164 | + weights = torch.tensor(weights, dtype=dtype) |
| 165 | + self.weights = weights.flip(2).transpose(0, 2).contiguous().view(1, 1, output_sr * kernel_width) |
83 | 166 | else:
|
84 |
| - self.padding = (f_len - 1) // 2 |
| 167 | + self.resample_type = 'general' |
| 168 | + self.reshaped = False |
| 169 | + self.padding = blocks_per_side |
| 170 | + self.weights = torch.tensor(weights, dtype=dtype) |
85 | 171 |
|
86 | 172 |
|
87 | 173 |
|
88 |
| - def downsample(self, input): |
| 174 | + def resample(self, in_data): |
89 | 175 | """
|
90 |
| - This downsamples the signal `input` and returns the result. |
91 |
| - Args: |
92 |
| - input (torch.Tensor): A Tensor with shape (minibatch_size, signal_length), |
93 |
| - and dtype torch.float64 if double_precision to constructor was true, |
94 |
| - else torch.float32. |
| 176 | + Resample the data |
95 | 177 |
|
96 |
| - Return: |
97 |
| - Returns a torch.Tensor with shape (minibatch_size, reduced_signal_length). |
98 |
| - """ |
99 |
| - if not isinstance(input, torch.Tensor): |
100 |
| - raise TypeError("Expected input to be torch.Tensor, got ", |
101 |
| - type(input)) |
102 |
| - if not (input.dtype == self.dtype): |
103 |
| - raise TypeError("Expected input tensor to have dtype {}, got {}".format( |
104 |
| - self.dtype, input.dtype)) |
105 |
| - |
106 |
| - # The squeeze and unsqueeze are to insert a dim for num_channels == 1. |
107 |
| - return torch.nn.functional.conv1d(input.unsqueeze(1), |
108 |
| - self.forward_filter, |
109 |
| - stride=self.N, |
110 |
| - padding=self.padding).squeeze(1) |
111 |
| - |
112 |
| - def upsample(self, input): |
113 |
| - """ |
114 |
| - This upsamples the signal `input`. |
| 178 | + Args: |
| 179 | + input: a torch.Tensor with the same dtype as was passed to the |
| 180 | + constructor. |
| 181 | + There must be 2 axes, interpreted as (minibatch_size, sequence_length)... |
| 182 | + the minibatch_size may in practice be the number of channels. |
| 183 | +
|
| 184 | + Return: Returns a torch.Tensor with the same dtype as the input, and |
| 185 | + dimension (minibatch_size, (sequence_length//input_sr)*output_sr), |
| 186 | + where input_sr and output_sr are the corresponding constructor args, |
| 187 | + modified to remove any common factors. |
115 | 188 | """
|
116 |
| - if not isinstance(input, torch.Tensor): |
117 |
| - raise TypeError("Expected input to be torch.Tensor, got ", |
118 |
| - type(input)) |
119 |
| - if not (input.dtype == self.dtype): |
120 |
| - raise TypeError("Expected input tensor to have dtype {}, got {}".format( |
121 |
| - self.dtype, input.dtype)) |
122 |
| - |
123 |
| - |
124 |
| - # The squeeze and unsqueeze are to insert a dim for num_channels == 1. |
125 |
| - return torch.nn.functional.conv_transpose1d(input.unsqueeze(1), |
126 |
| - self.backward_filter, |
127 |
| - stride=self.N, |
128 |
| - padding=self.padding).squeeze(1) |
| 189 | + if self.resample_type == 'trivial': |
| 190 | + return in_data |
| 191 | + elif self.resample_type == 'integer_downsample': |
| 192 | + (minibatch_size, seq_len) = in_data.shape |
| 193 | + # will be shape (minibatch_size, in_channels, seq_len) with in_channels == 1 |
| 194 | + in_data = in_data.unsqueeze(1) |
| 195 | + out = torch.nn.functional.conv1d(in_data, |
| 196 | + self.weights, |
| 197 | + stride=self.input_sr, |
| 198 | + padding=self.padding) |
| 199 | + # shape will be (minibatch_size, out_channels = 1, seq_len); |
| 200 | + # return as (minibatch_size, seq_len) |
| 201 | + return out.squeeze(1) |
| 202 | + elif self.resample_type == 'integer_upsample': |
| 203 | + out = torch.nn.functional.conv_transpose1d(in_data.unsqueeze(1), |
| 204 | + self.weights, |
| 205 | + stride=self.output_sr, |
| 206 | + padding=self.padding) |
| 207 | + return out.squeeze(1) |
| 208 | + else: |
| 209 | + assert self.resample_type == 'general' |
| 210 | + (minibatch_size, seq_len) = in_data.shape |
| 211 | + num_blocks = seq_len // self.input_sr |
| 212 | + if num_blocks == 0: |
| 213 | + # TODO: pad with zeros. |
| 214 | + raise RuntimeError("Signal is too short to resample") |
| 215 | + in_data = in_data[:, 0:(num_blocks*self.input_sr)] # Truncate input |
| 216 | + in_data = in_data.view(minibatch_size, num_blocks, self.input_sr) |
| 217 | + |
| 218 | + # Torch's conv1d actually expects input data with shape (minibatch, |
| 219 | + # in_channels, width) so we need to reshape (note: time is width). |
| 220 | + in_data = in_data.transpose(1, 2) |
| 221 | + |
| 222 | + out = torch.nn.functional.conv1d(in_data, self.weights, |
| 223 | + padding=self.padding) |
| 224 | + assert out.shape == (minibatch_size, self.output_sr, num_blocks) |
| 225 | + return out.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr) |
0 commit comments