Skip to content

Commit 9ad84d7

Browse files
committed
Sync from upstream TF.
1 parent c9f308e commit 9ad84d7

File tree

4 files changed

+90
-1
lines changed

4 files changed

+90
-1
lines changed

tensorflow/compiler/mlir/lite/schema/schema_utils.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,12 @@ limitations under the License.
1515
#include "tensorflow/compiler/mlir/lite/schema/schema_utils.h"
1616

1717
#include <algorithm>
18+
#include <complex>
19+
#include <cstddef>
20+
#include <cstdint>
1821

1922
#include "tensorflow/compiler/mlir/lite/kernels/internal/compatibility_macros.h"
23+
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
2024

2125
namespace tflite {
2226

@@ -59,4 +63,51 @@ BuiltinOperator GetBuiltinCode(const OperatorCodeT* op_code) {
5963
op_code->deprecated_builtin_code));
6064
}
6165

66+
size_t TensorTypeGetSize(::tflite::TensorType data_type) {
67+
switch (data_type) {
68+
case ::tflite::TensorType_FLOAT32:
69+
static_assert(sizeof(float) == 4, "");
70+
return 4;
71+
case ::tflite::TensorType_FLOAT16:
72+
static_assert(sizeof(int16_t) == 2, "");
73+
return 2;
74+
case ::tflite::TensorType_INT32:
75+
static_assert(sizeof(int32_t) == 4, "");
76+
return 4;
77+
case ::tflite::TensorType_UINT8:
78+
static_assert(sizeof(uint8_t) == 1, "");
79+
return 1;
80+
case ::tflite::TensorType_INT64:
81+
static_assert(sizeof(int64_t) == 8, "");
82+
return 8;
83+
case ::tflite::TensorType_BOOL:
84+
return sizeof(bool);
85+
case ::tflite::TensorType_INT16:
86+
static_assert(sizeof(int16_t) == 2, "");
87+
return 2;
88+
case ::tflite::TensorType_COMPLEX64:
89+
static_assert(sizeof(std::complex<float>) == 8, "");
90+
return 8;
91+
case ::tflite::TensorType_INT8:
92+
static_assert(sizeof(int8_t) == 1, "");
93+
return 1;
94+
case ::tflite::TensorType_FLOAT64:
95+
static_assert(sizeof(double) == 8, "");
96+
return 8;
97+
case ::tflite::TensorType_COMPLEX128:
98+
static_assert(sizeof(std::complex<double>) == 16, "");
99+
return 16;
100+
case ::tflite::TensorType_UINT64:
101+
static_assert(sizeof(uint64_t) == 8, "");
102+
return 8;
103+
case ::tflite::TensorType_UINT32:
104+
static_assert(sizeof(uint32_t) == 4, "");
105+
return 4;
106+
case ::tflite::TensorType_UINT16:
107+
static_assert(sizeof(uint16_t) == 2, "");
108+
return 2;
109+
default:
110+
return 0;
111+
}
112+
}
62113
} // namespace tflite

tensorflow/compiler/mlir/lite/schema/schema_utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_
1616
#define TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_
1717

18+
#include <cstddef>
19+
1820
#include "flatbuffers/flatbuffers.h"
1921
#include "tensorflow/compiler/mlir/lite/schema/schema_generated.h"
2022

@@ -28,6 +30,11 @@ BuiltinOperator GetBuiltinCode(const OperatorCode *op_code);
2830

2931
BuiltinOperator GetBuiltinCode(const OperatorCodeT *op_code);
3032

33+
// Returns the size of the given TensorType in bytes, or 0 if the TensorType is
34+
// not supported, this function should be aligned with TfLiteTypeGetSize in
35+
// lite/kernels/kernel_util.h.
36+
size_t TensorTypeGetSize(::tflite::TensorType data_type);
37+
3138
} // namespace tflite
3239

3340
#endif // TENSORFLOW_COMPILER_MLIR_LITE_SCHEMA_SCHEMA_UTILS_H_

tensorflow/lite/kernels/internal/reference/broadcast_to.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ limitations under the License.
1515
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
1616
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BROADCAST_TO_H_
1717

18+
#include <cstddef>
19+
1820
#include "tensorflow/lite/kernels/internal/common.h"
1921
#include "tensorflow/lite/kernels/kernel_util.h"
2022

@@ -83,7 +85,8 @@ inline void BroadcastTo(const RuntimeShape& unextended_input_shape,
8385
// If non-broadcasting, just copy data from input to output tensor.
8486
if (last_broadcast_dim == -1) {
8587
memcpy(output_data, input_data,
86-
unextended_input_shape.FlatSize() * TfLiteTypeGetSize(data_type));
88+
static_cast<size_t>(unextended_input_shape.FlatSize()) *
89+
static_cast<size_t>(TfLiteTypeGetSize(data_type)));
8790
return;
8891
}
8992

tensorflow/lite/kernels/internal/reference/slice.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ limitations under the License.
1515
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
1616
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_SLICE_H_
1717

18+
#include <cstdint>
19+
#include <vector>
20+
21+
#include "tensorflow/lite/core/c/common.h"
1822
#include "tensorflow/lite/kernels/internal/portable_tensor.h"
23+
#include "tensorflow/lite/kernels/internal/portable_tensor_utils.h"
24+
#include "tensorflow/lite/kernels/internal/runtime_shape.h"
25+
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
1926
#include "tensorflow/lite/kernels/internal/types.h"
2027

2128
namespace tflite {
@@ -74,6 +81,27 @@ inline void Slice(const tflite::SliceParams& op_params,
7481
return Slice(op_params, input_shape, output_shape, &writer);
7582
}
7683

84+
inline void SliceInt4(const tflite::SliceParams& op_params,
85+
const RuntimeShape& input_shape,
86+
const TfLiteTensor* input,
87+
const RuntimeShape& output_shape, TfLiteTensor* output) {
88+
const int num_input_elements = input_shape.FlatSize();
89+
std::vector<int8_t> unpacked_input(num_input_elements);
90+
tensor_utils::UnpackPackedIntToInt8(GetTensorData<int8_t>(input),
91+
num_input_elements, 4,
92+
unpacked_input.data());
93+
94+
const int num_output_elements = output_shape.FlatSize();
95+
std::vector<int8_t> unpacked_output(num_output_elements);
96+
97+
reference_ops::Slice<int8_t>(op_params, input_shape, unpacked_input.data(),
98+
output_shape, unpacked_output.data());
99+
100+
tensor_utils::PackInt8IntoDenseInt(unpacked_output.data(),
101+
num_output_elements, 4,
102+
GetTensorData<int8_t>(output));
103+
}
104+
77105
} // namespace reference_ops
78106
} // namespace tflite
79107

0 commit comments

Comments
 (0)