Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d4c1c30

Browse files
improve perfomance by using blas like primitives (eg. faxpy aka fma)
1 parent 4a34847 commit d4c1c30

File tree

1 file changed

+57
-49
lines changed

1 file changed

+57
-49
lines changed

‎src/rnn.c‎

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -76,35 +76,38 @@ static OPUS_INLINE float relu(float x)
7676
return x < 0 ? 0 : x;
7777
}
7878

79+
void faxpy(float *restrict a, const rnn_weight *restrict b, int k, float u)
80+
{
81+
if (u == 0.0) return;
82+
for (int idx = 0; idx < k; idx++)
83+
a[idx] += b[idx] * u;
84+
}
85+
7986
void compute_dense(const DenseLayer *layer, float *output, const float *input)
8087
{
8188
int i, j;
8289
int N, M;
83-
int stride;
8490
M = layer->nb_inputs;
8591
N = layer->nb_neurons;
86-
stride = N;
87-
for (i=0;i<N;i++)
88-
{
89-
/* Compute update gate. */
90-
float sum = layer->bias[i];
91-
for (j=0;j<M;j++)
92-
sum += layer->input_weights[j*stride + i]*input[j];
93-
output[i] = WEIGHTS_SCALE*sum;
94-
}
92+
const rnn_weight *ip = layer->input_weights;
93+
/* Compute update gate. */
94+
for(i = 0; i < N; i++)
95+
output[i] = layer->bias[i];
96+
for (j=0;j<M;j++,ip+=N)
97+
faxpy(output, ip, N, input[j]);
9598
switch (layer->activation) {
9699
case ACTIVATION_SIGMOID:
97100
for (i=0;i<N;i++)
98-
output[i] = sigmoid_approx(output[i]);
101+
output[i] = sigmoid_approx(WEIGHTS_SCALE*output[i]);
99102
break;
100103
case ACTIVATION_TANH:
101104
for (i=0;i<N;i++)
102-
output[i] = tansig_approx(output[i]);
105+
output[i] = tansig_approx(WEIGHTS_SCALE*output[i]);
103106
break;
104107
default:
105108
case ACTIVATION_RELU:
106109
for (i=0;i<N;i++)
107-
output[i] = relu(output[i]);
110+
output[i] = relu(WEIGHTS_SCALE*output[i]);
108111
break;
109112
}
110113
}
@@ -120,44 +123,49 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
120123
M = gru->nb_inputs;
121124
N = gru->nb_neurons;
122125
stride = 3*N;
123-
for (i=0;i<N;i++)
124-
{
125-
/* Compute update gate. */
126-
float sum = gru->bias[i];
127-
for (j=0;j<M;j++)
128-
sum += gru->input_weights[j*stride + i]*input[j];
129-
for (j=0;j<N;j++)
130-
sum += gru->recurrent_weights[j*stride + i]*state[j];
131-
z[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
132-
}
133-
for (i=0;i<N;i++)
134-
{
135-
/* Compute reset gate. */
136-
float sum = gru->bias[N + i];
137-
for (j=0;j<M;j++)
138-
sum += gru->input_weights[N + j*stride + i]*input[j];
139-
for (j=0;j<N;j++)
140-
sum += gru->recurrent_weights[N + j*stride + i]*state[j];
141-
r[i] = sigmoid_approx(WEIGHTS_SCALE*sum);
126+
const rnn_weight *ip = gru->input_weights;
127+
const rnn_weight *rp = gru->recurrent_weights;
128+
/* Compute update gate. */
129+
for(i = 0; i < N; i++)
130+
z[i] = gru->bias[i];
131+
for (j=0;j<M;j++,ip+=stride)
132+
faxpy(z, ip, N, input[j]);
133+
for (j=0;j<N;j++,rp+=stride)
134+
faxpy(z, rp, N, state[j]);
135+
for(i = 0; i < N; i++)
136+
z[i] = sigmoid_approx(WEIGHTS_SCALE*z[i]);
137+
/* Compute reset gate. */
138+
for(i = 0; i < N; i++)
139+
r[i] = gru->bias[N+i];
140+
ip = gru->input_weights + N;
141+
rp = gru->recurrent_weights + N;
142+
for (j=0;j<M;j++,ip+=stride)
143+
faxpy(r, ip, N, input[j]);
144+
for (j=0;j<N;j++,rp+=stride)
145+
faxpy(r, rp, N, state[j]);
146+
for(i = 0; i < N; i++)
147+
r[i] = sigmoid_approx(WEIGHTS_SCALE*r[i]);
148+
149+
/* Compute output. */
150+
for(i = 0; i < N; i++)
151+
h[i] = gru->bias[2*N+i];
152+
ip = gru->input_weights + 2*N;
153+
rp = gru->recurrent_weights + 2*N;
154+
for (j=0;j<M;j++,ip+=stride)
155+
faxpy(h, ip, N, input[j]);
156+
for (j=0;j<N;j++,rp+=stride)
157+
faxpy(h, rp, N, r[j]*state[j]);
158+
for (i=0;i<N;i++) {
159+
switch (gru->activation) {
160+
case ACTIVATION_SIGMOID: h[i] = sigmoid_approx(WEIGHTS_SCALE*h[i]);break;
161+
case ACTIVATION_TANH: h[i] = tansig_approx(WEIGHTS_SCALE*h[i]); break;
162+
default:
163+
case ACTIVATION_RELU: h[i] = relu(WEIGHTS_SCALE*h[i]); break;
164+
}
165+
h[i] = z[i]*state[i] + (1-z[i])*h[i];
142166
}
143167
for (i=0;i<N;i++)
144-
{
145-
/* Compute output. */
146-
float sum = gru->bias[2*N + i];
147-
for (j=0;j<M;j++)
148-
sum += gru->input_weights[2*N + j*stride + i]*input[j];
149-
for (j=0;j<N;j++)
150-
sum += gru->recurrent_weights[2*N + j*stride + i]*state[j]*r[j];
151-
switch (gru->activation) {
152-
case ACTIVATION_SIGMOID: sum = sigmoid_approx(WEIGHTS_SCALE*sum);break;
153-
case ACTIVATION_TANH: sum = tansig_approx(WEIGHTS_SCALE*sum); break;
154-
default:
155-
case ACTIVATION_RELU: sum = relu(WEIGHTS_SCALE*sum); break;
156-
}
157-
h[i] = z[i]*state[i] + (1-z[i])*sum;
158-
}
159-
for (i=0;i<N;i++)
160-
state[i] = h[i];
168+
state[i] = h[i];
161169
}
162170

163171
#define INPUT_SIZE 42

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /