|
1 | 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
2 | 2 | // Licensed under the MIT License. |
3 | 3 |
|
| 4 | +// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied |
| 5 | +// to other projects. |
| 6 | + |
4 | 7 | /* |
5 | 8 | SUMMARY: |
6 | 9 | Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider |
|
75 | 78 | // graph_proto stores large initializers in an external file |
76 | 79 | } |
77 | 80 | ``` |
| 81 | +
|
| 82 | + EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec): |
| 83 | +
|
| 84 | + This example stores initializers externally. However, instead of storing the initializers in a separate |
| 85 | + file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's |
| 86 | + location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the |
| 87 | + initializer's data in memory (instead of an offset into a file). |
| 88 | +
|
| 89 | + Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file. |
| 90 | + However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data |
| 91 | + if it has already been loaded into memory. |
| 92 | +
|
| 93 | + ```C++ |
| 94 | + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL |
| 95 | + #include "ort_graph_to_proto.h" |
| 96 | +
|
| 97 | + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, |
| 98 | + OrtEpGraphSupportInfo* graph_support_info) { |
| 99 | + auto handle_initializer_data = [](const OrtValueInfo* value_info, |
| 100 | + const void* data, size_t bytes, |
| 101 | + bool& is_external, std::string& location, |
| 102 | + int64_t& offset) -> Ort::Status { |
| 103 | + (void)value_info; |
| 104 | + (void)bytes; |
| 105 | +
|
| 106 | + offset = reinterpret_cast<int64_t>(data); |
| 107 | + location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer. |
| 108 | + is_external = true; // True if is external initializer |
| 109 | + return Ort::Status{nullptr}; |
| 110 | + } |
| 111 | +
|
| 112 | + ONNX_NAMESPACE::GraphProto graph_proto; |
| 113 | + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); |
| 114 | +
|
| 115 | + // graph_proto has initializers that look like they are stored in an external file, |
| 116 | + // but they are actually pointing to the data in memory. |
| 117 | + } |
| 118 | + ``` |
78 | 119 | */ |
79 | 120 |
|
80 | 121 | #ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ |
@@ -191,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ |
191 | 232 | /*out*/ std::vector<int64_t>& dims, |
192 | 233 | /*out*/ std::vector<std::string>& symbolic_dims); |
193 | 234 | static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); |
194 | | -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); |
| 235 | +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); |
195 | 236 |
|
196 | 237 | Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, |
197 | 238 | onnx::GraphProto& graph_proto, |
@@ -325,15 +366,20 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, |
325 | 366 | for (const OrtOpAttr* ort_attr : ort_attrs) { |
326 | 367 | OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; |
327 | 368 |
|
328 | | - Ort::Status status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; |
329 | | - if (!status.IsOK()) { |
330 | | - // This is an attribute type that ORT does not support via ReadOpAttr(), like subgraphs, so skip it. |
| 369 | + Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; |
| 370 | + if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { |
| 371 | + // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. |
331 | 372 | // Can use Node_GetSubgraphs to get subgraphs. |
332 | 373 | continue; |
333 | 374 | } |
334 | 375 |
|
| 376 | + if (!attr_type_status.IsOK()) { |
| 377 | + // Unsupported attribute type. |
| 378 | + return attr_type_status; |
| 379 | + } |
| 380 | + |
335 | 381 | onnx::AttributeProto* attr_proto = node_proto->add_attribute(); |
336 | | - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); |
| 382 | + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); |
337 | 383 | } |
338 | 384 | } |
339 | 385 |
|
@@ -456,11 +502,14 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, |
456 | 502 | auto* ext_data_entries = tensor_proto->mutable_external_data(); |
457 | 503 | onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); |
458 | 504 | onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); |
| 505 | + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); |
459 | 506 |
|
460 | 507 | location_entry->set_key("location"); |
461 | 508 | location_entry->set_value(ext_location); |
462 | 509 | offset_entry->set_key("offset"); |
463 | 510 | offset_entry->set_value(std::to_string(ext_offset)); |
| 511 | + length_entry->set_key("length"); |
| 512 | + length_entry->set_value(std::to_string(data_bytes)); |
464 | 513 | } else { |
465 | 514 | // User wants to store data inline the TensorProto's raw_data |
466 | 515 | tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); |
@@ -578,28 +627,32 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, |
578 | 627 | onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); |
579 | 628 | type_proto_tensor->set_elem_type(ort_elem_type); |
580 | 629 |
|
581 | | - onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); |
| 630 | + // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks |
| 631 | + // like a scalar value. |
| 632 | + if (!ort_dims.empty()) { |
| 633 | + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); |
582 | 634 |
|
583 | | - for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { |
584 | | - onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); |
| 635 | + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { |
| 636 | + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); |
585 | 637 |
|
586 | | - if (ort_dims[dim_idx] >= 0) { |
587 | | - dim_proto->set_dim_value(ort_dims[dim_idx]); |
588 | | - } else { |
589 | | - const std::string& dim_param = ort_dim_syms[dim_idx]; |
| 638 | + if (ort_dims[dim_idx] >= 0) { |
| 639 | + dim_proto->set_dim_value(ort_dims[dim_idx]); |
| 640 | + } else { |
| 641 | + const std::string& dim_param = ort_dim_syms[dim_idx]; |
590 | 642 |
|
591 | | - // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, |
592 | | - // which represents an unknown dimension. |
593 | | - if (!dim_param.empty()) { |
594 | | - dim_proto->set_dim_param(dim_param); |
| 643 | + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, |
| 644 | + // which represents an unknown dimension. |
| 645 | + if (!dim_param.empty()) { |
| 646 | + dim_proto->set_dim_param(dim_param); |
| 647 | + } |
595 | 648 | } |
596 | 649 | } |
597 | 650 | } |
598 | 651 |
|
599 | 652 | return Ort::Status{nullptr}; |
600 | 653 | } |
601 | 654 |
|
602 | | -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { |
| 655 | +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { |
603 | 656 | const OrtApi& ort_api = Ort::GetApi(); |
604 | 657 |
|
605 | 658 | const char* attr_name = nullptr; |
@@ -665,11 +718,11 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr |
665 | 718 | Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; |
666 | 719 | std::string* str = attr_proto.mutable_s(); |
667 | 720 |
|
668 | | - str->resize(total_attr_bytes, '\0'); |
| 721 | + str->resize(total_attr_bytes); |
669 | 722 | ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, |
670 | 723 | &total_attr_bytes)); |
671 | 724 |
|
672 | | - str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. |
| 725 | + str->resize(total_attr_bytes); |
673 | 726 | break; |
674 | 727 | } |
675 | 728 | case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { |
@@ -705,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr |
705 | 758 |
|
706 | 759 | break; |
707 | 760 | } |
| 761 | + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { |
| 762 | + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); |
| 763 | + |
| 764 | + onnx::TensorProto tensor_proto; |
| 765 | + |
| 766 | + // TensorProto as an attribute value doesn't require a name. |
| 767 | + |
| 768 | + OrtValue* ort_value = nullptr; |
| 769 | + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); |
| 770 | + |
| 771 | + Ort::Value tensor(ort_value); |
| 772 | + |
| 773 | + // Get tensor type and shape info |
| 774 | + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); |
| 775 | + |
| 776 | + // Get tensor type |
| 777 | + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); |
| 778 | + |
| 779 | + size_t element_size = 0; |
| 780 | + switch (element_type) { |
| 781 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { |
| 782 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); |
| 783 | + element_size = sizeof(float); |
| 784 | + break; |
| 785 | + } |
| 786 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { |
| 787 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); |
| 788 | + element_size = sizeof(uint8_t); |
| 789 | + break; |
| 790 | + } |
| 791 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { |
| 792 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); |
| 793 | + element_size = sizeof(int8_t); |
| 794 | + break; |
| 795 | + } |
| 796 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { |
| 797 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); |
| 798 | + element_size = sizeof(uint16_t); |
| 799 | + break; |
| 800 | + } |
| 801 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { |
| 802 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); |
| 803 | + element_size = sizeof(int16_t); |
| 804 | + break; |
| 805 | + } |
| 806 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { |
| 807 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); |
| 808 | + element_size = sizeof(int32_t); |
| 809 | + break; |
| 810 | + } |
| 811 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { |
| 812 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); |
| 813 | + element_size = sizeof(int64_t); |
| 814 | + break; |
| 815 | + } |
| 816 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { |
| 817 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); |
| 818 | + element_size = sizeof(bool); |
| 819 | + break; |
| 820 | + } |
| 821 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { |
| 822 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); |
| 823 | + element_size = sizeof(double); |
| 824 | + break; |
| 825 | + } |
| 826 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { |
| 827 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); |
| 828 | + element_size = sizeof(uint32_t); |
| 829 | + break; |
| 830 | + } |
| 831 | + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { |
| 832 | + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); |
| 833 | + element_size = sizeof(uint64_t); |
| 834 | + break; |
| 835 | + } |
| 836 | + default: { |
| 837 | + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast<int>(element_type)); |
| 838 | + return Ort::Status(err_msg.c_str(), ORT_FAIL); |
| 839 | + } |
| 840 | + } |
| 841 | + |
| 842 | + auto shape = type_shape_info.GetShape(); |
| 843 | + |
| 844 | + for (auto& dim : shape) { |
| 845 | + tensor_proto.add_dims(dim); |
| 846 | + } |
| 847 | + |
| 848 | + size_t element_count = type_shape_info.GetElementCount(); |
| 849 | + size_t data_bytes = element_count * element_size; |
| 850 | + const void* data = tensor.GetTensorData<void>(); |
| 851 | + |
| 852 | + // Copy the Ortvalue to TensorProto as raw data |
| 853 | + tensor_proto.set_raw_data(data, data_bytes); |
| 854 | + |
| 855 | + *(attr_proto.mutable_t()) = std::move(tensor_proto); |
| 856 | + break; |
| 857 | + } |
708 | 858 | default: { |
709 | 859 | std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast<int>(attr_type)); |
710 | 860 | return Ort::Status(err_msg.c_str(), ORT_FAIL); |
|
0 commit comments