@@ -76,35 +76,38 @@ static OPUS_INLINE float relu(float x)
7676 return  x  <  0  ? 0  : x ;
7777}
7878
79+ void  faxpy (float  * restrict a , const  rnn_weight  * restrict b , int  k , float  u )
80+ {
81+  if  (u  ==  0.0 ) return ;
82+  for  (int  idx  =  0 ; idx  <  k ; idx ++ )
83+  a [idx ] +=  b [idx ] *  u ;
84+ }
85+ 7986void  compute_dense (const  DenseLayer  * layer , float  * output , const  float  * input )
8087{
8188 int  i , j ;
8289 int  N , M ;
83-  int  stride ;
8490 M  =  layer -> nb_inputs ;
8591 N  =  layer -> nb_neurons ;
86-  stride  =  N ;
87-  for  (i = 0 ;i < N ;i ++ )
88-  {
89-  /* Compute update gate. */ 
90-  float  sum  =  layer -> bias [i ];
91-  for  (j = 0 ;j < M ;j ++ )
92-  sum  +=  layer -> input_weights [j * stride  +  i ]* input [j ];
93-  output [i ] =  WEIGHTS_SCALE * sum ;
94-  }
92+  const  rnn_weight  * ip  =  layer -> input_weights ;
93+  /* Compute update gate. */ 
94+  for (i  =  0 ; i  <  N ; i ++ )
95+  output [i ] =  layer -> bias [i ];
96+  for  (j = 0 ;j < M ;j ++ ,ip += N )
97+  faxpy (output , ip , N , input [j ]);
9598 switch  (layer -> activation ) {
9699 case  ACTIVATION_SIGMOID :
97100 for  (i = 0 ;i < N ;i ++ )
98-  output [i ] =  sigmoid_approx (output [i ]);
101+  output [i ] =  sigmoid_approx (WEIGHTS_SCALE * output [i ]);
99102 break ;
100103 case  ACTIVATION_TANH :
101104 for  (i = 0 ;i < N ;i ++ )
102-  output [i ] =  tansig_approx (output [i ]);
105+  output [i ] =  tansig_approx (WEIGHTS_SCALE * output [i ]);
103106 break ;
104107 default :
105108 case  ACTIVATION_RELU :
106109 for  (i = 0 ;i < N ;i ++ )
107-  output [i ] =  relu (output [i ]);
110+  output [i ] =  relu (WEIGHTS_SCALE * output [i ]);
108111 break ;
109112 }
110113}
@@ -120,44 +123,49 @@ void compute_gru(const GRULayer *gru, float *state, const float *input)
120123 M  =  gru -> nb_inputs ;
121124 N  =  gru -> nb_neurons ;
122125 stride  =  3 * N ;
123-  for  (i = 0 ;i < N ;i ++ )
124-  {
125-  /* Compute update gate. */ 
126-  float  sum  =  gru -> bias [i ];
127-  for  (j = 0 ;j < M ;j ++ )
128-  sum  +=  gru -> input_weights [j * stride  +  i ]* input [j ];
129-  for  (j = 0 ;j < N ;j ++ )
130-  sum  +=  gru -> recurrent_weights [j * stride  +  i ]* state [j ];
131-  z [i ] =  sigmoid_approx (WEIGHTS_SCALE * sum );
132-  }
133-  for  (i = 0 ;i < N ;i ++ )
134-  {
135-  /* Compute reset gate. */ 
136-  float  sum  =  gru -> bias [N  +  i ];
137-  for  (j = 0 ;j < M ;j ++ )
138-  sum  +=  gru -> input_weights [N  +  j * stride  +  i ]* input [j ];
139-  for  (j = 0 ;j < N ;j ++ )
140-  sum  +=  gru -> recurrent_weights [N  +  j * stride  +  i ]* state [j ];
141-  r [i ] =  sigmoid_approx (WEIGHTS_SCALE * sum );
126+  const  rnn_weight  * ip  =  gru -> input_weights ;
127+  const  rnn_weight  * rp  =  gru -> recurrent_weights ;
128+  /* Compute update gate. */ 
129+  for (i  =  0 ; i  <  N ; i ++ )
130+  z [i ] =  gru -> bias [i ];
131+  for  (j = 0 ;j < M ;j ++ ,ip += stride )
132+  faxpy (z , ip , N , input [j ]);
133+  for  (j = 0 ;j < N ;j ++ ,rp += stride )
134+  faxpy (z , rp , N , state [j ]);
135+  for (i  =  0 ; i  <  N ; i ++ )
136+  z [i ] =  sigmoid_approx (WEIGHTS_SCALE * z [i ]);
137+  /* Compute reset gate. */ 
138+  for (i  =  0 ; i  <  N ; i ++ )
139+  r [i ] =  gru -> bias [N + i ];
140+  ip  =  gru -> input_weights  +  N ;
141+  rp  =  gru -> recurrent_weights  +  N ;
142+  for  (j = 0 ;j < M ;j ++ ,ip += stride )
143+  faxpy (r , ip , N , input [j ]);
144+  for  (j = 0 ;j < N ;j ++ ,rp += stride )
145+  faxpy (r , rp , N , state [j ]);
146+  for (i  =  0 ; i  <  N ; i ++ )
147+  r [i ] =  sigmoid_approx (WEIGHTS_SCALE * r [i ]);
148+ 149+  /* Compute output. */ 
150+  for (i  =  0 ; i  <  N ; i ++ )
151+  h [i ] =  gru -> bias [2 * N + i ];
152+  ip  =  gru -> input_weights  +  2 * N ;
153+  rp  =  gru -> recurrent_weights  +  2 * N ;
154+  for  (j = 0 ;j < M ;j ++ ,ip += stride )
155+  faxpy (h , ip , N , input [j ]);
156+  for  (j = 0 ;j < N ;j ++ ,rp += stride )
157+  faxpy (h , rp , N , r [j ]* state [j ]);
158+  for  (i = 0 ;i < N ;i ++ ) {
159+  switch  (gru -> activation ) {
160+  case  ACTIVATION_SIGMOID : h [i ] =  sigmoid_approx (WEIGHTS_SCALE * h [i ]);break ;
161+  case  ACTIVATION_TANH : h [i ] =  tansig_approx (WEIGHTS_SCALE * h [i ]); break ;
162+  default :
163+  case  ACTIVATION_RELU : h [i ] =  relu (WEIGHTS_SCALE * h [i ]); break ;
164+  }
165+  h [i ] =  z [i ]* state [i ] +  (1 - z [i ])* h [i ];
142166 }
143167 for  (i = 0 ;i < N ;i ++ )
144-  {
145-  /* Compute output. */ 
146-  float  sum  =  gru -> bias [2 * N  +  i ];
147-  for  (j = 0 ;j < M ;j ++ )
148-  sum  +=  gru -> input_weights [2 * N  +  j * stride  +  i ]* input [j ];
149-  for  (j = 0 ;j < N ;j ++ )
150-  sum  +=  gru -> recurrent_weights [2 * N  +  j * stride  +  i ]* state [j ]* r [j ];
151-  switch  (gru -> activation ) {
152-  case  ACTIVATION_SIGMOID : sum  =  sigmoid_approx (WEIGHTS_SCALE * sum );break ;
153-  case  ACTIVATION_TANH : sum  =  tansig_approx (WEIGHTS_SCALE * sum ); break ;
154-  default :
155-  case  ACTIVATION_RELU : sum  =  relu (WEIGHTS_SCALE * sum ); break ;
156-  }
157-  h [i ] =  z [i ]* state [i ] +  (1 - z [i ])* sum ;
158-  }
159-  for  (i = 0 ;i < N ;i ++ )
160-  state [i ] =  h [i ];
168+  state [i ] =  h [i ];
161169}
162170
163171#define  INPUT_SIZE  42
0 commit comments