Skip to content

Commit da8edda

Browse files
authored
Merge pull request #2 from HuangZiliAndy/master
Fix problem in local amplitude
2 parents bd3581a + 696f44b commit da8edda

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

lilfilter/local_amplitude.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,7 @@ def __init__(self,
7979

8080

8181
self.block_size = block_size
82-
if block_size > 1:
83-
# num_zeros = 4 is a lower-than-normal width for the FIR filter since there
84-
# won't be frequencies near the Nyquist and we don't need a sharp cutoff.
85-
# filter_cutoff_ratio = 9 is to avoid aliasing effects with this less-precise
86-
# filter (default is 0.95).
87-
self.resampler = resampler.Resampler(block_size, num_zeros = 4,
88-
filter_cutoff_ratio = 0.9,
89-
double_precision = double_precision)
90-
82+
assert block_size > 1
9183

9284
def compute(self,
9385
input):
@@ -142,7 +134,13 @@ def compute(self,
142134
smoothed_amplitudes = self.gaussian_filter.apply(summed_amplitudes)
143135
assert smoothed_amplitudes.shape == summed_amplitudes.shape
144136

145-
upsampled_amplitudes = self.resampler.upsample(smoothed_amplitudes)
137+
# num_zeros = 4 is a lower-than-normal width for the FIR filter since there
138+
# won't be frequencies near the Nyquist and we don't need a sharp cutoff.
139+
# filter_cutoff_ratio = 9 is to avoid aliasing effects with this less-precise
140+
# filter (default is 0.95).
141+
self.resampler = resampler.Resampler(1, self.block_size, dtype = self.dtype, num_zeros = 4,
142+
cutoff_ratio = 0.9)
143+
upsampled_amplitudes = self.resampler.resample(smoothed_amplitudes)
146144
assert upsampled_amplitudes.shape[1] >= signal_length
147145

148146

tests/test_local_amplitude.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def test_constructor_and_compute(self):
2929
print("b.shape = {}, c.shape = {}, d.shape = {}".format(b.shape, c.shape, d.shape))
3030
print("b sum = ", b.sum())
3131
print("partial sum = ", b[0,0,5,:].sum().item())
32-
plt.plot(torch.arange(b.shape[-1]), b[0,0,5,:])
32+
plt.plot(torch.arange(b.shape[-1]).numpy(), b[0,0,5,:].numpy())
3333
print("d sum = ", d.sum())
34-
plt.plot(torch.arange(d.shape[-1]), d[0,5,:])
34+
plt.plot(torch.arange(d.shape[-1]).numpy(), d[0,5,:].numpy())
3535
plt.show()
3636

3737

0 commit comments

Comments
 (0)