Skip to content

Commit 5828e10

Browse files
committed
update ort to graph util
1 parent c58130b commit 5828e10

File tree

1 file changed

+169
-19
lines changed

1 file changed

+169
-19
lines changed

plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h

Lines changed: 169 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied
5+
// to other projects.
6+
47
/*
58
SUMMARY:
69
Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider
@@ -75,6 +78,44 @@
7578
// graph_proto stores large initializers in an external file
7679
}
7780
```
81+
82+
EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec):
83+
84+
This example stores initializers externally. However, instead of storing the initializers in a separate
85+
file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's
86+
location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the
87+
initializer's data in memory (instead of an offset into a file).
88+
89+
Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file.
90+
However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data
91+
if it has already been loaded into memory.
92+
93+
```C++
94+
#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
95+
#include "ort_graph_to_proto.h"
96+
97+
OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph,
98+
OrtEpGraphSupportInfo* graph_support_info) {
99+
auto handle_initializer_data = [](const OrtValueInfo* value_info,
100+
const void* data, size_t bytes,
101+
bool& is_external, std::string& location,
102+
int64_t& offset) -> Ort::Status {
103+
(void)value_info;
104+
(void)bytes;
105+
106+
offset = reinterpret_cast<int64_t>(data);
107+
location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer.
108+
is_external = true; // True if is external initializer
109+
return Ort::Status{nullptr};
110+
}
111+
112+
ONNX_NAMESPACE::GraphProto graph_proto;
113+
OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data);
114+
115+
// graph_proto has initializers that look like they are stored in an external file,
116+
// but they are actually pointing to the data in memory.
117+
}
118+
```
78119
*/
79120

80121
#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_
@@ -191,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_
191232
/*out*/ std::vector<int64_t>& dims,
192233
/*out*/ std::vector<std::string>& symbolic_dims);
193234
static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto);
194-
static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
235+
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto);
195236

196237
Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
197238
onnx::GraphProto& graph_proto,
@@ -325,15 +366,20 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
325366
for (const OrtOpAttr* ort_attr : ort_attrs) {
326367
OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED;
327368

328-
Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)};
329-
if (!status.IsOK()) {
330-
// This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it.
369+
Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)};
370+
if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) {
371+
// ORT does not support reading subgraphs via ReadOpAttr(), so skip it.
331372
// Can use Node_GetSubgraphs to get subgraphs.
332373
continue;
333374
}
334375

376+
if (!attr_type_status.IsOK()) {
377+
// Unsupported attribute type.
378+
return attr_type_status;
379+
}
380+
335381
onnx::AttributeProto* attr_proto = node_proto->add_attribute();
336-
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto));
382+
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto));
337383
}
338384
}
339385

@@ -456,11 +502,14 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph,
456502
auto* ext_data_entries = tensor_proto->mutable_external_data();
457503
onnx::StringStringEntryProto* location_entry = ext_data_entries->Add();
458504
onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add();
505+
onnx::StringStringEntryProto* length_entry = ext_data_entries->Add();
459506

460507
location_entry->set_key("location");
461508
location_entry->set_value(ext_location);
462509
offset_entry->set_key("offset");
463510
offset_entry->set_value(std::to_string(ext_offset));
511+
length_entry->set_key("length");
512+
length_entry->set_value(std::to_string(data_bytes));
464513
} else {
465514
// User wants to store data inline the TensorProto's raw_data
466515
tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT);
@@ -578,28 +627,32 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info,
578627
onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type();
579628
type_proto_tensor->set_elem_type(ort_elem_type);
580629

581-
onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape();
630+
// If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks
631+
// like a scalar value.
632+
if (!ort_dims.empty()) {
633+
onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape();
582634

583-
for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) {
584-
onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim();
635+
for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) {
636+
onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim();
585637

586-
if (ort_dims[dim_idx] >= 0) {
587-
dim_proto->set_dim_value(ort_dims[dim_idx]);
588-
} else {
589-
const std::string& dim_param = ort_dim_syms[dim_idx];
638+
if (ort_dims[dim_idx] >= 0) {
639+
dim_proto->set_dim_value(ort_dims[dim_idx]);
640+
} else {
641+
const std::string& dim_param = ort_dim_syms[dim_idx];
590642

591-
// If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set,
592-
// which represents an unknown dimension.
593-
if (!dim_param.empty()) {
594-
dim_proto->set_dim_param(dim_param);
643+
// If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set,
644+
// which represents an unknown dimension.
645+
if (!dim_param.empty()) {
646+
dim_proto->set_dim_param(dim_param);
647+
}
595648
}
596649
}
597650
}
598651

599652
return Ort::Status{nullptr};
600653
}
601654

602-
static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
655+
static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) {
603656
const OrtApi& ort_api = Ort::GetApi();
604657

605658
const char* attr_name = nullptr;
@@ -665,11 +718,11 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr
665718
Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)};
666719
std::string* str = attr_proto.mutable_s();
667720

668-
str->resize(total_attr_bytes, '\0');
721+
str->resize(total_attr_bytes);
669722
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes,
670723
&total_attr_bytes));
671724

672-
str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character.
725+
str->resize(total_attr_bytes);
673726
break;
674727
}
675728
case OrtOpAttrType::ORT_OP_ATTR_STRINGS: {
@@ -705,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr
705758

706759
break;
707760
}
761+
case OrtOpAttrType::ORT_OP_ATTR_TENSOR: {
762+
attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR);
763+
764+
onnx::TensorProto tensor_proto;
765+
766+
// TensorProto as an attribute value doesn't require a name.
767+
768+
OrtValue* ort_value = nullptr;
769+
ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value));
770+
771+
Ort::Value tensor(ort_value);
772+
773+
// Get tensor type and shape info
774+
Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo();
775+
776+
// Get tensor type
777+
ONNXTensorElementDataType element_type = type_shape_info.GetElementType();
778+
779+
size_t element_size = 0;
780+
switch (element_type) {
781+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: {
782+
tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT);
783+
element_size = sizeof(float);
784+
break;
785+
}
786+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: {
787+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8);
788+
element_size = sizeof(uint8_t);
789+
break;
790+
}
791+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: {
792+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8);
793+
element_size = sizeof(int8_t);
794+
break;
795+
}
796+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: {
797+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16);
798+
element_size = sizeof(uint16_t);
799+
break;
800+
}
801+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: {
802+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16);
803+
element_size = sizeof(int16_t);
804+
break;
805+
}
806+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: {
807+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32);
808+
element_size = sizeof(int32_t);
809+
break;
810+
}
811+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: {
812+
tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64);
813+
element_size = sizeof(int64_t);
814+
break;
815+
}
816+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: {
817+
tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL);
818+
element_size = sizeof(bool);
819+
break;
820+
}
821+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: {
822+
tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE);
823+
element_size = sizeof(double);
824+
break;
825+
}
826+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: {
827+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32);
828+
element_size = sizeof(uint32_t);
829+
break;
830+
}
831+
case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: {
832+
tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64);
833+
element_size = sizeof(uint64_t);
834+
break;
835+
}
836+
default: {
837+
std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast<int>(element_type));
838+
return Ort::Status(err_msg.c_str(), ORT_FAIL);
839+
}
840+
}
841+
842+
auto shape = type_shape_info.GetShape();
843+
844+
for (auto& dim : shape) {
845+
tensor_proto.add_dims(dim);
846+
}
847+
848+
size_t element_count = type_shape_info.GetElementCount();
849+
size_t data_bytes = element_count * element_size;
850+
const void* data = tensor.GetTensorData<void>();
851+
852+
// Copy the Ortvalue to TensorProto as raw data
853+
tensor_proto.set_raw_data(data, data_bytes);
854+
855+
*(attr_proto.mutable_t()) = std::move(tensor_proto);
856+
break;
857+
}
708858
default: {
709859
std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast<int>(attr_type));
710860
return Ort::Status(err_msg.c_str(), ORT_FAIL);

0 commit comments

Comments
 (0)