forked from erwald/midihum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmidi_to_df_conversion.py
323 lines (277 loc) · 15.3 KB
/
midi_to_df_conversion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import os
from pathlib import Path
from typing import List, Dict
import click
import numpy as np
import pandas as pd
from mido import MidiFile
from sklearn import preprocessing
from tqdm import tqdm
from midi_utility import get_note_tracks, get_midi_file_hash
from chord_identifier import chord_attributes
def midi_files_to_df(midi_filepaths: List[Path], skip_suspicious: bool = True) -> pd.DataFrame:
dfs = []
hashes_to_filenames: Dict[str, str] = {}
pbar = tqdm(midi_filepaths)
for midi_filepath in pbar:
pbar.set_description(f"midi_to_df_conversion converting {midi_filepath} to df")
midi_file = MidiFile(midi_filepath)
midi_file_hash = get_midi_file_hash(midi_file)
if midi_file_hash in hashes_to_filenames:
tqdm.write(
f"midi_to_df_conversion skipping {midi_filepath} since an identical file exists "
f"({hashes_to_filenames[midi_file_hash]})")
continue
hashes_to_filenames[midi_file_hash] = midi_filepath
try:
df = _midi_file_to_df(midi_file)
if skip_suspicious and len(df.velocity.unique()) < 25:
tqdm.write(f"midi_to_df_conversion skipping {midi_filepath} since it had few unique velocity values")
continue
df["name"] = os.path.split(midi_file.filename)[-1]
df = _add_engineered_features(df)
assert not np.any(df.index.duplicated()), (midi_filepath, df)
# reduce size by downcasting float64 and int64 columns
for column in df:
if column == "velocity":
df[column] = pd.to_numeric(df[column], downcast="float")
elif df[column].dtype == "float64":
df[column] = pd.to_numeric(df[column], downcast="float")
elif df[column].dtype == "int64":
df[column] = pd.to_numeric(df[column], downcast="integer")
dfs.append(df)
# TODO: catch more specific exception
except Exception as e:
tqdm.write(f"midi_to_df_conversion got exception converting midi to df: {e}")
raise e
processed_count = len(dfs)
total_count = len(midi_filepaths)
click.echo(f"midi_to_df_conversion converted {processed_count} files out of {total_count} to dfs")
return pd.concat(dfs)
def _midi_file_to_df(midi_file) -> pd.DataFrame:
note_tracks = get_note_tracks(midi_file)
note_events = [(track.index, note_event) for track in note_tracks for note_event in track.note_events]
note_events.sort(key=lambda note_event: note_event[1].time)
song_duration = note_events[-1][1].time
result = []
currently_playing_notes = []
for track_index, event in note_events:
if event.type == "note_on" and event.velocity > 0:
# get interval after the last released note by getting that note and checking the difference between the
# pitch values
if len(result) > 0:
interval_from_last_released_pitch = event.note - result[-1][4]
else:
interval_from_last_released_pitch = 0
# get interval after the last pressed note in a similar manner
if len(currently_playing_notes) > 0:
interval_from_last_pressed_pitch = (event.note - currently_playing_notes[-1][0])
else:
interval_from_last_pressed_pitch = interval_from_last_released_pitch
# get the average pitch of all notes currently being played
curr_pitches = [p for p, _, _ in currently_playing_notes] + [event.note]
average_pitch = np.mean(curr_pitches)
# add features denoting the quality of chord being played. that means there are six possible values for the
# "character":
#
# - is it minor?
# - is it major?
# - is it diminished?
# - is it augmented?
# - is it suspended?
# - or none of the above.
chord_attrs = chord_attributes(curr_pitches)
chord_character = chord_attrs[0] if chord_attrs is not None and chord_attrs[0] is not None else "none"
# and seven possible values for the number of notes:
#
# - is it a dyad?
# - is it a triad?
# - is it a seventh?
# - is it a ninth?
# - is it an eleventh?
# - is it a thirteenth?
# - or none of the above.
chord_size = chord_attrs[1] if chord_attrs is not None and chord_attrs[1] is not None else "none"
note_on_data = [
event.velocity,
event.time,
track_index,
event.index,
event.note,
str(event.note % 12),
event.note // 12,
average_pitch,
event.time / song_duration,
-(((event.time / song_duration) * 2 - 1) ** 2) + 1,
interval_from_last_pressed_pitch,
interval_from_last_released_pitch,
len(currently_playing_notes) + 1,
int(len(currently_playing_notes) == 0),
chord_character,
chord_size]
currently_playing_notes.append((event.note, event.time, note_on_data))
elif (event.type == "note_off" or (event.type == "note_on" and event.velocity == 0)):
if not (any(note == event.note for note, _, _ in currently_playing_notes)):
# note off-type event for a pitch that isn't being played
continue
note_on = _, note_on_time, note_on_data = next(x for x in currently_playing_notes if x[0] == event.note)
currently_playing_notes.remove(note_on)
sustain_duration = event.time - note_on_time
# if we get a note with a 0 sustain duration, use the duration of the previous note (if there is one)
if sustain_duration == 0:
if len(result) > 0:
sustain_duration = result[-1][16]
else:
tqdm.write(f"midi_to_df_conversion warning: got first note with 0 duration; defaulting to 25")
sustain_duration = 25.0
# get the average pitch of all notes currently being played
curr_pitches = [p for p, _, _ in currently_playing_notes] + [event.note]
average_pitch = np.mean(curr_pitches)
note_off_data = [sustain_duration, len(currently_playing_notes), average_pitch]
# add new row to result and sort all rows by note time (2nd column)
result.append(note_on_data + note_off_data)
result.sort(key=lambda row: row[1])
skipped_events = len(note_events) - len(result)
if skipped_events > 0:
tqdm.write(
f"midi_to_df_conversion warning: saw {skipped_events} note off events for pitches that hadn't been played")
df = pd.DataFrame(result)
df.columns = [
"velocity", "time", "midi_track_index", "midi_event_index", "pitch", "pitch_class", "octave",
"avg_pitch_pressed", "nearness_to_end", "nearness_to_midpoint", "interval_from_pressed",
"interval_from_released", "num_played_notes_pressed", "follows_pause", "chord_character_pressed",
"chord_size_pressed", "sustain", "num_played_notes_released", "avg_pitch_released"]
df["song_duration"] = song_duration
return df
def _add_engineered_features(df: pd.DataFrame, with_extra_features: bool = False) -> pd.DataFrame:
"""Takes a data frame representing one MIDI song and adds a bunch of
additional features to it.
"""
# NOTE: it's faster to create each column individually then merge them all together at the end. ("chord_character",
# "chord_size", "time_since_last_pressed" and "time_since_last_released" are however needed in the df, so we add
# those to the df directly.)
new_cols: Dict[str, pd.Series] = {}
# calculate "true" chord character and size by bunching all samples within 5 time units together and picking the
# chord character and size of the last of each group for all of them. this makes it so that, if a chord is played
# with not all notes perfectly at the same time, even the first notes here will get the information of the full
# chord (hopefully).
df["chord_character"] = df.groupby(np.floor(df.time / 5) * 5).chord_character_pressed.transform("last")
df["chord_size"] = df.groupby(np.floor(df.time / 5) * 5).chord_size_pressed.transform("last")
# get time elapsed since last note event(s)
df["time_since_last_pressed"] = (df.time - df.time.shift()).fillna(0)
df["time_since_last_released"] = (df.time - (df.time.shift() + df.sustain.shift())).fillna(0)
# get time elapsed since various further events. since some of these happen rather rarely (resulting in some very
# large values), we also normalise.
for cat in ["pitch_class", "octave", "follows_pause", "chord_character", "chord_size"]:
col_name = f"time_since_{cat}"
col = pd.Series(preprocessing.scale((df.time - df.groupby(cat)["time"].shift()).fillna(0).values))
new_cols[col_name] = col
new_cols[f"log_{col_name}"] = pd.Series(np.log(col + 1))
# add some abs cols
for col in ["interval_from_pressed", "interval_from_released"]:
base = new_cols[col] if col in new_cols else df[col]
new_cols[f"abs_{col}"] = np.abs(base)
# add some log cols
for col in [
"time_since_chord_character", "time_since_chord_size", "time_since_follows_pause", "time_since_octave",
"time_since_pitch_class"]:
base = new_cols[col] if col in new_cols else df[col]
new_cols[f"log_{col}"] = pd.Series(np.log10(np.abs(base) + 1))
for col in [
"sustain", "time_since_last_pressed", "time_since_last_released", "abs_interval_from_pressed",
"abs_interval_from_released"]:
base = new_cols[col] if col in new_cols else df[col]
new_cols[f"log_{col}"] = pd.Series(np.log(np.abs(base) + 1))
# calculate some simple moving averages
sma_aggs = {
"pitch": ["mean", "min", "max", "std"],
"log_sustain": ["mean", "min", "max", "std"],
"interval_from_pressed": ["mean", "min", "max", "std"],
"log_time_since_last_pressed": ["mean", "min", "max", "std"],
"log_time_since_follows_pause": ["mean", "min", "max", "std"]}
sma_windows = [15, 30, 75]
for col, funcs in sma_aggs.items():
base = new_cols[col] if col in new_cols else df[col]
for window in sma_windows:
for func in funcs:
sma = base.rolling(window).agg(func).bfill()
new_cols[f"{col}_sma_{func}_{window}"] = sma
fwd_sma = base[::-1].rolling(window).agg(func).bfill()[::-1]
new_cols[f"{col}_fwd_sma_{func}_{window}"] = fwd_sma
if col != "follows_pause":
new_cols[f"{col}_sma_{func}_{window}_oscillator"] = base - sma
new_cols[f"{col}_fwd_sma_{func}_{window}_oscillator"] = base - fwd_sma
# add ichimoku indicators
for col in ["pitch", "log_sustain", "interval_from_released", "interval_from_pressed"]:
base = new_cols[col] if col in new_cols else df[col]
tenkan_sen = (base.rolling(9).max() + base.rolling(9).min()).bfill() / 2.0
kijun_sen = (base.rolling(26).max() + base.rolling(26).min()).bfill() / 2.0
senkou_span_a = (tenkan_sen + kijun_sen) / 2.0
senkou_span_b = (base.rolling(52).max() + base.rolling(52).min()).bfill() / 2.0
new_cols[f"{col}_tenkan_sen"] = tenkan_sen
new_cols[f"{col}_kijun_sen"] = kijun_sen
new_cols[f"{col}_senkou_span_a"] = senkou_span_a
new_cols[f"{col}_senkou_span_b"] = senkou_span_b
new_cols[f"{col}_chikou_span"] = base.shift(26).bfill()
new_cols[f"{col}_cloud_is_green"] = senkou_span_a - senkou_span_b
new_cols[f"{col}_relative_to_tenkan_sen"] = base - tenkan_sen
new_cols[f"{col}_relative_to_kijun_sen"] = base - kijun_sen
new_cols[f"{col}_tenkan_sen_relative_to_kijun_sen"] = tenkan_sen - kijun_sen
new_cols[f"{col}_relative_to_chikou_span"] = base - base.shift(26).bfill()
new_cols[f"{col}_relative_to_cloud"] = base - (senkou_span_a + senkou_span_b) / 2.0
if with_extra_features:
# add percent change columns
for col in [
"pitch", "log_sustain", "num_played_notes_pressed", "num_played_notes_released",
"interval_from_pressed", "interval_from_released", "log_time_since_last_pressed",
"log_time_since_last_released"]:
base = new_cols[col] if col in new_cols else df[col]
if col == "pitch":
new_cols[f"{col}_pct_change"] = base.pct_change().fillna(0.0)
else:
new_cols[f"{col}_pct_change"] = pd.Series((np.abs(base) + 1.0).pct_change().fillna(0.0))
ewm_aggs = {
"pitch": ["mean", "std"],
"log_sustain": ["mean", "std"],
"num_played_notes_pressed": ["mean", "std"],
"interval_from_pressed": ["mean", "std"],
"log_abs_interval_from_released": ["mean", "std"],
"log_time_since_last_pressed": ["mean", "std"],
"log_time_since_follows_pause": ["mean", "std"]}
for col, funcs in ewm_aggs.items():
base = new_cols[col] if col in new_cols else df[col]
for func in funcs:
for span in [10, 20, 50]:
new_cols[f"{col}_ewm_{func}_{span}"] = base.ewm(span=span).agg(func).bfill()
new_cols[f"{col}_fwd_ewm_{func}_{span}"] = base[::-1].ewm(span=span).agg(func).bfill()[::-1]
# actually macd uses ewms with spans 12 and 26 and a signal ewm with span 9. but 2x those works better.
macd = base.ewm(span=24).agg(func).bfill() - base.ewm(span=52).agg(func).bfill()
new_cols[f"{col}_ewm_{func}_macd"] = macd
new_cols[f"{col}_ewm_{func}_macd_signal"] = base.ewm(span=18).agg(func).bfill() - macd
if with_extra_features:
# calculate lag values (just taking the values of the previous/next rows)
for col in ["octave", "follows_pause", "chord_character", "chord_size"]:
for i in range(1, 6):
new_cols[f"{col}_lag_{i}"] = df[col].shift(i).bfill().astype(df[col].dtype)
new_cols[f"{col}_fwd_lag_{i}"] = df[col][::- 1].shift(i).bfill()[::-1].astype(df[col].dtype)
if with_extra_features:
# get some aggregate data of the song as a whole
aggregators = {
"pitch": ["sum", "mean", "min", "max", "std"],
"log_sustain": ["sum", "mean", "min", "max", "std"],
"octave": ["nunique"]}
aggregated = df.agg(aggregators)
for col, funcs in aggregators.items():
for func in funcs:
new_cols[f"{col}_{func}"] = pd.Series([aggregated[col][func]] * len(df))
if with_extra_features:
# total number of notes in song
note_count = pd.Series([len(df)] * len(df))
new_cols["note_count"] = note_count
new_cols["note_count_adj_by_dur"] = note_count / df.song_duration[0]
for name, new_col in new_cols.items():
if not pd.api.types.is_numeric_dtype(new_col):
continue
assert not np.any(np.isnan(new_col)), (name, new_col)
assert np.all(np.isfinite(new_col)), (name, new_col)
return pd.concat([df] + [col.rename(name) for name, col in new_cols.items()], axis=1)