Skip to content

Commit 3965bf2

Browse files
[AvroTensorDataset] Add AvroTensorDataset to allow data conversion from avro to Tensorflow tensors (#1784)
* open source AvroTensorDataset and make it compatible with TF2.12 * Just add basic tests in first PR * Remove additional tests * fix all lint errors * fix header macro * fix pyupgrade lint error * copyright change * copyright to all files to avro/atds * move test utils from tensorflow_io/python/experimental/benchmark to tests/test_atds_avro/utils * fix black error and error on mac * resolve comments * update tests * fix linter
1 parent cc33429 commit 3965bf2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+9871
-3
lines changed

tensorflow_io/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_binary(
1515
"//tensorflow_io/core:bigtable_ops",
1616
"//tensorflow_io/core:audio_video_ops",
1717
"//tensorflow_io/core:avro_ops",
18+
"//tensorflow_io/core:avro_atds",
1819
"//tensorflow_io/core:orc_ops",
1920
"//tensorflow_io/core:cpuinfo",
2021
"//tensorflow_io/core:file_ops",

tensorflow_io/core/BUILD

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,61 @@ cc_library(
484484
alwayslink = 1,
485485
)
486486

487+
cc_library(
488+
name = "avro_atds",
489+
srcs = [
490+
"kernels/avro/atds/atds_decoder.cc",
491+
"kernels/avro/atds/errors.cc",
492+
"kernels/avro/atds_dataset_kernels.cc",
493+
],
494+
hdrs = [
495+
"kernels/avro/atds/atds_decoder.h",
496+
"kernels/avro/atds/avro_block_reader.h",
497+
"kernels/avro/atds/avro_decoder_template.h",
498+
"kernels/avro/atds/decoder_base.h",
499+
"kernels/avro/atds/decompression_handler.h",
500+
"kernels/avro/atds/dense_feature_decoder.h",
501+
"kernels/avro/atds/errors.h",
502+
"kernels/avro/atds/opaque_contextual_feature_decoder.h",
503+
"kernels/avro/atds/shuffle_handler.h",
504+
"kernels/avro/atds/sparse_feature_decoder.h",
505+
"kernels/avro/atds/sparse_feature_internal_decoder.h",
506+
"kernels/avro/atds/sparse_value_buffer.h",
507+
"kernels/avro/atds/varlen_feature_decoder.h",
508+
"kernels/avro/atds_dataset_kernels.h",
509+
],
510+
copts = tf_io_copts(),
511+
linkstatic = True,
512+
deps = [
513+
":avro_ops",
514+
"@avro",
515+
"@local_config_tf//:libtensorflow_framework",
516+
"@local_config_tf//:tf_header_lib",
517+
],
518+
alwayslink = 1,
519+
)
520+
521+
cc_library(
522+
name = "avro_atds_tests",
523+
srcs = [
524+
"kernels/avro/atds/atds_decoder_test.cc",
525+
"kernels/avro/atds/avro_block_reader_test.cc",
526+
"kernels/avro/atds/decoder_test_util.cc",
527+
"kernels/avro/atds/decoder_test_util.h",
528+
"kernels/avro/atds/dense_feature_decoder_test.cc",
529+
"kernels/avro/atds/shuffle_handler_test.cc",
530+
"kernels/avro/atds/sparse_feature_decoder_test.cc",
531+
"kernels/avro/atds/sparse_value_buffer_test.cc",
532+
"kernels/avro/atds/varlen_feature_decoder_test.cc",
533+
],
534+
copts = tf_io_copts(),
535+
deps = [
536+
":avro_atds",
537+
"//tensorflow_io/core:avro_ops",
538+
"@com_google_googletest//:gtest_main",
539+
],
540+
)
541+
487542
cc_library(
488543
name = "orc_ops",
489544
srcs = [
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_io/core/kernels/avro/atds/atds_decoder.h"
17+
18+
#include "api/Generic.hh"
19+
#include "api/Specific.hh"
20+
#include "tensorflow_io/core/kernels/avro/atds/dense_feature_decoder.h"
21+
#include "tensorflow_io/core/kernels/avro/atds/errors.h"
22+
#include "tensorflow_io/core/kernels/avro/atds/opaque_contextual_feature_decoder.h"
23+
#include "tensorflow_io/core/kernels/avro/atds/sparse_feature_decoder.h"
24+
#include "tensorflow_io/core/kernels/avro/atds/varlen_feature_decoder.h"
25+
26+
namespace tensorflow {
27+
namespace atds {
28+
29+
Status ATDSDecoder::Initialize(const avro::ValidSchema& schema) {
30+
auto& root_node = schema.root();
31+
if (root_node->type() != avro::AVRO_RECORD) {
32+
return ATDSNotRecordError(avro::toString(root_node->type()),
33+
schema.toJson());
34+
}
35+
36+
size_t num_of_columns = root_node->leaves();
37+
feature_names_.resize(num_of_columns, "");
38+
decoder_types_.resize(num_of_columns, FeatureType::opaque_contextual);
39+
decoders_.resize(num_of_columns);
40+
41+
for (size_t i = 0; i < dense_features_.size(); i++) {
42+
TF_RETURN_IF_ERROR(
43+
InitializeFeatureDecoder(schema, root_node, dense_features_[i]));
44+
}
45+
46+
for (size_t i = 0; i < sparse_features_.size(); i++) {
47+
TF_RETURN_IF_ERROR(
48+
InitializeFeatureDecoder(schema, root_node, sparse_features_[i]));
49+
}
50+
51+
for (size_t i = 0; i < varlen_features_.size(); i++) {
52+
TF_RETURN_IF_ERROR(
53+
InitializeFeatureDecoder(schema, root_node, varlen_features_[i]));
54+
}
55+
56+
size_t opaque_contextual_index = 0;
57+
for (size_t i = 0; i < num_of_columns; i++) {
58+
if (decoder_types_[i] == FeatureType::opaque_contextual) {
59+
decoders_[i] = std::unique_ptr<DecoderBase>(
60+
new opaque_contextual::FeatureDecoder(opaque_contextual_index++));
61+
62+
auto& opaque_contextual_node = root_node->leafAt(i);
63+
skipped_data_.emplace_back(opaque_contextual_node);
64+
if (opaque_contextual_node->hasName()) {
65+
feature_names_[i] = root_node->leafAt(i)->name();
66+
LOG(WARNING) << "Column '" << feature_names_[i] << "' from input data"
67+
<< " is not used. Cost of parsing an unused column is "
68+
"prohibitive!! "
69+
<< "Consider dropping it to improve I/O performance.";
70+
}
71+
}
72+
}
73+
74+
// Decoder requires unvaried schema in all input files.
75+
// Copy the schema to validate other input files.
76+
schema_ = schema;
77+
78+
return OkStatus();
79+
}
80+
81+
} // namespace atds
82+
} // namespace tensorflow
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_IO_CORE_KERNELS_AVRO_ATDS_DECODER_H_
17+
#define TENSORFLOW_IO_CORE_KERNELS_AVRO_ATDS_DECODER_H_
18+
19+
#include "api/Decoder.hh"
20+
#include "api/GenericDatum.hh"
21+
#include "api/ValidSchema.hh"
22+
#include "tensorflow/core/framework/tensor.h"
23+
#include "tensorflow/core/framework/tensor_shape.h"
24+
#include "tensorflow/core/framework/types.pb.h"
25+
#include "tensorflow/core/platform/status.h"
26+
#include "tensorflow_io/core/kernels/avro/atds/decoder_base.h"
27+
#include "tensorflow_io/core/kernels/avro/atds/dense_feature_decoder.h"
28+
#include "tensorflow_io/core/kernels/avro/atds/errors.h"
29+
#include "tensorflow_io/core/kernels/avro/atds/sparse_feature_decoder.h"
30+
#include "tensorflow_io/core/kernels/avro/atds/varlen_feature_decoder.h"
31+
32+
namespace tensorflow {
33+
namespace atds {
34+
35+
class NullableFeatureDecoder : public DecoderBase {
36+
public:
37+
explicit NullableFeatureDecoder(std::unique_ptr<DecoderBase>& decoder,
38+
size_t non_null_index)
39+
: decoder_(std::move(decoder)), non_null_index_(non_null_index) {}
40+
41+
Status operator()(avro::DecoderPtr& decoder,
42+
std::vector<Tensor>& dense_tensors,
43+
sparse::ValueBuffer& buffer,
44+
std::vector<avro::GenericDatum>& skipped_data,
45+
size_t offset) {
46+
auto index = decoder->decodeUnionIndex();
47+
if (index != non_null_index_) {
48+
return NullValueError();
49+
}
50+
return decoder_->operator()(decoder, dense_tensors, buffer, skipped_data,
51+
offset);
52+
}
53+
54+
private:
55+
std::unique_ptr<DecoderBase> decoder_;
56+
const size_t non_null_index_;
57+
};
58+
59+
class ATDSDecoder {
60+
public:
61+
explicit ATDSDecoder(const std::vector<dense::Metadata>& dense_features,
62+
const std::vector<sparse::Metadata>& sparse_features,
63+
const std::vector<varlen::Metadata>& varlen_features)
64+
: dense_features_(dense_features),
65+
sparse_features_(sparse_features),
66+
varlen_features_(varlen_features) {}
67+
68+
Status Initialize(const avro::ValidSchema&);
69+
70+
Status DecodeATDSDatum(avro::DecoderPtr& decoder,
71+
std::vector<Tensor>& dense_tensors,
72+
sparse::ValueBuffer& buffer,
73+
std::vector<avro::GenericDatum>& skipped_data,
74+
size_t offset) {
75+
// LOG(INFO) << "Decode atds from offset: " << offset;
76+
for (size_t i = 0; i < decoders_.size(); i++) {
77+
Status status = decoders_[i]->operator()(decoder, dense_tensors, buffer,
78+
skipped_data, offset);
79+
if (TF_PREDICT_FALSE(!status.ok())) {
80+
return FeatureDecodeError(feature_names_[i], status.error_message());
81+
}
82+
}
83+
// LOG(INFO) << "Decode atds from offset Done: " << offset;
84+
return OkStatus();
85+
}
86+
87+
const std::vector<avro::GenericDatum>& GetSkippedData() {
88+
return skipped_data_;
89+
}
90+
91+
const avro::ValidSchema& GetSchema() { return schema_; }
92+
93+
private:
94+
template <typename Metadata>
95+
Status InitializeFeatureDecoder(const avro::ValidSchema& schema,
96+
const avro::NodePtr& root_node,
97+
const Metadata& metadata) {
98+
size_t pos;
99+
if (!root_node->nameIndex(metadata.name, pos)) {
100+
return FeatureNotFoundError(metadata.name, schema.toJson());
101+
}
102+
decoder_types_[pos] = metadata.type;
103+
feature_names_[pos] = metadata.name;
104+
105+
auto& feature_node = root_node->leafAt(pos);
106+
if (feature_node->type() == avro::AVRO_UNION) {
107+
size_t non_null_index = 0;
108+
size_t num_union_types = feature_node->leaves();
109+
110+
if (num_union_types == 2 &&
111+
feature_node->leafAt(0)->type() == avro::AVRO_NULL) {
112+
non_null_index = 1;
113+
}
114+
115+
if (num_union_types == 1 || num_union_types == 2) {
116+
auto& non_null_feature_node = feature_node->leafAt(non_null_index);
117+
TF_RETURN_IF_ERROR(ValidateSchema(non_null_feature_node, metadata));
118+
std::unique_ptr<DecoderBase> decoder_base =
119+
CreateFeatureDecoder(non_null_feature_node, metadata);
120+
decoders_[pos] = std::unique_ptr<DecoderBase>(
121+
new NullableFeatureDecoder(decoder_base, non_null_index));
122+
} else {
123+
std::ostringstream oss;
124+
feature_node->printJson(oss, 0);
125+
return InvalidUnionTypeError(metadata.name, oss.str());
126+
}
127+
} else {
128+
TF_RETURN_IF_ERROR(ValidateSchema(feature_node, metadata));
129+
decoders_[pos] = CreateFeatureDecoder(feature_node, metadata);
130+
}
131+
132+
return OkStatus();
133+
}
134+
135+
const std::vector<dense::Metadata>& dense_features_;
136+
const std::vector<sparse::Metadata>& sparse_features_;
137+
const std::vector<varlen::Metadata>& varlen_features_;
138+
139+
std::vector<string> feature_names_;
140+
std::vector<std::unique_ptr<DecoderBase>> decoders_;
141+
std::vector<FeatureType> decoder_types_;
142+
143+
std::vector<avro::GenericDatum> skipped_data_;
144+
avro::ValidSchema schema_;
145+
};
146+
147+
} // namespace atds
148+
} // namespace tensorflow
149+
150+
#endif // TENSORFLOW_IO_CORE_KERNELS_AVRO_ATDS_DECODER_H_

0 commit comments

Comments
 (0)