-
Notifications
You must be signed in to change notification settings - Fork 1.1k
[CMSIS-NN] Fix stateful execution and batch-major striding for CMSIS-NN LSTM #3564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ limitations under the License. | |
| #include <limits> | ||
|
|
||
| #include "Include/arm_nnfunctions.h" | ||
| #include "Include/arm_nnsupportfunctions.h" | ||
| #include "tensorflow/lite/kernels/internal/quantization_util.h" | ||
| #include "tensorflow/lite/kernels/kernel_util.h" | ||
| #include "tensorflow/lite/micro/kernels/fully_connected.h" | ||
|
|
@@ -270,7 +271,7 @@ TfLiteStatus CMSIS_NN_PortOpData(TfLiteContext* context, OpDataLSTM* params_ref, | |
| } | ||
|
|
||
| TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm( | ||
| const OpData& op_data, const LSTMKernelContents& kernel_content, | ||
| const OpData& op_data, LSTMKernelContents& kernel_content, | ||
| const LSTMBuffers<int16_t>& buffers) { | ||
| TFLITE_DCHECK( | ||
| kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >= | ||
|
|
@@ -282,21 +283,74 @@ TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm( | |
| kernel_content.GetInternalTensor(tflite::kLstmInputTensor)); | ||
| int8_t* output = | ||
| tflite::micro::GetTensorData<int8_t>(kernel_content.output_tensor); | ||
| int8_t* hidden_state = | ||
| tflite::micro::GetTensorData<int8_t>(kernel_content.HiddenStateTensor()); | ||
| int16_t* cell_state = | ||
| tflite::micro::GetTensorData<int16_t>(kernel_content.CellStateTensor()); | ||
|
|
||
| // Create lstm buffer struct | ||
| cmsis_nn_lstm_context cmsis_buffers; | ||
| cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0); | ||
| cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1); | ||
| cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2); | ||
|
|
||
| arm_lstm_unidirectional_s8(input, output, &op_data.params_cmsis_nn, | ||
| &cmsis_buffers); | ||
| cmsis_buffers.cell_state = cell_state; | ||
|
|
||
| const auto& params = op_data.params_cmsis_nn; | ||
|
|
||
| #ifdef CMSIS_NN_STATEFUL_LSTM | ||
| cmsis_buffers.hidden_state = hidden_state; | ||
| arm_cmsis_nn_status status = | ||
| arm_lstm_unidirectional_s8(input, output, ¶ms, &cmsis_buffers); | ||
| if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; | ||
| #else | ||
| if (params.time_major) { | ||
| int8_t* step_hidden_in = hidden_state; | ||
| for (int t = 0; t < params.time_steps; t++) { | ||
| const int8_t* data_in = | ||
| input + (t * params.batch_size * params.input_size); | ||
| int8_t* hidden_out = | ||
| output + (t * params.batch_size * params.hidden_size); | ||
|
|
||
| arm_cmsis_nn_status status = arm_nn_lstm_step_s8( | ||
| data_in, step_hidden_in, hidden_out, ¶ms, &cmsis_buffers, 1); | ||
| if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; | ||
| step_hidden_in = hidden_out; | ||
| } | ||
| if (params.time_steps > 0) { | ||
| std::copy_n(step_hidden_in, params.batch_size * params.hidden_size, | ||
| hidden_state); | ||
| } | ||
| } else { | ||
| cmsis_nn_lstm_params step_params = params; | ||
| step_params.batch_size = 1; | ||
| for (int b = 0; b < params.batch_size; b++) { | ||
| int8_t* step_hidden_in = hidden_state + b * params.hidden_size; | ||
| cmsis_buffers.cell_state = cell_state + b * params.hidden_size; | ||
|
|
||
| for (int t = 0; t < params.time_steps; t++) { | ||
| const int8_t* data_in = | ||
| input + (b * params.time_steps + t) * params.input_size; | ||
| int8_t* hidden_out = | ||
| output + (b * params.time_steps + t) * params.hidden_size; | ||
|
|
||
| arm_cmsis_nn_status status = | ||
| arm_nn_lstm_step_s8(data_in, step_hidden_in, hidden_out, | ||
| &step_params, &cmsis_buffers, 1); | ||
| if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; | ||
| step_hidden_in = hidden_out; | ||
| } | ||
| if (params.time_steps > 0) { | ||
| std::copy_n(step_hidden_in, params.hidden_size, | ||
| hidden_state + b * params.hidden_size); | ||
| } | ||
|
Comment on lines
+341
to
+344
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as the above comment with this additional info: I have not been able to produce a Colab where the. converter will produce a stateful, fused LSTM operation with quantization. The converter (and the Colab session) crash every time. The only time I can make a stateful LSTM in Colab, always produces an unfused LSTM. |
||
| } | ||
| } | ||
| #endif | ||
|
|
||
| return kTfLiteOk; | ||
| } | ||
|
|
||
| TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm( | ||
| const OpData& op_data, const LSTMKernelContents& kernel_content, | ||
| const OpData& op_data, LSTMKernelContents& kernel_content, | ||
| const LSTMBuffers<int16_t>& buffers) { | ||
| TFLITE_DCHECK( | ||
| kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >= | ||
|
|
@@ -308,15 +362,63 @@ TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm( | |
| kernel_content.GetInternalTensor(tflite::kLstmInputTensor)); | ||
| int16_t* output = | ||
| tflite::micro::GetTensorData<int16_t>(kernel_content.output_tensor); | ||
| int16_t* hidden_state = | ||
| tflite::micro::GetTensorData<int16_t>(kernel_content.HiddenStateTensor()); | ||
| int16_t* cell_state = | ||
| tflite::micro::GetTensorData<int16_t>(kernel_content.CellStateTensor()); | ||
|
|
||
| // Create lstm buffer struct | ||
| cmsis_nn_lstm_context cmsis_buffers; | ||
| cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0); | ||
| cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1); | ||
| cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2); | ||
|
|
||
| arm_lstm_unidirectional_s16(input, output, &op_data.params_cmsis_nn, | ||
| &cmsis_buffers); | ||
| cmsis_buffers.cell_state = cell_state; | ||
|
|
||
| const auto& params = op_data.params_cmsis_nn; | ||
|
|
||
| #ifdef CMSIS_NN_STATEFUL_LSTM | ||
| cmsis_buffers.hidden_state = hidden_state; | ||
| arm_cmsis_nn_status status = | ||
| arm_lstm_unidirectional_s16(input, output, ¶ms, &cmsis_buffers); | ||
| if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; | ||
| #else | ||
| if (params.time_major) { | ||
| for (int t = 0; t < params.time_steps; t++) { | ||
| const int16_t* data_in = | ||
| input + (t * params.batch_size * params.input_size); | ||
| int16_t* hidden_out = | ||
| output + (t * params.batch_size * params.hidden_size); | ||
|
|
||
| arm_cmsis_nn_status status = arm_nn_lstm_step_s16( | ||
| data_in, hidden_state, hidden_out, ¶ms, &cmsis_buffers, 1); | ||
| if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; | ||
|
|
||
| // Update hidden state for next step | ||
| std::copy_n(hidden_out, params.batch_size * params.hidden_size, | ||
| hidden_state); | ||
|
Comment on lines
+395
to
+397
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't understand why this is inside the step loop. Why not just update the hidden state input pointer as was done in the s8 code? |
||
| } | ||
| } else { | ||
| cmsis_nn_lstm_params step_params = params; | ||
| step_params.batch_size = 1; | ||
| for (int b = 0; b < params.batch_size; b++) { | ||
| for (int t = 0; t < params.time_steps; t++) { | ||
| const int16_t* data_in = | ||
| input + (b * params.time_steps + t) * params.input_size; | ||
| int16_t* hidden_out = | ||
| output + (b * params.time_steps + t) * params.hidden_size; | ||
| int16_t* current_hidden = hidden_state + b * params.hidden_size; | ||
| cmsis_buffers.cell_state = cell_state + b * params.hidden_size; | ||
|
|
||
| arm_cmsis_nn_status status = | ||
| arm_nn_lstm_step_s16(data_in, current_hidden, hidden_out, | ||
| &step_params, &cmsis_buffers, 1); | ||
| if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError; | ||
|
|
||
| // Update hidden state for next step | ||
| std::copy_n(hidden_out, params.hidden_size, current_hidden); | ||
|
Comment on lines
+416
to
+417
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't understand why this is inside the step loop. Why not just update the hidden state input pointer as was done in the s8 code? |
||
| } | ||
| } | ||
| } | ||
| #endif | ||
|
|
||
| return kTfLiteOk; | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure why this is here. When using the greedy memory planner, the hidden_state may be overwritten by subsequent operator's output(s). See next comment for more info.