Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ void LstmStepManager::UpdateBatch() {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::InputShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major || ((size_info_.batch_size > 1 && size_info_.time_steps == 1))) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.input_dimension};
Expand All @@ -485,7 +485,7 @@ RuntimeShape LstmStepManager::InputShape() const {
// Multi-batch for time_major input
RuntimeShape LstmStepManager::StateShape() const {
int batch_size = 1;
if (size_info_.time_major) {
if (size_info_.time_major || (size_info_.batch_size > 1 && size_info_.time_steps == 1)) {
batch_size = size_info_.batch_size;
}
const int dims[2] = {batch_size, size_info_.state_dimension};
Expand Down
29 changes: 20 additions & 9 deletions tensorflow/lite/micro/kernels/xtensa/lstm_eval.h
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,11 @@ void LstmStep(const LstmStepManager& step_info, const OpDataLSTM& op_data,
int input_dimension = step_info.input_dimension();
int state_dimension = step_info.state_dimension();

const auto& size_info = op_data.size_info;
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
num_batches = size_info.batch_size;
}

Comment on lines +669 to +673
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move this up to the definition of num_batches and do the logic there? I think all those values should be const int as well.

// Check offset validity to avoid memory overflow
TFLITE_DCHECK_LE(step_info.InputOffset() + num_batches * input_dimension,
tflite::micro::GetTensorShape(input).FlatSize());
Expand Down Expand Up @@ -805,16 +810,22 @@ TfLiteStatus EvalLstm(const OpDataLSTM& op_data,
}
} else {
// batch first, unable to size the input data. single batch inference
for (int b = 0; b < size_info.batch_size; b++) {
for (int t = 0; t < size_info.time_steps; t++) {
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
// prepare for the next time step
step_info.UpdateTime();
if(size_info.batch_size > 1 && size_info.time_steps == 1) {
Comment on lines 811 to +813
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

} else if (...) {?

// Ramesh
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
} else {
for (int b = 0; b < size_info.batch_size; b++) {
for (int t = 0; t < size_info.time_steps; t++) {
lstm_internal::LstmStep<ActivationType, WeightType, CellType, BiasType>(
step_info, op_data, kernel_content, buffers);
// prepare for the next time step
step_info.UpdateTime();
}
// prepare for the next batch
step_info.UpdateBatch();
step_info.ResetTime();
}
// prepare for the next batch
step_info.UpdateBatch();
step_info.ResetTime();
}
}
return kTfLiteOk;
Expand Down