Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Suwon Shon committed Feb 28, 2018
0 parents commit 20bec9b
Show file tree
Hide file tree
Showing 31 changed files with 262,062 additions and 0 deletions.
79 changes: 79 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# End-to-end Dialect Identification (implementation on MGB-3 Arabic dialect dataset)
Tensorflow implementation of End-to-End dialect identificaion in Arabic. If you are familiar with Language/Speaker identificatio/verification, it can be easily modified to another dialect, language or even speaker identification/verification tasks.

# Requirment
* Python, tested on 2.7.6
* Tensorflow > v1.0
* python library > sox, tested on 1.3.2
* python library > librosa, tested on 0.5.1

# Data list format
datalist consist of (location of wavfile) and (label in digit).

Example) "train.txt"
```
./data/wav/EGY/EGY000001.wav 0
./data/wav/EGY/EGY000002.wav 0
./data/wav/NOR/NOR000001.wav 4
```

Labels of Dialect:
- Egytion (EGY) : 0
- Gulf (GLF) : 1
- Levantine(LAV): 2
- Modern Standard Arabic (MSA) : 3
- North African (NOR): 4

# Dataset Augmentation
Augementation was done by two different method. First is random segment of the input utterance, and the other is perturbation by modifying speed and volume of speech.



# Model definition
Simple description of the DNN model:
![Image of Model](https://github.com/swshon/dialectID_e2e/blob/master/images/figure_network.png)
we used four 1-dimensional CNN (1d-CNN) layers (40x5 - 500x7 - 500x1 - 500x1 filter sizes with 1-2-1-1 strides and the number of filters is 500-500-500-3000) and two FC layers (1500-600) that are connected with a Global average pooling layer which averages the CNN outputs to produce a fixed output size of 3000x1.

End-to-end DID accuracy by epoch
![Image of Model](https://github.com/swshon/dialectID_e2e/blob/master/images/accuracy_aug.png)
End-to-end DID accuracy by epoch using augmented dataset
![Image of Model](https://github.com/swshon/dialectID_e2e/blob/master/images/accuracy_feat.png)
Performance comparison with and without Random Segmentation(RS)
![Image of Model](https://github.com/swshon/dialectID_e2e/blob/master/images/random_segment.png)


# Performance evaluation
Best performance is 73.39% on Accuracy. (Feb.28 2018)

for reference,

Conventional i-vector with SVM : 60.32%<br />
Conventional i-vector with LDA and Cosine Distance : 62.60%<br />
End-to-End model without dataset augmentation(MFCC): 65.55%<br />
End-to-End model without dataset augmentation(FBANK): 64.81%<br />
End-to-End model without dataset augmentation(Spectrogram): 57.57%<br />

End-to-End model with volume perturbation(MFCC) : 67.49%<br />
End-to-End model with speed perturbation(MFCC) : 70.51%<br />

End-to-End model with speed and volume perturbation (MFCC) : 70.91%<br />
End-to-End model with speed and volume perturbation (FBANK) : 71.92%<br />
End-to-End model with speed and volume perturbation (Spectrogram) : 68.83%<br />

End-to-End model with speed and volume perturbation+random segmention (MFCC) : 71.05%<br />
End-to-End model with speed and volume perturbation+random segmention (FBANK) : 73.39%<br />
End-to-End model with speed and volume perturbation+random segmention (Spectrogram) : 70.17%<br />


# Offline test
Offline test can be done in offline_test.ipynb code on our pretrained model. Specify wav file you want to identify Arabic dialect by modifying FILENAME variable.

```
FILENAME = ['/data/test/EGY_00001.wav']
```

Result can be shown like below bar plot of likelihood on 5 Arabic dialects.

![Image of offline result plot](https://github.com/swshon/dialectID_e2e/blob/master/images/offline_plot.png)


1,566 changes: 1,566 additions & 0 deletions data/dev.txt

Large diffs are not rendered by default.

4,698 changes: 4,698 additions & 0 deletions data/dev_speed.txt

Large diffs are not rendered by default.

14,094 changes: 14,094 additions & 0 deletions data/dev_speed_vol.txt

Large diffs are not rendered by default.

4,698 changes: 4,698 additions & 0 deletions data/dev_vol.txt

Large diffs are not rendered by default.

Binary file not shown.
Binary file not shown.
Binary file not shown.
1,492 changes: 1,492 additions & 0 deletions data/test.txt

Large diffs are not rendered by default.

14,592 changes: 14,592 additions & 0 deletions data/train.txt

Large diffs are not rendered by default.

43,776 changes: 43,776 additions & 0 deletions data/train_speed.txt

Large diffs are not rendered by default.

131,328 changes: 131,328 additions & 0 deletions data/train_speed_vol.txt

Large diffs are not rendered by default.

43,776 changes: 43,776 additions & 0 deletions data/train_vol.txt

Large diffs are not rendered by default.

Binary file added images/accuracy_aug.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/accuracy_feat.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/figure_network.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/offline_plot.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added images/random_segment.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
194 changes: 194 additions & 0 deletions models/e2e_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import tensorflow as tf
import numpy as np
class nn:

# Create model
def __init__(self, x1, y_, y_string, shapes_batch, softmax_num,is_training,input_dim, is_batchnorm):
self.ea, self.eb, self.o1,self.res1,self.conv,self.ac1,self.ac2 = self.net(x1, shapes_batch, softmax_num,is_training,input_dim,is_batchnorm)

# Create loss
self.loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_, logits=self.o1))
self.label=y_
self.shape = shapes_batch
self.true_length = x1
self.label_string=y_string



def net(self,x, shapes_batch,softmax_num,is_training, input_dim, is_batchnorm):
shape_list = shapes_batch[:,0]
is_exclude_short = False
if is_exclude_short:
#randomly select start of sequences
sequence_limit = tf.reduce_min(shape_list)/2
# sequence_limit = tf.cond(sequence_limit<=200, lambda: sequence_limit, lambda: tf.subtract(sequence_limit,200))
random_start_pt = tf.random_uniform([1],minval=0,maxval=sequence_limit,dtype=tf.int32)
end_pt = tf.reduce_max(shape_list)
x = tf.gather(x,tf.range(tf.squeeze(random_start_pt),end_pt),axis=1)
shape_list = shape_list-random_start_pt

#randomly chunk sequences
batch_quantity = tf.size(shape_list)
aug_list = tf.constant([200, 300, 400], dtype=tf.float32)
aug_quantity = tf.size(aug_list)
rand_index = tf.random_uniform([batch_quantity],minval=0,maxval=aug_quantity-1,dtype=tf.int32)
rand_aug_list = tf.gather(aug_list,rand_index)

shape_list_f = tf.cast(shape_list, tf.float32)
temp = tf.multiply(shape_list_f, rand_aug_list/shape_list_f)
aug_shape_list = tf.cast(temp, tf.int32)
shape_list = tf.minimum(shape_list,aug_shape_list)


featdim = input_dim #channel
weights = []
kernel_size =5
stride = 1
depth = 500

shape_list = shape_list/stride
conv1 = self.conv_layer(x,kernel_size,featdim,stride,depth,'conv1',shape_list)
conv1_bn = self.batch_norm_wrapper_1dcnn(conv1, is_training,'bn1',shape_list,is_batchnorm)
conv1r= tf.nn.relu(conv1_bn)


featdim = depth #channel
weights = []
kernel_size =7
stride = 2
depth = 500

shape_list = shape_list/stride
conv2 = self.conv_layer(conv1r,kernel_size,featdim,stride,depth,'conv2',shape_list)
conv2_bn = self.batch_norm_wrapper_1dcnn(conv2, is_training,'bn2',shape_list,is_batchnorm)
conv2r= tf.nn.relu(conv2_bn)

featdim = depth #channel
weights = []
kernel_size =1
stride = 1
depth = 500

shape_list = shape_list/stride
conv3 = self.conv_layer(conv2r,kernel_size,featdim,stride,depth,'conv3',shape_list)
conv3_bn = self.batch_norm_wrapper_1dcnn(conv3, is_training,'bn3',shape_list,is_batchnorm)
conv3r= tf.nn.relu(conv3_bn)

featdim = depth #channel
weights = []
kernel_size =1
stride = 1
depth = 3000

shape_list = shape_list/stride
conv4 = self.conv_layer(conv3r,kernel_size,featdim,stride,depth,'conv4',shape_list)
conv4_bn = self.batch_norm_wrapper_1dcnn(conv4, is_training,'bn4',shape_list,is_batchnorm)
conv4r= tf.nn.relu(conv4_bn)

print conv1



shape_list = tf.cast(shape_list, tf.float32)
shape_list = tf.reshape(shape_list,[-1,1,1])
mean = tf.reduce_sum(conv4r,1,keep_dims=True)/shape_list
res1=tf.squeeze(mean,axis=1)


fc1 = self.fc_layer(res1,1500,"fc1")
fc1_bn = self.batch_norm_wrapper_fc(fc1, is_training,'bn5',is_batchnorm)
ac1 = tf.nn.relu(fc1_bn)
fc2 = self.fc_layer(ac1,600,"fc2")
fc2_bn = self.batch_norm_wrapper_fc(fc2, is_training,'bn6',is_batchnorm)
ac2 = tf.nn.relu(fc2_bn)

fc3 = self.fc_layer(ac2,softmax_num,"fc3")
return fc1, fc2, fc3,res1,conv1r,ac1,ac2

def xavier_init(self,n_inputs, n_outputs, uniform=True):
if uniform:
init_range = np.sqrt(6.0 / (n_inputs + n_outputs))
return tf.random_uniform_initializer(-init_range, init_range)
else:
stddev = np.sqrt(3.0 / (n_inputs + n_outputs))
return tf.truncated_normal_initializer(stddev=stddev)

def fc_layer(self, bottom, n_weight, name):
print( bottom.get_shape())
assert len(bottom.get_shape()) == 2
n_prev_weight = bottom.get_shape()[1]

initer = self.xavier_init(int(n_prev_weight),n_weight)
W = tf.get_variable(name+'W', dtype=tf.float32, shape=[n_prev_weight, n_weight], initializer=initer)
b = tf.get_variable(name+'b', dtype=tf.float32, initializer=tf.random_uniform([n_weight],-0.001,0.001, dtype=tf.float32))
fc = tf.nn.bias_add(tf.matmul(bottom, W), b)
return fc


def conv_layer(self, bottom, kernel_size,num_channels, stride, depth, name, shape_list): # n_prev_weight = int(bottom.get_shape()[1])
n_prev_weight = tf.shape(bottom)[1]

inputlayer=bottom
initer = tf.truncated_normal_initializer(stddev=0.1)

W = tf.get_variable(name+'W', dtype=tf.float32, shape=[kernel_size, num_channels, depth], initializer=tf.contrib.layers.xavier_initializer())
b = tf.get_variable(name+'b', dtype=tf.float32, initializer=tf.constant(0.001, shape=[depth], dtype=tf.float32))

conv = ( tf.nn.bias_add( tf.nn.conv1d(inputlayer, W, stride, padding='SAME'), b))
mask = tf.sequence_mask(shape_list,tf.shape(conv)[1]) # make mask with batch x frame size
mask = tf.where(mask, tf.ones_like(mask,dtype=tf.float32), tf.zeros_like(mask,dtype=tf.float32))
mask=tf.tile(mask, tf.stack([tf.shape(conv)[2],1])) #replicate make with depth size
mask=tf.reshape(mask,[tf.shape(conv)[2], tf.shape(conv)[0], -1])
mask = tf.transpose(mask,[1, 2, 0])
print mask
conv=tf.multiply(conv,mask)
return conv






def batch_norm_wrapper_1dcnn(self, inputs, is_training, name, shape_list, is_batchnorm,decay = 0.999 ):
if is_batchnorm:
shape_list = tf.cast(shape_list, tf.float32)
epsilon = 1e-3
scale = tf.get_variable(name+'scale',dtype=tf.float32,initializer=tf.ones([inputs.get_shape()[-1]]) )
beta = tf.get_variable(name+'beta',dtype=tf.float32,initializer= tf.zeros([inputs.get_shape()[-1]]) )
pop_mean = tf.get_variable(name+'pop_mean',dtype=tf.float32,initializer = tf.zeros([inputs.get_shape()[-1]]), trainable=False)
pop_var = tf.get_variable(name+'pop_var',dtype=tf.float32,initializer = tf.ones([inputs.get_shape()[-1]]), trainable=False)
if is_training:
#batch_mean, batch_var = tf.nn.moments(inputs,[0,1])
batch_mean = tf.reduce_sum(inputs,[0,1])/tf.reduce_sum(shape_list) # for variable length input
batch_var = tf.reduce_sum(tf.square(inputs-batch_mean), [0,1])/tf.reduce_sum(shape_list) # for variable length input
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(inputs,batch_mean, batch_var, beta, scale, epsilon)
else:
return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon)
else:
return inputs




def batch_norm_wrapper_fc(self, inputs, is_training, name, is_batchnorm, decay = 0.999 ):
if is_batchnorm:
epsilon = 1e-3
scale = tf.get_variable(name+'scale',dtype=tf.float32,initializer=tf.ones([inputs.get_shape()[-1]]) )
beta = tf.get_variable(name+'beta',dtype=tf.float32,initializer= tf.zeros([inputs.get_shape()[-1]]) )
pop_mean = tf.get_variable(name+'pop_mean',dtype=tf.float32,initializer = tf.zeros([inputs.get_shape()[-1]]), trainable=False)
pop_var = tf.get_variable(name+'pop_var',dtype=tf.float32,initializer = tf.ones([inputs.get_shape()[-1]]), trainable=False)
if is_training:
batch_mean, batch_var = tf.nn.moments(inputs,[0])
train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay))
train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay))
with tf.control_dependencies([train_mean, train_var]):
return tf.nn.batch_normalization(inputs,batch_mean, batch_var, beta, scale, epsilon)
else:
return tf.nn.batch_normalization(inputs, pop_mean, pop_var, beta, scale, epsilon)
else:
return inputs


Binary file added models/e2e_model.pyc
Binary file not shown.
Loading

0 comments on commit 20bec9b

Please sign in to comment.