Skip to content

Commit

Permalink
modified my_custom_dataset and augs
Browse files Browse the repository at this point in the history
edited the my_custom_dataset.py file so ssl1.py can now load data using,
rather than the opso functions i added to ssl1 previously.

I then edited my_custom_dataset to use the augmentations I used for
bomb fishing.
  • Loading branch information
BenUCL committed Aug 21, 2023
1 parent 92e9896 commit 239d4b3
Show file tree
Hide file tree
Showing 4 changed files with 1,336 additions and 56 deletions.
95 changes: 95 additions & 0 deletions code/simclr-pytorch-reefs/models/my_custom_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
import json
import librosa
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

# for my transformations
#import librosa
from audiomentations import Compose, AddGaussianNoise, PitchShift, TimeStretch, ClippingDistortion, Gain, SevenBandParametricEQ


def resize_mel_spectrogram(mel_spec, desired_shape=(224, 224)):
# Convert the 2D Mel spectrogram to 4D tensor (batch, channels, height, width)
mel_spec_tensor = torch.tensor(mel_spec).unsqueeze(0).unsqueeze(0)
# Resize
resized_mel_spec = F.interpolate(mel_spec_tensor, size=desired_shape, mode='bilinear', align_corners=False)
return resized_mel_spec.squeeze(0).squeeze(0).numpy()

# augmentation
augment_raw_audio = Compose(
[
AddGaussianNoise(min_amplitude=0.0001, max_amplitude=0.0005, p=1), # good
PitchShift(min_semitones=-2, max_semitones=12, p=0.5), #set values so it doesnt shift too low, rmeoving bomb signal
TimeStretch(p = 0.5), # defaults are fine
ClippingDistortion(0, 5, p = 0.5), # tested params to make sure its good
Gain(-10, 5, p = 0.5), # defaults are fine
# throws an error, so i commented it out
#SevenBandParametricEQ(-12, 12, p = 0.5)
]
)

# Modify the load_audio_and_get_mel_spectrogram function:
def load_audio_and_get_mel_spectrogram(filename, sr=8000, n_mels=128, n_fft=1024, hop_length=64, win_length=512):
y, _ = librosa.load(filename, sr=sr)
augmented_signal = augment_raw_audio(y, sr)

mel_spectrogram = librosa.feature.melspectrogram(y=augmented_signal, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
mel_spectrogram_resized = resize_mel_spectrogram(mel_spectrogram)
return mel_spectrogram_resized



class CTDataset(Dataset):

def __init__(self, cfg, split, transform):
'''
Constructor. Here, we collect and index the dataset inputs and labels.
'''
#if split == 'unlabeled':
# print('This will not work unless you change the getitem function to have no labels for the unlabeled set')
self.data_root = cfg['data_path']
self.split = split
self.transform = transform
#

# index data from JSON file
self.data = []
with open(cfg['json_path'], 'r') as f:
json_data = json.load(f)
for sublist in json_data.values():
for entry in sublist:
#print(entry)

if entry["data_type"] == split:
path = entry["file_name"]
label = entry["class"]
self.data.append((path, label))

def __len__(self):
'''
Returns the length of the dataset.
'''
return len(self.data)

def __getitem__(self, idx):
'''
Returns a single data point at given idx.
Here's where we actually load the audio and get the Mel spectrogram.
'''
audio_path, label = self.data[idx]

# load audio and get Mel spectrogram
mel_spectrogram = load_audio_and_get_mel_spectrogram(os.path.join(self.data_root, audio_path))

# make 3 dimensions, so shape goes from [x, y] to [3, x, y]
mel_spectrogram_tensor = torch.tensor(mel_spectrogram).unsqueeze(0).repeat(3, 1, 1).float()

# the old transform call, its now ditched
#if self.transform:
# mel_spectrogram_tensor = self.transform(mel_spectrogram_tensor)

# return the objects, label is commented out for now
return mel_spectrogram_tensor#, label
311 changes: 311 additions & 0 deletions code/simclr-pytorch-reefs/models/my_custom_dataset_ben3.ipynb

Large diffs are not rendered by default.

146 changes: 90 additions & 56 deletions code/simclr-pytorch-reefs/models/ssl1.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
import pandas as pd
from opensoundscape.preprocess.preprocessors import SpectrogramPreprocessor
from opensoundscape.ml.datasets import AudioFileDataset
# helper function for displaying a sample as an image
# helper function for displaying a sample as an image
from opensoundscape.preprocess.utils import show_tensor, show_tensor_grid
from opensoundscape import Action
from opensoundscape.spectrogram import MelSpectrogram




##################
def _melspec_linear_to_db(melspec):

Expand Down Expand Up @@ -131,7 +132,24 @@ def samplers(self):

def prepare_data(self):
##self.trainset = dataset
##train_transform, test_transform = self.transforms() ### comment out later??
train_transform = self.transforms()#, test_transform = self.transforms() ### comment out later??

if self.hparams.data == 'ROV':
# not used with new dataset issue fix but can leave in for now
cfg = {'data_path': '/mnt/ssd-cluster/ben/data/full_dataset/', #############################
'json_path': '/home/ben/reef-audio-representation-learning/data/dataset.json'}#### tarun : for pretraining self.trainset is the unlabeled dataset.
print(f'Dataset path:')
print(cfg['data_path'])

self.trainset = CTDataset(cfg, split='train_data', transform=train_transform)
#### tarun : for eval or finetuning,self.trainset is the 10percent train dataset
#self.trainset = CTDataset(cfg, split='train', transform=train_transform)
#self.testset = CTDataset(cfg, split='val', transform=test_transform)

else:
raise NotImplementedError
print(f'Number of training data samples: {len(self.trainset)}')

# print('The following train transform is used:\n', train_transform)
# print('The following test transform is used:\n', test_transform)
# if self.hparams.data == 'cifar':
Expand All @@ -142,84 +160,97 @@ def prepare_data(self):
# valdir = os.path.join(self.IMAGENET_PATH, 'val')
# self.trainset = datasets.ImageFolder(traindir, transform=train_transform)
# self.testset = datasets.ImageFolder(valdir, transform=test_transform)
if self.hparams.data == 'ROV':




##############################################

# if self.hparams.data == 'ROV':

# not used with new dataset issue fix but can leave in for now
cfg = {'dataset_path': '/home/ben/data/full_dataset/', #############################
'json_path': '/home/ben/data/dataset.json'}
# # not used with new dataset issue fix but can leave in for now
# cfg = {'dataset_path': '/home/ben/data/full_dataset/', #############################
# 'json_path': '/home/ben/data/dataset.json'}

#cfg = {'data_root':'/root/all_ROV_crops_with_unknown/all_ROV_crops_with_unknown', 'train_label_file':'../10_percent_train_with_unknown.csv', 'val_label_file':'../5_percent_val_with_unknown.csv', 'test_label_file':'../10_percent_test_with_unknown.csv', 'unlabeled_file':'../75_percent_unlabeled_with_unknown.csv'}
#### tarun : for pretraining self.trainset is the unlabeled dataset.
# #cfg = {'data_root':'/root/all_ROV_crops_with_unknown/all_ROV_crops_with_unknown', 'train_label_file':'../10_percent_train_with_unknown.csv', 'val_label_file':'../5_percent_val_with_unknown.csv', 'test_label_file':'../10_percent_test_with_unknown.csv', 'unlabeled_file':'../75_percent_unlabeled_with_unknown.csv'}
# #### tarun : for pretraining self.trainset is the unlabeled dataset.


#self.trainset = CTDataset(**cfg)#, split='unlabeled', transform=train_transform) ##################################
#################################################
# fixing dataset issue we add a load of load of code here to preprocess data and make train.dataset
# Load the JSON data from the file
json_path = '/home/ben/reef-audio-representation-learning/data/dataset.json'
dataset_path = '/mnt/ssd-cluster/ben/data/full_dataset/'
with open(json_path, "r") as file:
data = json.load(file)
# #self.trainset = CTDataset(**cfg)#, split='unlabeled', transform=train_transform) ##################################


# #################################################
# # fixing dataset issue we add a load of load of code here to preprocess data and make train.dataset
# # Load the JSON data from the file
# json_path = '/home/ben/reef-audio-representation-learning/data/dataset.json'
# dataset_path = '/mnt/ssd-cluster/ben/data/full_dataset/'
# with open(json_path, "r") as file:
# data = json.load(file)

# Extract the list of dictionaries from the "audio" key
audio_data = data.get("audio", [])
# # Extract the list of dictionaries from the "audio" key
# audio_data = data.get("audio", [])

# Filter the list to only include entries where data_type = "train_data"
data = [entry for entry in audio_data if entry.get("data_type") == "train_data"]
# # Filter the list to only include entries where data_type = "train_data"
# data = [entry for entry in audio_data if entry.get("data_type") == "train_data"]

# Convert the filtered list into a DataFrame
#df = pd.DataFrame(self.data)
df = pd.DataFrame(data)#[:32]) to rig dataset size for testing
# # Convert the filtered list into a DataFrame
# #df = pd.DataFrame(self.data)
# df = pd.DataFrame(data)#[:32]) to rig dataset size for testing

# Convert the list of dictionaries (which is the value of the main dictionary) into a DataFrame
#df = pd.DataFrame(data[list(data.keys())[0]])
#self.data = {k: v for k, v in data.items() if v.get("data_type") == "train_data"}
#df = pd.DataFrame(self.data[list(self.data.keys())[0]])
# # Convert the list of dictionaries (which is the value of the main dictionary) into a DataFrame
# #df = pd.DataFrame(data[list(data.keys())[0]])
# #self.data = {k: v for k, v in data.items() if v.get("data_type") == "train_data"}
# #df = pd.DataFrame(self.data[list(self.data.keys())[0]])


# Create a dataframe with just file_path and a class column (req for AudioFileDataset)
transformed_df = df[['file_name', 'class']].copy()
# # Create a dataframe with just file_path and a class column (req for AudioFileDataset)
# transformed_df = df[['file_name', 'class']].copy()

# rename 'file_name' column to 'file'
transformed_df.rename(columns={'file_name': 'file'}, inplace=True)
# # rename 'file_name' column to 'file'
# transformed_df.rename(columns={'file_name': 'file'}, inplace=True)

# set file to be the index for AudioFileDataset
transformed_df.set_index('file', inplace=True)
# # set file to be the index for AudioFileDataset
# transformed_df.set_index('file', inplace=True)

# set all classes to 1 as AudioFileDataset requires class
transformed_df['class'] = 1
# # set all classes to 1 as AudioFileDataset requires class
# transformed_df['class'] = 1

# append dataset_path to start of file_name column
transformed_df.index = dataset_path + transformed_df.index
#transformed_df.head() # for notebook
# # append dataset_path to start of file_name column
# transformed_df.index = dataset_path + transformed_df.index
# #transformed_df.head() # for notebook

# initialize the preprocessor (forget what this does?)
pre = SpectrogramPreprocessor(sample_duration=1.92)
# # initialize the preprocessor (forget what this does?)
# pre = SpectrogramPreprocessor(sample_duration=1.92)

# initialize the dataset
dataset = AudioFileDataset(transformed_df, pre)
# # initialize the dataset
# dataset = AudioFileDataset(transformed_df, pre)

# change the bandpass from the default to 8kHz
dataset.preprocessor.pipeline.bandpass.set(min_f=0,max_f=8000)
# # change the bandpass from the default to 8kHz
# dataset.preprocessor.pipeline.bandpass.set(min_f=0,max_f=8000)

melspec_action = Action(_my_melspec)
melspec_bandpass_action = Action(MelSpectrogram.bandpass, min_f=0, max_f=8000)
# melspec_action = Action(_my_melspec)
# melspec_bandpass_action = Action(MelSpectrogram.bandpass, min_f=0, max_f=8000)

dataset.preprocessor.pipeline['to_spec'] = melspec_action
dataset.preprocessor.pipeline['bandpass'] = melspec_bandpass_action
dataset.bypass_augmentations = True ### added to stop augs
print(f'Total number of train samples found: {len(dataset)}')
#print(dataset[0].shape)
self.trainset = dataset
# dataset.preprocessor.pipeline['to_spec'] = melspec_action
# dataset.preprocessor.pipeline['bandpass'] = melspec_bandpass_action
# dataset.bypass_augmentations = True ### added to stop augs
# print(f'Total number of train samples found: {len(dataset)}')
# #print(dataset[0].shape)
# self.trainset = dataset


##############################################

#### tarun : for eval or finetuning,self.trainset is the 10percent train dataset
#self.trainset = CTDataset(cfg, split='train', transform=train_transform)
##self.testset = CTDataset(cfg, split='val', transform=test_transform)
else:
raise NotImplementedError
# else:
# raise NotImplementedError
####################################





def dataloaders(self, iters=None):
train_batch_sampler = self.samplers()#, test_batch_sampler = self.samplers()
Expand Down Expand Up @@ -401,16 +432,19 @@ def transforms(self):
from utils.datautils import GaussianBlur

im_size = 224
# could put my transformation code in here? should i??

# note in my_custom_dataset.py this is currently written over, so these transforms arent used
train_transform = transforms.Compose([
transforms.RandomResizedCrop(
im_size,
scale=(self.hparams.scale_lower, 1.0),
interpolation=PIL.Image.BICUBIC,
),
## transforms.RandomHorizontalFlip(0.5),
transforms.RandomHorizontalFlip(0.5),
## datautils.get_color_distortion(s=self.hparams.color_dist_s),
## transforms.ToTensor(),
GaussianBlur(im_size // 10, 0.5),
#GaussianBlur(im_size // 10, 0.5),
## datautils.Clip(),
])
##test_transform = train_transform
Expand Down
Loading

0 comments on commit 239d4b3

Please sign in to comment.