Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
#include <limits>

#include "Include/arm_nnfunctions.h"
#include "Include/arm_nnsupportfunctions.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/micro/kernels/fully_connected.h"
Expand Down Expand Up @@ -270,7 +271,7 @@ TfLiteStatus CMSIS_NN_PortOpData(TfLiteContext* context, OpDataLSTM* params_ref,
}

TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm(
const OpData& op_data, const LSTMKernelContents& kernel_content,
const OpData& op_data, LSTMKernelContents& kernel_content,
const LSTMBuffers<int16_t>& buffers) {
TFLITE_DCHECK(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >=
Expand All @@ -282,21 +283,74 @@ TfLiteStatus CMSIS_NN_EvalInteger8x8_16Lstm(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor));
int8_t* output =
tflite::micro::GetTensorData<int8_t>(kernel_content.output_tensor);
int8_t* hidden_state =
tflite::micro::GetTensorData<int8_t>(kernel_content.HiddenStateTensor());
int16_t* cell_state =
tflite::micro::GetTensorData<int16_t>(kernel_content.CellStateTensor());

// Create lstm buffer struct
cmsis_nn_lstm_context cmsis_buffers;
cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0);
cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1);
cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2);

arm_lstm_unidirectional_s8(input, output, &op_data.params_cmsis_nn,
&cmsis_buffers);
cmsis_buffers.cell_state = cell_state;

const auto& params = op_data.params_cmsis_nn;

#ifdef CMSIS_NN_STATEFUL_LSTM
cmsis_buffers.hidden_state = hidden_state;
arm_cmsis_nn_status status =
arm_lstm_unidirectional_s8(input, output, &params, &cmsis_buffers);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
#else
if (params.time_major) {
int8_t* step_hidden_in = hidden_state;
for (int t = 0; t < params.time_steps; t++) {
const int8_t* data_in =
input + (t * params.batch_size * params.input_size);
int8_t* hidden_out =
output + (t * params.batch_size * params.hidden_size);

arm_cmsis_nn_status status = arm_nn_lstm_step_s8(
data_in, step_hidden_in, hidden_out, &params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
step_hidden_in = hidden_out;
}
if (params.time_steps > 0) {
std::copy_n(step_hidden_in, params.batch_size * params.hidden_size,
hidden_state);
}
Comment on lines +318 to +321

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why this is here. When using the greedy memory planner, the hidden_state may be overwritten by subsequent operator's output(s). See next comment for more info.

} else {
cmsis_nn_lstm_params step_params = params;
step_params.batch_size = 1;
for (int b = 0; b < params.batch_size; b++) {
int8_t* step_hidden_in = hidden_state + b * params.hidden_size;
cmsis_buffers.cell_state = cell_state + b * params.hidden_size;

for (int t = 0; t < params.time_steps; t++) {
const int8_t* data_in =
input + (b * params.time_steps + t) * params.input_size;
int8_t* hidden_out =
output + (b * params.time_steps + t) * params.hidden_size;

arm_cmsis_nn_status status =
arm_nn_lstm_step_s8(data_in, step_hidden_in, hidden_out,
&step_params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
step_hidden_in = hidden_out;
}
if (params.time_steps > 0) {
std::copy_n(step_hidden_in, params.hidden_size,
hidden_state + b * params.hidden_size);
}
Comment on lines +341 to +344

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as the above comment with this additional info: I have not been able to produce a Colab where the. converter will produce a stateful, fused LSTM operation with quantization. The converter (and the Colab session) crash every time. The only time I can make a stateful LSTM in Colab, always produces an unfused LSTM.

}
}
#endif

return kTfLiteOk;
}

TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm(
const OpData& op_data, const LSTMKernelContents& kernel_content,
const OpData& op_data, LSTMKernelContents& kernel_content,
const LSTMBuffers<int16_t>& buffers) {
TFLITE_DCHECK(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor)->dims->size >=
Expand All @@ -308,15 +362,63 @@ TfLiteStatus CMSIS_NN_EvalInteger16x8_16Lstm(
kernel_content.GetInternalTensor(tflite::kLstmInputTensor));
int16_t* output =
tflite::micro::GetTensorData<int16_t>(kernel_content.output_tensor);
int16_t* hidden_state =
tflite::micro::GetTensorData<int16_t>(kernel_content.HiddenStateTensor());
int16_t* cell_state =
tflite::micro::GetTensorData<int16_t>(kernel_content.CellStateTensor());

// Create lstm buffer struct
cmsis_nn_lstm_context cmsis_buffers;
cmsis_buffers.temp1 = reinterpret_cast<int16_t*>(buffers.buffer0);
cmsis_buffers.temp2 = reinterpret_cast<int16_t*>(buffers.buffer1);
cmsis_buffers.cell_state = reinterpret_cast<int16_t*>(buffers.buffer2);

arm_lstm_unidirectional_s16(input, output, &op_data.params_cmsis_nn,
&cmsis_buffers);
cmsis_buffers.cell_state = cell_state;

const auto& params = op_data.params_cmsis_nn;

#ifdef CMSIS_NN_STATEFUL_LSTM
cmsis_buffers.hidden_state = hidden_state;
arm_cmsis_nn_status status =
arm_lstm_unidirectional_s16(input, output, &params, &cmsis_buffers);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;
#else
if (params.time_major) {
for (int t = 0; t < params.time_steps; t++) {
const int16_t* data_in =
input + (t * params.batch_size * params.input_size);
int16_t* hidden_out =
output + (t * params.batch_size * params.hidden_size);

arm_cmsis_nn_status status = arm_nn_lstm_step_s16(
data_in, hidden_state, hidden_out, &params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;

// Update hidden state for next step
std::copy_n(hidden_out, params.batch_size * params.hidden_size,
hidden_state);
Comment on lines +395 to +397

@ddavis-2015 ddavis-2015 Jun 18, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't understand why this is inside the step loop. Why not just update the hidden state input pointer as was done in the s8 code?

}
} else {
cmsis_nn_lstm_params step_params = params;
step_params.batch_size = 1;
for (int b = 0; b < params.batch_size; b++) {
for (int t = 0; t < params.time_steps; t++) {
const int16_t* data_in =
input + (b * params.time_steps + t) * params.input_size;
int16_t* hidden_out =
output + (b * params.time_steps + t) * params.hidden_size;
int16_t* current_hidden = hidden_state + b * params.hidden_size;
cmsis_buffers.cell_state = cell_state + b * params.hidden_size;

arm_cmsis_nn_status status =
arm_nn_lstm_step_s16(data_in, current_hidden, hidden_out,
&step_params, &cmsis_buffers, 1);
if (status != ARM_CMSIS_NN_SUCCESS) return kTfLiteError;

// Update hidden state for next step
std::copy_n(hidden_out, params.hidden_size, current_hidden);
Comment on lines +416 to +417

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't understand why this is inside the step loop. Why not just update the hidden state input pointer as was done in the s8 code?

}
}
}
#endif

return kTfLiteOk;
}
Expand Down
Loading