diff --git a/esp-dl/dl/tensor/include/dl_tensor_base.hpp b/esp-dl/dl/tensor/include/dl_tensor_base.hpp index 80ecf017..6970e02b 100644 --- a/esp-dl/dl/tensor/include/dl_tensor_base.hpp +++ b/esp-dl/dl/tensor/include/dl_tensor_base.hpp @@ -147,10 +147,11 @@ class TensorBase { * @brief Assign tensor to this tensor * * @param tensor + * @param start_position The starting index in src_data from which to begin copying data. * * @return ture if assign successfully, otherwise false. */ - bool assign(TensorBase *tensor); + bool assign(TensorBase *tensor, int start_position = 0); /** * @brief Assign data to this tensor diff --git a/esp-dl/dl/tensor/src/dl_tensor_base.cpp b/esp-dl/dl/tensor/src/dl_tensor_base.cpp index 4c0b5402..46b77b35 100644 --- a/esp-dl/dl/tensor/src/dl_tensor_base.cpp +++ b/esp-dl/dl/tensor/src/dl_tensor_base.cpp @@ -167,7 +167,7 @@ TensorBase::TensorBase( this->caps = caps; } -bool TensorBase::assign(TensorBase *tensor) +bool TensorBase::assign(TensorBase *tensor, int start_position) { if (tensor == nullptr || this->get_size() != tensor->get_size()) { return false; @@ -181,9 +181,21 @@ bool TensorBase::assign(TensorBase *tensor) if (this->dtype == DATA_TYPE_INT8) { int8_t *data = (int8_t *)this->data; - for (int i = 0; i < this->get_size(); i++) { - data[i] = quantize(src_data[i], inv_scale); + int data_size = this->get_size(); + // Ensure start_position is within valid range [0, data_size) + if (start_position < 0 || start_position >= data_size) + return false; + + int i = 0; + // Copy from start_position to end of src_data + for (int j = start_position; j < data_size; i++, j++) { + data[i] = quantize(src_data[j], inv_scale); } + // Wrap around and copy from start of src_data to start_position - 1 + for (int j = 0; j < start_position; ++i, ++j) { + data[i] = quantize(src_data[j], inv_scale); + } + } else if (this->dtype == DATA_TYPE_INT16) { int16_t *data = (int16_t *)this->data; for (int i = 0; i < this->get_size(); i++) {