Skip to content

Commit accadd3

Browse files
committed
improve perfomance by using blas-like ops (eg. faxpy/fma) for gru hadamard product
1 parent 4a34847 commit accadd3

File tree

1 file changed

+57
-49
lines changed

1 file changed

+57
-49
lines changed

src/rnn.c

+57-49
Original file line numberDiff line numberDiff line change
@@ -76,35 +76,38 @@ static OPUS_INLINE float relu(float x)
7676
return x < 0 ? 0 : x;
7777
}
7878

79+
static void faxpy(float *restrict a, const rnn_weight *restrict b, int k, float u)
80+
{
81+
if (u == 0.0) return;
82+
for (int idx = 0; idx < k; idx++)
83+
a[idx] += b[idx] * u;
84+
}
85+
7986
void compute_dense(const DenseLayer *layer, float *output, const float *input)
8087
{
8188
int i, j;
8289
int N, M;
83-
int stride;
8490
M = layer->nb_inputs;
8591
N = layer->nb_neurons;
86-
stride = N;
87-
for (i=0;i<N;i++)
88-
{
89-
/* Compute update gate. */
90-
float sum = layer->bias[i];
91-
for (j=0;j<M;j++)
92-
sum += layer->input_weights[j*stride + i]*input[j];
93-
output[i] = WEIGHTS_SCALE*sum;
94-
}
92+
const rnn_weight *ip = layer->input_weights;
93+
/* Compute update gate. */
94+
for(i = 0; i < N; i++)
95+
output[i] = layer->bias[i];
96+
for (j=0;j<M;j++,ip+=N)
97+
faxpy(output, ip, N, input[j]);
9598
switch (layer->activation) {
9699
case ACTIVATION_SIGMOID:
97100
for (i=0;i<N;i++)
98-
output[i] = sigmoid_approx(output[i]);
101+
output[i] = sigmoid_approx(WEIGHTS_SCALE * output[i]);
99102
break;
100103
case ACTIVATION_TANH:
101104
for (i=0;i<N;i++)
102-
output[i] = tansig_approx(output[i]);
105+
output[i] = tansig_approx(WEIGHTS_SCALE * output[i]);
103106
break;
104107
default:
105108
case ACTIVATION_RELU:
106109
for (i=0;i<N;i++)
107-
output[i] = relu(output[i]);
110+
output[i] = relu(WEIGHTS_SCALE * output[i]);
108111
break;
109112
}
110113
}
@@ -120,44 +123,49 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
120123
M = gru->nb_inputs;
121124
N = gru->nb_neurons;
122125
stride = 3*N;
123-
for (i=0;i<N;i++)
124-
{
125-
/* Compute update gate. */
126-
float sum = gru->bias[i];
127-
for (j=0;j<M;j++)
128-
sum += gru->input_weights[j*stride + i]*input[j];
129-
for (j=0;j<N;j++)
130-
sum += gru->recurrent_weights[j*stride + i]*state[j];
131-
z[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
132-
}
133-
for (i=0;i<N;i++)
134-
{
135-
/* Compute reset gate. */
136-
float sum = gru->bias[N + i];
137-
for (j=0;j<M;j++)
138-
sum += gru->input_weights[N + j*stride + i]*input[j];
139-
for (j=0;j<N;j++)
140-
sum += gru->recurrent_weights[N + j*stride + i]*state[j];
141-
r[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
126+
const rnn_weight *ip = gru->input_weights;
127+
const rnn_weight *rp = gru->recurrent_weights;
128+
/* Compute update gate. */
129+
for(i = 0; i < N; i++)
130+
z[i] = gru->bias[i];
131+
for (j=0;j<M;j++,ip+=stride)
132+
faxpy(z, ip, N, input[j]);
133+
for (j=0;j<N;j++,rp+=stride)
134+
faxpy(z, rp, N, state[j]);
135+
for(i = 0; i < N; i++)
136+
z[i] = sigmoid_approx(WEIGHTS_SCALE*z[i]);
137+
/* Compute reset gate. */
138+
for(i = 0; i < N; i++)
139+
r[i] = gru->bias[N+i];
140+
ip = gru->input_weights + N;
141+
rp = gru->recurrent_weights + N;
142+
for (j=0;j<M;j++,ip+=stride)
143+
faxpy(r, ip, N, input[j]);
144+
for (j=0;j<N;j++,rp+=stride)
145+
faxpy(r, rp, N, state[j]);
146+
for(i = 0; i < N; i++)
147+
r[i] = sigmoid_approx(WEIGHTS_SCALE*r[i]);
148+
149+
/* Compute output. */
150+
for(i = 0; i < N; i++)
151+
h[i] = gru->bias[2*N+i];
152+
ip = gru->input_weights + 2*N;
153+
rp = gru->recurrent_weights + 2*N;
154+
for (j=0;j<M;j++,ip+=stride)
155+
faxpy(h, ip, N, input[j]);
156+
for (j=0;j<N;j++,rp+=stride)
157+
faxpy(h, rp, N, r[j]*state[j]);
158+
for (i=0;i<N;i++) {
159+
switch (gru->activation) {
160+
case ACTIVATION_SIGMOID: h[i] = sigmoid_approx(WEIGHTS_SCALE*h[i]);break;
161+
case ACTIVATION_TANH: h[i] = tansig_approx(WEIGHTS_SCALE*h[i]); break;
162+
default:
163+
case ACTIVATION_RELU: h[i] = relu(WEIGHTS_SCALE*h[i]); break;
164+
}
165+
h[i] = z[i]*state[i] + (1-z[i])*h[i];
142166
}
143167
for (i=0;i<N;i++)
144-
{
145-
/* Compute output. */
146-
float sum = gru->bias[2*N + i];
147-
for (j=0;j<M;j++)
148-
sum += gru->input_weights[2*N + j*stride + i]*input[j];
149-
for (j=0;j<N;j++)
150-
sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j];
151-
switch (gru->activation) {
152-
case ACTIVATION_SIGMOID: sum = sigmoid_approx(WEIGHTS_SCALE*sum);break;
153-
case ACTIVATION_TANH: sum = tansig_approx(WEIGHTS_SCALE*sum); break;
154-
default:
155-
case ACTIVATION_RELU: sum = relu(WEIGHTS_SCALE*sum); break;
156-
}
157-
h[i] = z[i]*state[i] + (1-z[i])*sum;
158-
}
159-
for (i=0;i<N;i++)
160-
state[i] = h[i];
168+
state[i] = h[i];
161169
}
162170

163171
#define INPUT_SIZE 42

0 commit comments

Comments
 (0)