-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathPCS_coeffs_generate.py
214 lines (180 loc) · 7.82 KB
/
PCS_coeffs_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import freqz, firls
import os
import sys
import argparse
from utils import *
from original_PCS_spectral import process_wav
from tqdm import tqdm
TAPS = 127
BINS = 257
GAIN_SMOOTHING = 0.2
OUTDIR = 'generated_freq_response'
STAT_DIR = 'statistical'
def init_PCS_params():
PCS_curve = np.ones(257) # Perceptual Contrast Stretching
PCS_curve[0:3] = 1
PCS_curve[3:6] = 1.070175439
PCS_curve[6:9] = 1.182456140
PCS_curve[9:12] = 1.287719298
PCS_curve[12:138] = 1.4 # Pre Set
PCS_curve[138:166] = 1.322807018
PCS_curve[166:200] = 1.238596491
PCS_curve[200:241] = 1.161403509
PCS_curve[241:256] = 1.077192982
PCS_params = {
'Band0': dict(band=[0, 3/256], gain=1.0),
'Band1': dict(band=[3/256, 6/256], gain=1.070175439),
'Band2': dict(band=[6/256, 9/256], gain=1.182456140),
'Band3': dict(band=[9/256, 12/256], gain=1.287719298),
'Band4': dict(band=[12/256, 138/256], gain=1.4),
'Band5': dict(band=[138/256, 166/256], gain=1.322807018),
'Band6': dict(band=[166/256, 200/256], gain=1.238596491),
'Band7': dict(band=[200/256, 241/256], gain=1.161403509),
'Band8': dict(band=[241/256, 1], gain=1.077192982),
}
return PCS_curve, PCS_params
def smooth_gains(gains):
desired = []
for idx in range(len(gains)):
if idx == 0:
gain_2 = (gains[idx+1]-gains[idx]) * GAIN_SMOOTHING + gains[idx]
desired.append(gains[idx])
desired.append(gain_2)
elif idx == len(gains)-1:
gain_1 = (gains[idx-1]-gains[idx]) * GAIN_SMOOTHING + gains[idx]
desired.append(gain_1)
desired.append(gains[idx])
else:
gain_1 = (gains[idx-1]-gains[idx]) * GAIN_SMOOTHING + gains[idx]
gain_2 = (gains[idx+1]-gains[idx]) * GAIN_SMOOTHING + gains[idx]
desired.append(gain_1)
desired.append(gain_2)
return desired
def get_multiband_filter(PCS_params, numtaps):
bands = []
gains = []
for _, params in PCS_params.items():
bands.extend(params['band'])
gains.append(np.exp(params['gain']))
min_gain = min(gains)
for idx in range(len(gains)):
gains[idx] = gains[idx] / min_gain
desired = smooth_gains(gains)
multiband_coeffs = firls(numtaps, bands, desired, weight=None, nyq=None, fs=2)
return multiband_coeffs
def plot_fir_response(w, h):
fig = plt.figure()
plt.title('Digital filter frequency response')
ax1 = fig.add_subplot(111)
plt.plot(w, abs(h), 'b')
plt.ylabel('Amplitude [linear]', color='b')
plt.xlabel('Frequency [rad/sample]')
ax2 = ax1.twinx()
angles = np.unwrap(np.angle(h))
plt.plot(w, angles, 'g')
plt.ylabel('Angle (radians)', color='g')
plt.grid()
plt.axis('tight')
plt.savefig(os.path.join(OUTDIR,'PCS_coeffs_freqz.png'))
plt.show()
ax1.clear()
ax2.clear()
def load_and_filter(audio_path=None):
if audio_path is not None:
audio, sr = load_wav(audio_path)
else:
sr=22050
ones = torch.ones(1)
audio = generate_noise(sr*10, std=torch.abs(torch.normal(mean=ones*0.2, std=ones*0.1)))
filtered_audio = process_wav(audio)
audio = torch.FloatTensor(audio)
filtered_audio = torch.FloatTensor(filtered_audio)
return audio.squeeze(), filtered_audio.squeeze()
def adaptive_smoothing(curve, target_length):
smoothed = torch.nn.functional.upsample(curve.view(1,1,-1), target_length, mode='linear')
return smoothed.squeeze()
def moving_avg_spectra(spectrum_avg, spectrum, count):
spectrum = adaptive_smoothing(spectrum, BINS)
if spectrum_avg is None:
assert count == 0
return spectrum
return (spectrum_avg * count + spectrum) / (count+1)
def record_spectrum_avg(spectrum_avg, audio, count):
spectrum = onesided_magnitude_spectrum(audio)
spectrum_avg = moving_avg_spectra(spectrum_avg, spectrum, count)
return spectrum_avg
def statistical_response(mode='gaussian', num_samples=100, wav_dir=None):
original_spectrum_avg = None
filtered_spectrum_avg = None
if mode == 'gaussian':
audio_path = None
elif mode == 'wav':
assert wav_dir is not None
filepaths = list()
for file in os.listdir(wav_dir):
if file.endswith('.wav'):
filepaths.append(os.path.join(wav_dir, file))
if num_samples is not None:
num_samples = min(num_samples, len(filepaths))
else:
num_samples = len(filepaths)
with tqdm(total=num_samples) as pbar:
for idx in range(num_samples):
if mode == 'gaussian':
audio_path = None
elif mode == 'wav':
audio_path = filepaths[idx]
audio, filtered_audio = load_and_filter(audio_path)
original_spectrum_avg = record_spectrum_avg(original_spectrum_avg, audio, idx)
filtered_spectrum_avg = record_spectrum_avg(filtered_spectrum_avg, filtered_audio, idx)
pbar.update(1)
pointwise_gains = filtered_spectrum_avg / original_spectrum_avg
plot_response_curves([pointwise_gains])
return pointwise_gains
def manual_PCS():
_, PCS_params = init_PCS_params()
cascade_coeffs = get_multiband_filter(PCS_params, TAPS)
np.save(os.path.join(OUTDIR,'PCS_coeffs.npy'), cascade_coeffs)
w, h = freqz(cascade_coeffs)
plot_fir_response(w, h)
def statistical_PCS(args):
pointwise_gains = statistical_response(mode=args.stat_mode, num_samples=args.num_samples, wav_dir=args.wav_dir)
np.save(os.path.join(STAT_DIR,'stat_gains.npy'), pointwise_gains)
bands = list()
gains = list()
for idx in range(len(pointwise_gains)):
bands.extend([idx/BINS, (idx+1)/BINS])
gains.append(pointwise_gains[idx])
desired = smooth_gains(gains)
cascade_coeffs = firls(TAPS, bands, desired, weight=None, nyq=None, fs=2)
np.save(os.path.join(STAT_DIR,'PCS_coeffs.npy'), cascade_coeffs)
w, h = freqz(cascade_coeffs)
plot_fir_response(w, h)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--mode', type=str, default='statistical',
help='`statistical` or `manual`.\n\
`statistical` uses specified signals to measure the original spectral PCS as a LTI system to obtain gains.\n\
`manual` uses default gains in the original spectroal PCS.')
parser.add_argument('-stm', '--stat_mode', type=str, default='gaussian', \
help='`gaussian` or `wav`.\n `spcifies the measuring signal if mode==`statistical`.\n\
if `gaussian`, generates Gaussian noise as measuring signals.\n\
if `wav` load .wav files from specified directory as measuring signals',
required=False)
parser.add_argument('-wd', '--wav_dir', type=str, default=None,
required=False, help='specifies where the .wav files are located if mode==`statistical` and --stat_mode==wav')
parser.add_argument('-n', '--num_samples', type=int, default=1000,
required=False, help='if mode==`statistical`, the measuring process will be performed num_samples times.\n\
if --stat_mode==`wav`, the process will be performed min(num_samples, num_wavs_loaded) times')
args = parser.parse_args()
if args.mode == 'manual':
manual_PCS()
elif args.mode == 'statistical':
assert args.stat_mode == 'gaussian' or args.stat_mode == 'wav', 'args.stat_mode: {}'.format(args.stat_mode)
if args.stat_mode == 'wav':
assert os.path.isdir(args.wav_dir), 'Error, is not dir; args.wav_dir: {}'.format(args.wav_dir)
statistical_PCS(args)
else:
parser.print_help()