@@ -15,8 +15,8 @@ limitations under the License.
15
15
16
16
#include " tensorflow/lite/kernels/internal/reference/add.h"
17
17
18
+ #include " CMSIS/NN/Include/arm_nnfunctions.h"
18
19
#include " tensorflow/lite/c/builtin_op_data.h"
19
- #include " tensorflow/lite/c/common.h"
20
20
#include " tensorflow/lite/kernels/internal/quantization_util.h"
21
21
#include " tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
22
22
#include " tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
@@ -54,10 +54,6 @@ struct OpData {
54
54
int32_t input1_offset;
55
55
int32_t input2_offset;
56
56
int32_t output_offset;
57
-
58
- // Used only for float evals:
59
- float output_activation_min_f32;
60
- float output_activation_max_f32;
61
57
};
62
58
63
59
TfLiteStatus CalculateOpData (TfLiteContext* context, TfLiteAddParams* params,
@@ -95,10 +91,6 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params,
95
91
TF_LITE_ENSURE_STATUS (CalculateActivationRangeQuantized (
96
92
context, params->activation , output, &data->output_activation_min ,
97
93
&data->output_activation_max ));
98
- } else if (output->type == kTfLiteFloat32 ) {
99
- CalculateActivationRange (params->activation ,
100
- &data->output_activation_min_f32 ,
101
- &data->output_activation_max_f32 );
102
94
}
103
95
104
96
return kTfLiteOk ;
@@ -107,25 +99,24 @@ TfLiteStatus CalculateOpData(TfLiteContext* context, TfLiteAddParams* params,
107
99
void EvalAdd (TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
108
100
const OpData* data, const TfLiteEvalTensor* input1,
109
101
const TfLiteEvalTensor* input2, TfLiteEvalTensor* output) {
102
+ float output_activation_min, output_activation_max;
103
+ CalculateActivationRange (params->activation , &output_activation_min,
104
+ &output_activation_max);
110
105
tflite::ArithmeticParams op_params;
111
- SetActivationParams (data->output_activation_min_f32 ,
112
- data->output_activation_max_f32 , &op_params);
106
+ SetActivationParams (output_activation_min, output_activation_max, &op_params);
107
+ #define TF_LITE_ADD (opname ) \
108
+ reference_ops::opname (op_params, tflite::micro::GetTensorShape (input1), \
109
+ tflite::micro::GetTensorData<float >(input1), \
110
+ tflite::micro::GetTensorShape (input2), \
111
+ tflite::micro::GetTensorData<float >(input2), \
112
+ tflite::micro::GetTensorShape (output), \
113
+ tflite::micro::GetTensorData<float >(output))
113
114
if (data->requires_broadcast ) {
114
- reference_ops::BroadcastAdd4DSlow (
115
- op_params, tflite::micro::GetTensorShape (input1),
116
- tflite::micro::GetTensorData<float >(input1),
117
- tflite::micro::GetTensorShape (input2),
118
- tflite::micro::GetTensorData<float >(input2),
119
- tflite::micro::GetTensorShape (output),
120
- tflite::micro::GetTensorData<float >(output));
115
+ TF_LITE_ADD (BroadcastAdd4DSlow);
121
116
} else {
122
- reference_ops::Add (op_params, tflite::micro::GetTensorShape (input1),
123
- tflite::micro::GetTensorData<float >(input1),
124
- tflite::micro::GetTensorShape (input2),
125
- tflite::micro::GetTensorData<float >(input2),
126
- tflite::micro::GetTensorShape (output),
127
- tflite::micro::GetTensorData<float >(output));
117
+ TF_LITE_ADD (Add);
128
118
}
119
+ #undef TF_LITE_ADD
129
120
}
130
121
131
122
TfLiteStatus EvalAddQuantized (TfLiteContext* context, TfLiteNode* node,
@@ -150,42 +141,39 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
150
141
bool need_broadcast = reference_ops::ProcessBroadcastShapes (
151
142
tflite::micro::GetTensorShape (input1),
152
143
tflite::micro::GetTensorShape (input2), &op_params);
144
+ #define TF_LITE_ADD (type, opname, dtype ) \
145
+ type::opname (op_params, tflite::micro::GetTensorShape (input1), \
146
+ tflite::micro::GetTensorData<dtype>(input1), \
147
+ tflite::micro::GetTensorShape (input2), \
148
+ tflite::micro::GetTensorData<dtype>(input2), \
149
+ tflite::micro::GetTensorShape (output), \
150
+ tflite::micro::GetTensorData<dtype>(output));
153
151
if (output->type == kTfLiteInt8 ) {
154
152
if (need_broadcast) {
155
- reference_integer_ops::BroadcastAdd4DSlow (
156
- op_params, tflite::micro::GetTensorShape (input1),
157
- tflite::micro::GetTensorData<int8_t >(input1),
158
- tflite::micro::GetTensorShape (input2),
159
- tflite::micro::GetTensorData<int8_t >(input2),
160
- tflite::micro::GetTensorShape (output),
161
- tflite::micro::GetTensorData<int8_t >(output));
153
+ TF_LITE_ADD (reference_integer_ops, BroadcastAdd4DSlow, int8_t );
162
154
} else {
163
- reference_integer_ops::Add (
164
- op_params, tflite::micro::GetTensorShape (input1),
155
+ arm_elementwise_add_s8 (
165
156
tflite::micro::GetTensorData<int8_t >(input1),
166
- tflite::micro::GetTensorShape (input2),
167
157
tflite::micro::GetTensorData<int8_t >(input2),
168
- tflite::micro::GetTensorShape (output),
169
- tflite::micro::GetTensorData<int8_t >(output));
158
+ op_params.input1_offset , op_params.input1_multiplier ,
159
+ op_params.input1_shift , op_params.input2_offset ,
160
+ op_params.input2_multiplier , op_params.input2_shift ,
161
+ op_params.left_shift , tflite::micro::GetTensorData<int8_t >(output),
162
+ op_params.output_offset , op_params.output_multiplier ,
163
+ op_params.output_shift , op_params.quantized_activation_min ,
164
+ op_params.quantized_activation_max ,
165
+ MatchingElementsSize (tflite::micro::GetTensorShape (input1),
166
+ tflite::micro::GetTensorShape (input2),
167
+ tflite::micro::GetTensorShape (output)));
170
168
}
171
169
} else {
172
170
if (need_broadcast) {
173
- reference_ops::BroadcastAdd4DSlow (
174
- op_params, tflite::micro::GetTensorShape (input1),
175
- tflite::micro::GetTensorData<uint8_t >(input1),
176
- tflite::micro::GetTensorShape (input2),
177
- tflite::micro::GetTensorData<uint8_t >(input2),
178
- tflite::micro::GetTensorShape (output),
179
- tflite::micro::GetTensorData<uint8_t >(output));
171
+ TF_LITE_ADD (reference_ops, BroadcastAdd4DSlow, uint8_t );
180
172
} else {
181
- reference_ops::Add (op_params, tflite::micro::GetTensorShape (input1),
182
- tflite::micro::GetTensorData<uint8_t >(input1),
183
- tflite::micro::GetTensorShape (input2),
184
- tflite::micro::GetTensorData<uint8_t >(input2),
185
- tflite::micro::GetTensorShape (output),
186
- tflite::micro::GetTensorData<uint8_t >(output));
173
+ TF_LITE_ADD (reference_ops, Add, uint8_t );
187
174
}
188
175
}
176
+ #undef TF_LITE_ADD
189
177
}
190
178
191
179
return kTfLiteOk ;
@@ -201,11 +189,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
201
189
TFLITE_DCHECK (node->builtin_data != nullptr );
202
190
203
191
const TfLiteTensor* input1 = GetInput (context, node, kInputTensor1 );
204
- TF_LITE_ENSURE (context, input1 != nullptr );
205
192
const TfLiteTensor* input2 = GetInput (context, node, kInputTensor2 );
206
- TF_LITE_ENSURE (context, input2 != nullptr );
207
193
TfLiteTensor* output = GetOutput (context, node, kOutputTensor );
208
- TF_LITE_ENSURE (context, output != nullptr );
209
194
210
195
OpData* data = static_cast <OpData*>(node->user_data );
211
196
auto * params = reinterpret_cast <TfLiteAddParams*>(node->builtin_data );
@@ -219,16 +204,16 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
219
204
TfLiteStatus Eval (TfLiteContext* context, TfLiteNode* node) {
220
205
auto * params = reinterpret_cast <TfLiteAddParams*>(node->builtin_data );
221
206
222
- TFLITE_DCHECK (node->user_data != nullptr );
223
- const OpData* data = static_cast <const OpData*>(node->user_data );
224
-
225
207
const TfLiteEvalTensor* input1 =
226
208
tflite::micro::GetEvalInput (context, node, kInputTensor1 );
227
209
const TfLiteEvalTensor* input2 =
228
210
tflite::micro::GetEvalInput (context, node, kInputTensor2 );
229
211
TfLiteEvalTensor* output =
230
212
tflite::micro::GetEvalOutput (context, node, kOutputTensor );
231
213
214
+ TFLITE_DCHECK (node->user_data != nullptr );
215
+ const OpData* data = static_cast <const OpData*>(node->user_data );
216
+
232
217
if (output->type == kTfLiteFloat32 ) {
233
218
EvalAdd (context, node, params, data, input1, input2, output);
234
219
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ) {
0 commit comments