Skip to content

Commit 921be54

Browse files
committed
Get tests working; some cleanup
1 parent 929f0be commit 921be54

9 files changed

+283
-495
lines changed

README.md

+39-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,44 @@
1-
# filtering
1+
## lilfilter
22

33
Utilities for resampling and filtering audio data
44

5-
This repository will contain a Python package containing certain utilities for filtering and
6-
resampling audio data.
5+
This repository exports a Python package `lilfilter` containing certain
6+
utilities for filtering and resampling audio data.
7+
8+
9+
One quite-useful thing is class Resampler:
10+
```
11+
python3
12+
>>> import lilfilter
13+
>>> # ... let a be a Torch tensor of size (num_channels, num_samples)
14+
>>> # that we want to downsample from 42.1kHz to 16kHz. Note,
15+
>>> # the sampling rates must be integers; only their ratio
16+
>>> # matters.
17+
>>> r = lilfilter.Resampler(42100, 16000, dtype=torch.float32)
18+
>>> b = r.resample(a)
19+
```
20+
21+
Another thing that's useful is class Multistreamer, which can turn a
22+
signal into multiple parallel signals at a lower sampling rate, where
23+
pairs of those signals represent the (real,complex) part of one
24+
complex frequency band of the input.
25+
```
26+
>>> import lilfilter
27+
>>> num_freq_bands = 8
28+
>>> m = lilfilter.Multistreamer(num_freq_bands)
29+
>>>
30+
>>> # ... let a be a Torch tensor of size (num_channels, num_samples)
31+
>>> # that we want to `demultiplex`.
32+
>>>
33+
>>> b = m.split(a)
34+
>>> # now b is of size (num_channels, 2, num_freq_bands, num_samples/num_freq_bands)
35+
>>> # (note: the dim of the last axis may be slightly different from that number).
36+
>>> # You can in principle manipulate b somehow, e.g. do some kind of machine
37+
>>> # learning with it, and then reconstruct to the original format:
38+
>>>
39+
>>> c = m.merge(b)
40+
>>> # now c is of size (num_channels, 8*(num_samples/8)) and will be extremely
41+
>>> # close to a.
42+
```
743

8-
The most thing exported, currently, is class Multistreamer which can be used to
9-
split an audio stream into multiple lower-frequency audio streams, each one
10-
representing one frequency band of the input.
1144

lilfilter/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11

22
from . multistreamer import Multistreamer
3-
3+
from . resampler import Resampler

lilfilter/resampler.py

+202-105
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
# To be run with python3
22

33
"""
4-
CAUTION: you probably want to use ./torch_resampler.py instead; it's more
5-
general.
6-
7-
This module defines an object that can be used for upsampling and downsampling
8-
of signals. Note: unlike ./filters.py, this object has a torch dependency.
9-
(It uses ./filters.py for initialization though.)
4+
This module defines an object that can be used for signal resampling.
5+
It has a torch dependency because it does the resampling via 1d convolution.
106
"""
117

128

@@ -15,114 +11,215 @@
1511
import math
1612
import torch
1713

18-
class Resampler:
1914

20-
def __init__(self, N, num_zeros = 32,
21-
filter_cutoff_ratio = 0.95,
22-
full_padding = False,
23-
double_precision = False):
24-
"""
25-
This creates an object which can be used for both upsampling and
26-
downsampling of signals. This involves creating a low-pass filter with
27-
the appropriate cutoff.
15+
def gcd(a, b):
16+
""" Return the greatest common divisor of a and b"""
17+
assert isinstance(a, int) and isinstance(b, int)
18+
if b == 0:
19+
return a
20+
else:
21+
return gcd(b, a % b)
2822

29-
Args:
30-
N (int): The downsampling or upsampling ratio. For example,
31-
4 would mean we downsample or upsample by a factor of 4.
32-
Must be > 1.
33-
34-
num_zeros (int): The number of zeros in the filter function..
35-
a larger number will give a sharper cutoff, but will be
36-
slower.
37-
38-
filter_cutoff_ratio (float): Determines where we place the
39-
cutoff of the filter used for upsampling and
40-
downsampling, relative to the Nyquist of the lower
41-
of the two frequencies. Must be >0.5 and <1.0.
42-
43-
full_padding (bool): If true, will pad on each side with
44-
(filter_width - 1) which ensures that a sufficiently-low-pass
45-
signal that's upsampled and then downsampled will
46-
undergo the round trip with minimal end effects.
47-
If false, we pad with filter_width when downsampling,
48-
which will give a signal length closer to
49-
input_signal_length / N and enables easier
50-
mapping of time offsets,(without worrying about time
51-
offsets).
52-
53-
double_precision: If true, will use torch.float64 for the filter
54-
(and expect this for the input); else will use torch.float32.
23+
class Resampler:
24+
"""
25+
This object should ideally be initialized once and used many times,
26+
but the construction time shouldn't be excessive.
27+
Please read the documentation carefully!
28+
"""
29+
30+
def __init__(self,
31+
input_sr, output_sr, dtype,
32+
num_zeros = 64, cutoff_ratio = 0.95):
5533
"""
56-
self.N = N
57-
if not (isinstance(N, int) and isinstance(num_zeros, int) and isinstance(filter_cutoff_ratio, float)):
58-
raise TypeError("One of the args has the wrong type")
59-
if N <= 1 or num_zeros < 2:
60-
raise ValueError("Require N > 1 and num_zeros > 1")
61-
if not (filter_cutoff_ratio > 0.5 and filter_cutoff_ratio < 1.0):
62-
raise ValueError("Invalid number for filter_cutoff_ratio: ",
63-
filter_cutoff_ratio)
64-
65-
self.dtype = (torch.float64 if double_precision else torch.float32)
66-
67-
# f is a numpy array. i is its central index, not really needed.
68-
(f, i) = filters.low_pass_filter(filter_cutoff_ratio / (N * 2),
69-
num_zeros = num_zeros)
70-
71-
72-
f_len = f.shape[0]
34+
This creates an object that can apply a symmetric FIR filter
35+
based on torch.nn.functional.conv1d.
7336
74-
# self.filter is a torch.Tensor whose dimension is interpreted
75-
# as (out_channels, in_channels, width) where out_channels and
76-
# in_channels are both 1.
77-
self.forward_filter = torch.tensor(f, dtype=self.dtype).view(1, 1, f_len)
78-
79-
self.backward_filter = self.forward_filter * N
37+
Args:
38+
input_sr: The input sampling rate, AS A SMALL INTEGER..
39+
does not have to be the real sampling rate but should
40+
have the correct ratio with output_sr.
41+
output_sr: The output sampling rate, AS A SMALL INTEGER.
42+
It is the ratio with the input sampling rate that is
43+
important here.
44+
dtype: The torch dtype to use for computations
45+
num_zeros: The number of zeros per side in the (sinc*hanning-window)
46+
filter function. More is more accurate, but 64 is already
47+
quite a lot.
48+
49+
You can think of this algorithm as dividing up the signals
50+
(input,output) into blocks where there are `input_sr` input
51+
samples and `output_sr` output samples. Then we treat it
52+
using convolutional code, imagining there are `input_sr`
53+
input channels and `output_sr` output channels per time step.
8054
81-
if full_padding:
82-
self.padding = f_len - 1
55+
"""
56+
assert isinstance(input_sr, int) and isinstance(output_sr, int)
57+
if input_sr == output_sr:
58+
self.resample_type = 'trivial'
59+
return
60+
d = gcd(input_sr, output_sr)
61+
input_sr, output_sr = input_sr // d, output_sr // d
62+
63+
assert dtype in [torch.float32, torch.float64]
64+
assert num_zeros > 3 # a reasonable bare minimum
65+
np_dtype = np.float32 if dtype == torch.float32 else np.float64
66+
67+
# Define one 'block' of samples `input_sr` input samples
68+
# and `output_sr` output samples. We can divide up
69+
# the samples into these blocks and have the blocks be
70+
#in correspondence.
71+
72+
# The sinc function will have, on average, `zeros_per_block`
73+
# zeros per block.
74+
zeros_per_block = min(input_sr, output_sr) * cutoff_ratio
75+
76+
# The convolutional kernel size will be n = (blocks_per_side*2 + 1),
77+
# i.e. we add that many blocks on each side of the central block. The
78+
# window radius (defined as distance from center to edge)
79+
# is `blocks_per_side` blocks. This ensures that each sample in the
80+
# central block can "see" all the samples in its window.
81+
#
82+
# Assuming the following division is not exact, adding 1
83+
# will have the same effect as rounding up.
84+
blocks_per_side = 1 + int(num_zeros / zeros_per_block)
85+
86+
kernel_width = 2*blocks_per_side + 1
87+
88+
# We want the weights as used by torch's conv1d code; format is
89+
# (out_channels, in_channels, kernel_width)
90+
# https://pytorch.org/docs/stable/nn.functional.html
91+
weights = torch.tensor((output_sr, input_sr, kernel_width), dtype=dtype)
92+
93+
# Computations involving time will be in units of 1 block. Actually this
94+
# is the same as the `canonical` time axis since each block has input_sr
95+
# input samples, so it would be one of whatever time unit we are using
96+
window_radius_in_blocks = blocks_per_side
97+
98+
99+
# The `times` below will end up being the args to the sinc function.
100+
# For the shapes of the things below, look at the args to `view`. The terms
101+
# below will get expanded to shape (output_sr, input_sr, kernel_width) through
102+
# broadcasting
103+
# We want it so that, assuming input_sr == output_sr, along the diagonal of
104+
# the central block we have t == 0.
105+
# The signs of the output_sr and input_sr terms need to be opposite. The
106+
# sign that the kernel_width term needs to be will depend on whether it's
107+
# convolution or correlation, and the logic is tricky.. I will just find
108+
# which sign works.
109+
110+
111+
times = (
112+
np.arange(output_sr, dtype=np_dtype).reshape((output_sr, 1, 1)) / output_sr -
113+
np.arange(input_sr, dtype=np_dtype).reshape((1, input_sr, 1)) / input_sr -
114+
(np.arange(kernel_width, dtype=np_dtype).reshape((1, 1, kernel_width)) - blocks_per_side))
115+
116+
117+
def window_func(a):
118+
"""
119+
window_func returns the Hann window on [-1,1], which is zero
120+
if a < -1 or a > 1, and otherwise 0.5 + 0.5 cos(a/pi).
121+
This is applied elementwise to a, which should be a NumPy array.
122+
123+
The heaviside function returns (a > 0 ? 1 : 0).
124+
"""
125+
return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi))
126+
127+
128+
# The weights below are a sinc function times a Hann-window function.
129+
#
130+
# multiplication by zeros_per_block can be seen as correctly normalizing
131+
# the sinc function (to compensate for scaling on the x-axis), so that
132+
# its integral is 1.
133+
#
134+
# division is by input_sr can be interpreted as normalizing the input
135+
# function correctly...if we view it as a stream of dirac deltas that's
136+
# passed through a low pass filter and want that to have the same
137+
# magnitude as the original input function, we need to divide by the
138+
# number of those deltas per unit time.
139+
weights = (np.sinc(times * zeros_per_block)
140+
* window_func(times / window_radius_in_blocks)
141+
* zeros_per_block / input_sr)
142+
143+
self.input_sr = input_sr
144+
self.output_sr = output_sr
145+
146+
147+
# OK, at this point the dim of the weights is (output_sr, input_sr,
148+
# kernel_width). If output_sr == 1, we can fold the input_sr into the
149+
# kernel_width (i.e. have just 1 input channel); this will make the
150+
# convolution faster and avoid unnecessary reshaping.
151+
152+
assert weights.shape == (output_sr, input_sr, kernel_width)
153+
if output_sr == 1:
154+
self.resample_type = 'integer_downsample'
155+
self.padding = input_sr * blocks_per_side
156+
weights = torch.tensor(weights, dtype=dtype)
157+
self.weights = weights.transpose(1, 2).contiguous().view(1, 1, input_sr * kernel_width)
158+
elif input_sr == 1:
159+
# In this case we'll be doing conv_transpose, so we want the same weights that
160+
# we would have if we wer *downsampling* by this factor-- i.e. as if input_sr,
161+
# output_sr had been swapped.
162+
self.resample_type = 'integer_upsample'
163+
self.padding = output_sr * blocks_per_side
164+
weights = torch.tensor(weights, dtype=dtype)
165+
self.weights = weights.flip(2).transpose(0, 2).contiguous().view(1, 1, output_sr * kernel_width)
83166
else:
84-
self.padding = (f_len - 1) // 2
167+
self.resample_type = 'general'
168+
self.reshaped = False
169+
self.padding = blocks_per_side
170+
self.weights = torch.tensor(weights, dtype=dtype)
85171

86172

87173

88-
def downsample(self, input):
174+
def resample(self, in_data):
89175
"""
90-
This downsamples the signal `input` and returns the result.
91-
Args:
92-
input (torch.Tensor): A Tensor with shape (minibatch_size, signal_length),
93-
and dtype torch.float64 if double_precision to constructor was true,
94-
else torch.float32.
176+
Resample the data
95177
96-
Return:
97-
Returns a torch.Tensor with shape (minibatch_size, reduced_signal_length).
98-
"""
99-
if not isinstance(input, torch.Tensor):
100-
raise TypeError("Expected input to be torch.Tensor, got ",
101-
type(input))
102-
if not (input.dtype == self.dtype):
103-
raise TypeError("Expected input tensor to have dtype {}, got {}".format(
104-
self.dtype, input.dtype))
105-
106-
# The squeeze and unsqueeze are to insert a dim for num_channels == 1.
107-
return torch.nn.functional.conv1d(input.unsqueeze(1),
108-
self.forward_filter,
109-
stride=self.N,
110-
padding=self.padding).squeeze(1)
111-
112-
def upsample(self, input):
113-
"""
114-
This upsamples the signal `input`.
178+
Args:
179+
input: a torch.Tensor with the same dtype as was passed to the
180+
constructor.
181+
There must be 2 axes, interpreted as (minibatch_size, sequence_length)...
182+
the minibatch_size may in practice be the number of channels.
183+
184+
Return: Returns a torch.Tensor with the same dtype as the input, and
185+
dimension (minibatch_size, (sequence_length//input_sr)*output_sr),
186+
where input_sr and output_sr are the corresponding constructor args,
187+
modified to remove any common factors.
115188
"""
116-
if not isinstance(input, torch.Tensor):
117-
raise TypeError("Expected input to be torch.Tensor, got ",
118-
type(input))
119-
if not (input.dtype == self.dtype):
120-
raise TypeError("Expected input tensor to have dtype {}, got {}".format(
121-
self.dtype, input.dtype))
122-
123-
124-
# The squeeze and unsqueeze are to insert a dim for num_channels == 1.
125-
return torch.nn.functional.conv_transpose1d(input.unsqueeze(1),
126-
self.backward_filter,
127-
stride=self.N,
128-
padding=self.padding).squeeze(1)
189+
if self.resample_type == 'trivial':
190+
return in_data
191+
elif self.resample_type == 'integer_downsample':
192+
(minibatch_size, seq_len) = in_data.shape
193+
# will be shape (minibatch_size, in_channels, seq_len) with in_channels == 1
194+
in_data = in_data.unsqueeze(1)
195+
out = torch.nn.functional.conv1d(in_data,
196+
self.weights,
197+
stride=self.input_sr,
198+
padding=self.padding)
199+
# shape will be (minibatch_size, out_channels = 1, seq_len);
200+
# return as (minibatch_size, seq_len)
201+
return out.squeeze(1)
202+
elif self.resample_type == 'integer_upsample':
203+
out = torch.nn.functional.conv_transpose1d(in_data.unsqueeze(1),
204+
self.weights,
205+
stride=self.output_sr,
206+
padding=self.padding)
207+
return out.squeeze(1)
208+
else:
209+
assert self.resample_type == 'general'
210+
(minibatch_size, seq_len) = in_data.shape
211+
num_blocks = seq_len // self.input_sr
212+
if num_blocks == 0:
213+
# TODO: pad with zeros.
214+
raise RuntimeError("Signal is too short to resample")
215+
in_data = in_data[:, 0:(num_blocks*self.input_sr)] # Truncate input
216+
in_data = in_data.view(minibatch_size, num_blocks, self.input_sr)
217+
218+
# Torch's conv1d actually expects input data with shape (minibatch,
219+
# in_channels, width) so we need to reshape (note: time is width).
220+
in_data = in_data.transpose(1, 2)
221+
222+
out = torch.nn.functional.conv1d(in_data, self.weights,
223+
padding=self.padding)
224+
assert out.shape == (minibatch_size, self.output_sr, num_blocks)
225+
return out.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr)

0 commit comments

Comments
 (0)