diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml new file mode 100644 index 000000000..a25dd294c --- /dev/null +++ b/.github/workflows/windows_tensorrt.yml @@ -0,0 +1,242 @@ +name: Windows GPU TensorRT CI Pipeline + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: + name: Windows GPU TensorRT CI Pipeline + runs-on: windows-2022 + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + submodules: 'none' + + - uses: actions/setup-python@v6 + with: + python-version: '3.12' + architecture: x64 + + - name: Download CUDA SDK v12.2 + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" . + dir + shell: pwsh + + - name: Download TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8 + run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' + shell: pwsh + + - name: Add CUDA to PATH + shell: powershell + run: | + Write-Host "Adding CUDA to PATH" + Write-Host "CUDA Path: $env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\extras\CUPTI\lib64" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" + + - uses: actions/setup-node@v5 + with: + node-version: '20.x' + + - uses: actions/setup-java@v5 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - uses: actions/cache@v4 + id: onnx-node-tests-cache + with: + path: ${{ github.workspace }}/js/test/ + key: onnxnodetests-${{ hashFiles('js/scripts/prepare-onnx-node-tests.ts') }} + + - name: API Documentation Check and generate + run: | + set ORT_DOXY_SRC=${{ github.workspace }} + set ORT_DOXY_OUT=${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo + mkdir %ORT_DOXY_SRC% + mkdir %ORT_DOXY_OUT% + "C:\Program Files\doxygen\bin\doxygen.exe" ${{ github.workspace }}\tools\ci_build\github\Doxyfile_csharp.cfg + working-directory: ${{ github.workspace }} + shell: cmd + + - uses: actions/setup-dotnet@v5 + env: + PROCESSOR_ARCHITECTURE: x64 + with: + dotnet-version: '8.x' + + - name: Use Nuget 6.x + uses: nuget/setup-nuget@v2 + with: + nuget-version: '6.x' + + - name: NuGet restore + run: nuget restore ${{ github.workspace }}\packages.config -ConfigFile ${{ github.workspace }}\NuGet.config -PackagesDirectory ${{ runner.temp }}\build\RelWithDebInfo + shell: cmd + + - name: Set OnnxRuntimeBuildDirectory + shell: pwsh + run: | + $buildDir = Join-Path ${{ runner.temp }} "build" + echo "OnnxRuntimeBuildDirectory=$buildDir" >> $env:GITHUB_ENV + + - name: Build and Clean Binaries + working-directory: ${{ runner.temp }} + run: | + npm install -g typescript + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + # Execute the build process + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --build --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + # Clean up the output directory before uploading artifacts + $outputDir = "${{ runner.temp }}\build\RelWithDebInfo" + Write-Host "Cleaning up files from $outputDir..." + + Remove-Item -Path "$outputDir\onnxruntime" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\pybind11" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\models" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\vcpkg_installed" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\_deps" -Recurse -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\CMakeCache.txt" -Force -ErrorAction SilentlyContinue + Remove-Item -Path "$outputDir\CMakeFiles" -Recurse -Force -ErrorAction SilentlyContinue + # Remove intermediate object files as in the original script + Remove-Item -Path $outputDir -Include "*.obj" -Recurse + shell: pwsh + + - name: Upload build artifacts + uses: actions/upload-artifact@v4 + with: + name: build-artifacts + path: ${{ runner.temp }}\build + env: + OrtPackageId: Microsoft.ML.OnnxRuntime.Gpu + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: false + ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 + + test: + name: Windows GPU TensorRT CI Pipeline Test Job + needs: build + timeout-minutes: 300 + runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Win2022-GPU-A10"] + steps: + - uses: actions/checkout@v5 + with: + fetch-depth: 0 + submodules: 'none' + + - name: Download build artifacts + uses: actions/download-artifact@v5 + with: + name: build-artifacts + path: ${{ runner.temp }}\build + + - uses: actions/setup-python@v6 + with: + python-version: '3.12' + architecture: x64 + + - uses: actions/setup-node@v5 + with: + node-version: '20.x' + + - uses: actions/setup-java@v5 + with: + distribution: 'temurin' + java-version: '17' + architecture: x64 + + - name: Locate vcvarsall and Setup Env + uses: ./.github/actions/locate-vcvarsall-and-setup-env + with: + architecture: x64 + + - name: Install python modules + run: python -m pip install -r .\tools\ci_build\github\windows\python\requirements.txt + working-directory: ${{ github.workspace }} + shell: cmd + + - name: Download CUDA SDK v12.2 + working-directory: ${{ runner.temp }} + run: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v12.2" . + dir + shell: pwsh + + - name: Download TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8 + run: 'azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" ${{ runner.temp }}' + shell: pwsh + + - name: Add CUDA to PATH + shell: powershell + run: | + Write-Host "Adding CUDA to PATH" + Write-Host "CUDA Path: $env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\bin" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.2\extras\CUPTI\lib64" + Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" + + - name: Set OnnxRuntimeBuildDirectory + shell: pwsh + run: | + $buildDir = Join-Path ${{ runner.temp }} "build" + echo "OnnxRuntimeBuildDirectory=$buildDir" >> $env:GITHUB_ENV + + - name: Install ONNX Runtime Wheel + uses: ./.github/actions/install-onnxruntime-wheel + with: + whl-directory: ${{ runner.temp }}\build\RelWithDebInfo\RelWithDebInfo\dist + + - name: Run Tests + working-directory: ${{ runner.temp }} + run: | + npm install -g typescript + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + python.exe ${{ github.workspace }}\tools\python\update_ctest_path.py "${{ runner.temp }}\build\RelWithDebInfo\CTestTestfile.cmake" "${{ runner.temp }}\build\RelWithDebInfo" + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + + python ${{ github.workspace }}\tools\ci_build\build.py --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir build --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="${{ runner.temp }}\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8" --cuda_home="${{ runner.temp }}\v12.2" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 + if ($lastExitCode -ne 0) { + exit $lastExitCode + } + shell: pwsh + + - name: Validate C# native delegates + run: python tools\ValidateNativeDelegateAttributes.py + working-directory: ${{ github.workspace }}\csharp + shell: cmd + env: + OrtPackageId: Microsoft.ML.OnnxRuntime.Gpu + DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '0' + DocUpdateNeeded: false + ONNXRUNTIME_TEST_GPU_DEVICE_ID: '0' + AZCOPY_AUTO_LOGIN_TYPE: MSI + AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 diff --git a/plugin_execution_providers/tensorrt/CMakeLists.txt b/plugin_execution_providers/tensorrt/CMakeLists.txt new file mode 100644 index 000000000..85e6ca9fb --- /dev/null +++ b/plugin_execution_providers/tensorrt/CMakeLists.txt @@ -0,0 +1,160 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/path/to/ort_package/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/path/to/tensorrt/TensorRT-10.3.0.26 -DCMAKE_POSITION_INDEPENDENT_CODE=ON (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake --build ./ --config Debug +cmake_minimum_required(VERSION 3.26) +project(TensorRTEp VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) + +enable_language(CUDA) # via nvcc to get the CUDA tool kit +file(TO_CMAKE_PATH "/usr/local/cuda" CUDAToolkit_ROOT) +find_package(CUDAToolkit REQUIRED) + +# CMake config to force dynamic debug CRT or dynamic release CRT globally for all dependencies. +# This is to address the issue of: +# libprotobufd.lib(common.obj) : error LNK2038: mismatch detected for 'RuntimeLibrary': value 'MTd_StaticDebug' doesn't match value 'MDd_DynamicDebug' in unary_elementwise_ops_impl.obj +if (WIN32) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebugDLL" CACHE STRING "" FORCE) # /MDd + set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime + endif() + + if(CMAKE_BUILD_TYPE STREQUAL "Release") + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDLL" CACHE STRING "" FORCE) + set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime + endif() +endif() + +add_definitions(-DONNX_NAMESPACE=onnx) +add_definitions(-DONNX_ML) +add_definitions(-DNOMINMAX) +file(GLOB tensorrt_src "./*.cc" "./utils/*.cc" "./cuda/unary_elementwise_ops_impl.cu" "./*.h") +add_library(TensorRTEp SHARED ${tensorrt_src}) + +if (NOT ORT_HOME) + message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/") +endif() + +if (NOT TENSORRT_HOME) + message(FATAL_ERROR "Please specify TENSORRT_HOME, e.g. -DTENSORRT_HOME=/path/to/trt/") +endif() + +# Use release mode if not specified +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +# Add dependencies +include(FetchContent) + +# Add protobuf +FetchContent_Declare( + protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git + GIT_TAG v21.12 # Use a specific tag or commit +) + +if (WIN32) + # Sometimes, protobuf ignores CMAKE_MSVC_RUNTIME_LIBRARY. To ensure it works: + set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "" FORCE) +endif() + +FetchContent_MakeAvailable(protobuf) + +# Add ONNX +FetchContent_Declare( + onnx + GIT_REPOSITORY https://github.com/onnx/onnx.git + GIT_TAG v1.18.0 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(onnx) + +# Add GSL +FetchContent_Declare( + gsl + GIT_REPOSITORY https://github.com/microsoft/GSL.git + GIT_TAG v4.0.0 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(gsl) + +# Add flatbuffers +FetchContent_Declare( + flatbuffers + GIT_REPOSITORY https://github.com/google/flatbuffers.git + GIT_TAG v23.5.26 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(flatbuffers) + +set(DEPS_PATH "${CMAKE_BINARY_DIR}/_deps") + +if (WIN32) # Windows + set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.lib") + set(TRT_LIBS "${TENSORRT_HOME}/lib/nvinfer_10.lib" + "${TENSORRT_HOME}/lib/nvinfer_plugin_10.lib" + "${TENSORRT_HOME}/lib/nvonnxparser_10.lib") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") + endif() + + set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/${CMAKE_BUILD_TYPE}/flatbuffers.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") + + set(TRT_EP_LIB_LINK_FLAG + "-DEF:${CMAKE_SOURCE_DIR}/tensorrt_execution_provider.def") + +else() # Linux + set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so") + set(TRT_LIBS "${TENSORRT_HOME}/lib/libnvinfer.so" + "${TENSORRT_HOME}/lib/libnvinfer_plugin.so" + "${TENSORRT_HOME}/lib/libnvonnxparser.so") + set(DEPS_LIBS "${DEPS_PATH}/flatbuffers-build/libflatbuffers.a" + "${DEPS_PATH}/onnx-build/libonnx.a" + "${DEPS_PATH}/onnx-build/libonnx_proto.a") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobufd.a" + "${DEPS_PATH}/protobuf-build/libprotocd.a") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobuf.a" + "${DEPS_PATH}/protobuf-build/libprotoc.a") + endif() +endif() + +MESSAGE(STATUS "Looking for following dependencies ...") +MESSAGE(STATUS "ORT lib : ${ORT_LIB}") +MESSAGE(STATUS "TRT libs : ${TRT_LIBS}") +MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") + +set_property(TARGET TensorRTEp APPEND_STRING PROPERTY LINK_FLAGS + ${TRT_EP_LIB_LINK_FLAG}) + +target_include_directories(TensorRTEp PUBLIC "${ORT_HOME}/include" + "./utils" + "/usr/local/cuda/include" + "${TENSORRT_HOME}/include" + "${DEPS_PATH}/flatbuffers-src/include" + "${DEPS_PATH}/gsl-src/include" # GSL is header-only + "${DEPS_PATH}/onnx-src" + "${DEPS_PATH}/onnx-build" + "${DEPS_PATH}/protobuf-src/src" +) + +target_link_libraries(TensorRTEp PUBLIC #${DEPS_LIBS} + protobuf::libprotobuf onnx flatbuffers + ${ORT_LIB} + ${TRT_LIBS} + CUDA::cudart +) diff --git a/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh new file mode 100644 index 000000000..7d05c54b5 --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda/cu_inc/unary_elementwise_impl.cuh @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace cuda { + +// We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer +// For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. +#ifndef CUDA_LONG +#define CUDA_LONG int32_t +#endif + +template +inline __host__ __device__ INT CeilDiv(INT a, INT2 b) // ceil(a/b) +{ + return (INT)(((size_t)a + (size_t)b - 1) / (size_t)b); // these size_t casts are necessary since b may be INT_MAX (for maxGridSize[]) +} + +struct GridDim { + enum : CUDA_LONG { + maxThreadsPerBlock = 256, // max threads per block + maxElementsPerThread = 4, // max element processed per thread + }; +}; + +template +__global__ void _UnaryElementWise( + const InT* input_data, + OutT* output_data, + const FuncT functor, + CUDA_LONG N) { + CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + InT value[NumElementsPerThread]; + + CUDA_LONG id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + value[i] = input_data[id]; + id += NumThreadsPerBlock; + } + } + + id = start; +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + output_data[id] = functor(value[i]); + id += NumThreadsPerBlock; + } + } +} + +template +void UnaryElementWiseImpl( + cudaStream_t stream, + const InT* input_data, + OutT* output_data, + const FuncT& func, + size_t count) { + if (count == 0) // special case where there's a dim value of 0 in the shape + return; + + int blocksPerGrid = static_cast(CeilDiv(count, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + CUDA_LONG N = static_cast(count); + _UnaryElementWise + <<>>( + input_data, + output_data, + func, + N); +} + +} // namespace cuda diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu new file mode 100644 index 000000000..9d4887520 --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.cu @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "cu_inc/unary_elementwise_impl.cuh" + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 +#include "cuda_fp8.h" +#endif +#include + +namespace cuda { + +// the postfix of means the types supported by the op: +// B: uint8_t +// W: uint16_t +// U: uint32_t +// Z: uint64_t +// C: int8_t +// S: int16_t +// I: int32_t +// L: int64_t +// H: float16 +// F: float +// D: double +// O: bool +// X: BFloat16 + +// When casting, half needs to be converted via float type from most other types +template +struct ViaTypeMap { + typedef T ViaT; +}; + +template <> +struct ViaTypeMap { + typedef float ViaT; +}; + +template +struct OP_Cast { + __device__ __inline__ OutT operator()(const InT& a) const { + const bool any_float16 = std::is_same::value || std::is_same::value; + typedef typename std::conditional::type T; + typedef typename ViaTypeMap::ViaT ViaT; + return (OutT)((ViaT)a); + } +}; + +#define IMPL_CAST_IMPL(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ + UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ + } + +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ + size_t /*count*/) { \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ + } + +#define IMPL_CAST_IMPL_FROM(T) \ + IMPL_CAST_IMPL(T, half) \ + IMPL_CAST_IMPL(T, float) \ + IMPL_CAST_IMPL(T, double) \ + IMPL_CAST_IMPL(T, int8_t) \ + IMPL_CAST_IMPL(T, int16_t) \ + IMPL_CAST_IMPL(T, int32_t) \ + IMPL_CAST_IMPL(T, int64_t) \ + IMPL_CAST_IMPL(T, uint8_t) \ + IMPL_CAST_IMPL(T, uint16_t) \ + IMPL_CAST_IMPL(T, uint32_t) \ + IMPL_CAST_IMPL(T, uint64_t) \ + IMPL_CAST_IMPL(T, bool) \ + // IMPL_CAST_IMPL(T, BFloat16) + +IMPL_CAST_IMPL_FROM(half) +IMPL_CAST_IMPL_FROM(float) +IMPL_CAST_IMPL_FROM(double) +IMPL_CAST_IMPL_FROM(int8_t) +IMPL_CAST_IMPL_FROM(int16_t) +IMPL_CAST_IMPL_FROM(int32_t) +IMPL_CAST_IMPL_FROM(int64_t) +IMPL_CAST_IMPL_FROM(uint8_t) +IMPL_CAST_IMPL_FROM(uint16_t) +IMPL_CAST_IMPL_FROM(uint32_t) +IMPL_CAST_IMPL_FROM(uint64_t) +IMPL_CAST_IMPL_FROM(bool) +// IMPL_CAST_IMPL_FROM(BFloat16) + +} // namespace cuda diff --git a/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h new file mode 100644 index 000000000..1bd241f7b --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda/unary_elementwise_ops_impl.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace cuda { + +// Cast + +#define DECL_IMPL_CAST(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count); + +#define DECL_IMPL_CAST_FROM(T) \ + DECL_IMPL_CAST(T, half) \ + DECL_IMPL_CAST(T, float) \ + DECL_IMPL_CAST(T, double) \ + DECL_IMPL_CAST(T, int8_t) \ + DECL_IMPL_CAST(T, int16_t) \ + DECL_IMPL_CAST(T, int32_t) \ + DECL_IMPL_CAST(T, int64_t) \ + DECL_IMPL_CAST(T, uint8_t) \ + DECL_IMPL_CAST(T, uint16_t) \ + DECL_IMPL_CAST(T, uint32_t) \ + DECL_IMPL_CAST(T, uint64_t) \ + DECL_IMPL_CAST(T, bool) \ + // DECL_IMPL_CAST(T, BFloat16) + +DECL_IMPL_CAST_FROM(half) +DECL_IMPL_CAST_FROM(float) +DECL_IMPL_CAST_FROM(double) +DECL_IMPL_CAST_FROM(int8_t) +DECL_IMPL_CAST_FROM(int16_t) +DECL_IMPL_CAST_FROM(int32_t) +DECL_IMPL_CAST_FROM(int64_t) +DECL_IMPL_CAST_FROM(uint8_t) +DECL_IMPL_CAST_FROM(uint16_t) +DECL_IMPL_CAST_FROM(uint32_t) +DECL_IMPL_CAST_FROM(uint64_t) +DECL_IMPL_CAST_FROM(bool) +// DECL_IMPL_CAST_FROM(BFloat16) + +template +void Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { + Explicit_Impl_Cast(stream, input_data, output_data, count); +} + +} // namespace cuda diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.cc b/plugin_execution_providers/tensorrt/cuda_allocator.cc new file mode 100644 index 000000000..5ad74957a --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda_allocator.cc @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "cuda_allocator.h" + +namespace trt_ep { + +void CUDA_RETURN_IF_ERROR(cudaError_t res); + +void CUDAAllocator::CheckDevice(bool throw_when_fail) const { +#ifndef NDEBUG + // check device to match at debug build + // if it's expected to change, call cudaSetDevice instead of the check + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + assert(current_device == CUDAAllocator::GetDeviceId()); + } else if (throw_when_fail) { + CUDA_RETURN_IF_ERROR(cuda_err); + } +#endif +} + +void CUDAAllocator::SetDevice(bool throw_when_fail) const { + int current_device; + auto cuda_err = cudaGetDevice(¤t_device); + if (cuda_err == cudaSuccess) { + int allocator_device_id = CUDAAllocator::GetDeviceId(); + if (current_device != allocator_device_id) { + cuda_err = cudaSetDevice(allocator_device_id); + } + } + + if (cuda_err != cudaSuccess && throw_when_fail) { + CUDA_RETURN_IF_ERROR(cuda_err); + } +} + +void* CUDAAllocator::Alloc(size_t size) { + SetDevice(true); + CheckDevice(true); + void* p = nullptr; + if (size > 0) { + CUDA_RETURN_IF_ERROR(cudaMalloc((void**)&p, size)); + } + return p; +} + +void CUDAAllocator::Free(void* p) { + SetDevice(false); + CheckDevice(false); // ignore CUDA failure when free + cudaFree(p); // do not throw error since it's OK for cudaFree to fail during shutdown +} + +const OrtMemoryInfo* CUDAAllocator::Info() const { + return mem_info_; +} + +void* CUDAPinnedAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + CUDA_RETURN_IF_ERROR(cudaMallocHost((void**)&p, size)); + } + return p; +} + +void CUDAPinnedAllocator::Free(void* p) { + CUDA_RETURN_IF_ERROR(cudaFreeHost(p)); +} + +const OrtMemoryInfo* CUDAPinnedAllocator::Info() const { + return mem_info_; +} + +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/cuda_allocator.h b/plugin_execution_providers/tensorrt/cuda_allocator.h new file mode 100644 index 000000000..7d3362e58 --- /dev/null +++ b/plugin_execution_providers/tensorrt/cuda_allocator.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "onnxruntime_c_api.h" + +using DeviceId = int16_t; + +namespace trt_ep { + +struct CUDAAllocator : OrtAllocator { + CUDAAllocator(const OrtMemoryInfo* mem_info, DeviceId device_id) : mem_info_(mem_info), device_id_(device_id) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + OrtAllocator::Reserve = nullptr; + OrtAllocator::GetStats = nullptr; + OrtAllocator::AllocOnStream = nullptr; // Allocate memory, handling usage across different Streams. Not used for TRT EP. + } + + void* Alloc(size_t size); + void Free(void* p); + const OrtMemoryInfo* Info() const; + DeviceId GetDeviceId() const { return device_id_; }; + + private: + CUDAAllocator(const CUDAAllocator&) = delete; + CUDAAllocator& operator=(const CUDAAllocator&) = delete; + + void CheckDevice(bool throw_when_fail) const; + void SetDevice(bool throw_when_fail) const; + + DeviceId device_id_; + const OrtMemoryInfo* mem_info_ = nullptr; +}; + +struct CUDAPinnedAllocator : OrtAllocator { + CUDAPinnedAllocator(const OrtMemoryInfo* mem_info) : mem_info_(mem_info) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Alloc(size); }; + OrtAllocator::Free = [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; + OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; + OrtAllocator::Reserve = nullptr; + OrtAllocator::GetStats = nullptr; + OrtAllocator::AllocOnStream = nullptr; + } + + void* Alloc(size_t size); + void Free(void* p); + const OrtMemoryInfo* Info() const; + + DeviceId GetDeviceId() const { return device_id_; }; + + private: + CUDAPinnedAllocator(const CUDAPinnedAllocator&) = delete; + CUDAPinnedAllocator& operator=(const CUDAPinnedAllocator&) = delete; + + DeviceId device_id_ = 0; + const OrtMemoryInfo* mem_info_ = nullptr; +}; + +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/nv_includes.h b/plugin_execution_providers/tensorrt/nv_includes.h new file mode 100644 index 000000000..047f325f4 --- /dev/null +++ b/plugin_execution_providers/tensorrt/nv_includes.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// File to include the required TRT headers with workarounds for warnings we can't fix or not fixed yet. +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced formal parameter +#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::IPluginV2' was declared deprecated +#endif + +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc new file mode 100644 index 000000000..c1d141eff --- /dev/null +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.cc @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "ep_utils.h" +#include "path_string.h" +#include "onnx_ctx_model_helper.h" +#include "onnx/onnx_pb.h" + +namespace trt_ep { +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log, const OrtLogger& ort_default_logger, + const OrtApi* ort_api); + +bool IsAbsolutePath(const std::string& path_string) { +#ifdef _WIN32 + PathString ort_path_string = ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + return path.is_absolute(); +#else + if (!path_string.empty() && path_string[0] == '/') { + return true; + } + return false; +#endif +} + +// Like "../file_path" +bool IsRelativePathToParentPath(const std::string& path_string) { +#ifdef _WIN32 + PathString ort_path_string = ToPathString(path_string); + auto path = std::filesystem::path(ort_path_string.c_str()); + auto relative_path = path.lexically_normal().make_preferred().wstring(); + if (relative_path.find(L"..", 0) != std::string::npos) { + return true; + } + return false; +#else + if (!path_string.empty() && path_string.find("..", 0) != std::string::npos) { + return true; + } + return false; +#endif +} + +/* + * Return the directory where the ep context model locates + */ +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path) { + if (ep_context_file_path.empty()) { + return std::filesystem::path(); + } + std::filesystem::path ctx_path(ep_context_file_path); + if (std::filesystem::is_directory(ep_context_file_path)) { + return ctx_path; + } else { + return ctx_path.parent_path(); + } +} + +bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) { + // The weight-stripped engine cache has the naming of xxx.stripped.engine + return engine_cache_path.stem().extension().string() == ".stripped"; +} + +/* + * Create an EPContext OrtNode from a fused_node + */ +OrtStatus* EPContextNodeHelper::CreateEPContextNode(const std::string& engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string& compute_capability, + const std::string& onnx_model_path, + OrtNode** ep_context_node) { + // Helper to collect input or output names from an array of OrtValueInfo instances. + auto collect_input_output_names = [&](gsl::span value_infos, + std::vector& result) -> OrtStatus* { + size_t num_values = value_infos.size(); + std::vector value_names(num_values); + + for (size_t i = 0; i < num_values; ++i) { + const OrtValueInfo* value_info = value_infos[i]; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_names[i])); + } + + result = std::move(value_names); + return nullptr; + }; + + const char* fused_node_name = nullptr; + + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node_, &fused_node_name)); + + size_t num_fused_node_inputs = 0; + size_t num_fused_node_outputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node_, &num_fused_node_inputs)); + RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node_, &num_fused_node_outputs)); + + std::vector fused_node_inputs(num_fused_node_inputs); + std::vector fused_node_outputs(num_fused_node_outputs); + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node_, fused_node_inputs.data(), fused_node_inputs.size())); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node_, fused_node_outputs.data(), fused_node_outputs.size())); + + std::vector input_names; + std::vector output_names; + + RETURN_IF_ERROR(collect_input_output_names(fused_node_inputs, /*out*/ input_names)); + RETURN_IF_ERROR(collect_input_output_names(fused_node_outputs, /*out*/ output_names)); + + // Create node attributes. The CreateNode() function copies the attributes, so we have to release them. + std::array attributes = {}; + DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); + + RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, sizeof(int64_t), ORT_OP_ATTR_INT, &attributes[0])); + + std::string engine_data_str = ""; + if (embed_mode) { + if (size > 0) { + engine_data_str.assign(engine_data, size); + } + RETURN_IF_ERROR( + ort_api.CreateOpAttr("ep_cache_context", engine_data_str.c_str(), engine_data_str.size(), ORT_OP_ATTR_STRING, &attributes[1])); + } else { + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", engine_cache_path.c_str(), engine_cache_path.size(), ORT_OP_ATTR_STRING, &attributes[1])); + } + + ort_api.CreateOpAttr("hardware_architecture", compute_capability.c_str(), compute_capability.size(), ORT_OP_ATTR_STRING, &attributes[2]); + ort_api.CreateOpAttr("onnx_model_filename", std::filesystem::path(onnx_model_path).filename().string().c_str(), 1, + ORT_OP_ATTR_STRING, &attributes[3]); + + RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, input_names.data(), + input_names.size(), output_names.data(), output_names.size(), + attributes.data(), attributes.size(), ep_context_node)); + + return nullptr; +} + +/* + * Check whether the graph has the EP context node. + * The node can contain the precompiled engine info for TRT EP to directly load the engine. + * + * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc + */ +bool EPContextNodeReader::GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api) { + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + for (size_t i = 0; i < num_nodes; ++i) { + auto node = nodes[i]; + + const char* op_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(node, &op_type)); + if (node != nullptr && std::string(op_type) == "EPContext") { + return true; + } + } + return false; +} + +/* + * The sanity check for EP context contrib op. + */ +OrtStatus* EPContextNodeReader::ValidateEPCtxNode(const OrtGraph* graph) const { + size_t num_nodes = 0; + THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + RETURN_IF_NOT(num_nodes == 1, "Graph contains more than one node."); + + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + const char* op_type = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetOperatorType(nodes[0], &op_type)); + RETURN_IF_NOT(std::string(op_type) == "EPContext", "Node is not an EPContext node."); + + // TODO: Check compute capability and others + + return nullptr; +} + +OrtStatus* EPContextNodeReader::GetEpContextFromGraph(const OrtGraph& graph) { + if (ValidateEPCtxNode(&graph) != nullptr) { + return ort_api.CreateStatus(ORT_EP_FAIL, "It's not a valid EPContext node"); + } + + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); + + auto ort_graph = Ort::ConstGraph(&graph); + std::vector nodes(num_nodes); + nodes = ort_graph.GetNodes(); + + // ValidateEPCtxNode() already checked ENFORCE(num_nodes == 1) + auto& node = nodes[0]; + Ort::ConstOpAttr node_attr; + + // Get "embed_mode" attribute + RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("embed_mode", node_attr)); + RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_INT, "\'embed_mode\' attribute should be integer type."); + + int64_t embed_mode = 0; + RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(embed_mode)); + + // Only make path checks if model not provided as byte buffer + bool make_secure_path_checks = !ort_graph.GetModelPath().empty(); + + if (embed_mode) { + // Get engine from byte stream. + RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr)); + RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'ep_cache_context\' attribute should be string type."); + + std::string context_binary; + RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(context_binary)); + + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()), + static_cast(context_binary.length()))); + + std::string message = "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (!(*trt_engine_)) { + return ort_api.CreateStatus(ORT_EP_FAIL, "TensorRT EP could not deserialize engine from binary data"); + } + + if (weight_stripped_engine_refit_) { + RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr)); + RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'onnx_model_filename\' attribute should be string type."); + std::string onnx_model_filename; + RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(onnx_model_filename)); + std::string placeholder; + RETURN_IF_ERROR(ep_.RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + placeholder, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + (*trt_engine_).get(), + false, // serialize refitted engine to disk + detailed_build_log_)); + } + } else { + // Get engine from cache file. + RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("ep_cache_context", node_attr)); + RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'ep_cache_context\' attribute should be string type."); + std::string cache_path; + RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(cache_path)); + + // For security purpose, in the case of running context model, TRT EP won't allow + // engine cache path to be the relative path like "../file_path" or the absolute path. + // It only allows the engine cache to be in the same directory or sub directory of the context model. + if (IsAbsolutePath(cache_path)) { + std::string message = "For security purpose, the ep_cache_context attribute should be set with a relative path, but it is an absolute path: " + cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + if (IsRelativePathToParentPath(cache_path)) { + std::string message = "The file path in ep_cache_context attribute has '..'. For security purpose, it's not allowed to point outside the directory."; + return ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + // The engine cache and context model (current model) should be in the same directory + std::filesystem::path ctx_model_dir(GetPathOrParentPathOfCtxModel(ep_context_model_path_)); + auto engine_cache_path = ctx_model_dir.append(cache_path); + + std::string message = "[TensorRT EP] GetEpContextFromGraph engine_cache_path: " + engine_cache_path.string(); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + // If it's a weight-stripped engine cache, it needs to be refitted even though the refit flag is not enabled + if (!weight_stripped_engine_refit_) { + weight_stripped_engine_refit_ = IsWeightStrippedEngineCache(engine_cache_path); + } + + // If the serialized refitted engine is present, use it directly without refitting the engine again + if (weight_stripped_engine_refit_) { + const std::filesystem::path refitted_engine_cache_path = GetWeightRefittedEnginePath(engine_cache_path.string()); + if (std::filesystem::exists(refitted_engine_cache_path)) { + std::string message = "[TensorRT EP] " + refitted_engine_cache_path.string() + " exists."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + engine_cache_path = refitted_engine_cache_path.string(); + weight_stripped_engine_refit_ = false; + } + } + + if (!std::filesystem::exists(engine_cache_path)) { + std::string error_msg = + "TensorRT EP can't find engine cache: " + engine_cache_path.string() + + ". Please make sure engine cache is in the same directory or sub-directory of context model."; + return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*trt_engine_)) { + std::string error_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path.string(); + return ort_api.CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + + message = "[TensorRT EP] DeSerialized " + engine_cache_path.string(); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + if (weight_stripped_engine_refit_) { + RETURN_IF_ORT_STATUS_ERROR(node.GetAttributeByName("onnx_model_filename", node_attr)); + RETURN_IF_NOT(node_attr.GetType() == OrtOpAttrType::ORT_OP_ATTR_STRING, "\'onnx_model_filename\' attribute should be string type."); + std::string onnx_model_filename; + RETURN_IF_ORT_STATUS_ERROR(node_attr.GetValue(onnx_model_filename)); + std::string weight_stripped_engine_cache = engine_cache_path.string(); + auto status = ep_.RefitEngine(onnx_model_filename, + onnx_model_folder_path_, + weight_stripped_engine_cache, + make_secure_path_checks, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + (*trt_engine_).get(), + true, // serialize refitted engine to disk + detailed_build_log_); + if (status != nullptr) { + return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); + } + } + } + return nullptr; +} + +/* + * Get the weight-refitted engine cache path from a weight-stripped engine cache path + * + * Weight-stipped engine: + * An engine with weights stripped and its size is smaller than a regualr engine. + * The cache name of weight-stripped engine is TensorrtExecutionProvider_TRTKernel_XXXXX.stripped.engine + * + * Weight-refitted engine: + * An engine that its weights have been refitted and it's simply a regular engine. + * The cache name of weight-refitted engine is TensorrtExecutionProvider_TRTKernel_XXXXX.engine + */ +std::string GetWeightRefittedEnginePath(std::string stripped_engine_cache) { + std::filesystem::path stripped_engine_cache_path(stripped_engine_cache); + std::string refitted_engine_cache_path = stripped_engine_cache_path.stem().stem().string() + ".engine"; + return refitted_engine_cache_path; +} +} // namespace trt_ep \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h new file mode 100644 index 000000000..fcf7b6e59 --- /dev/null +++ b/plugin_execution_providers/tensorrt/onnx_ctx_model_helper.h @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "tensorrt_execution_provider.h" +#include "ep_utils.h" +// #include "nv_includes.h" + +#include +#include +#include +#include + +namespace trt_ep { +bool IsAbsolutePath(const std::string& path_string); +bool IsRelativePathToParentPath(const std::string& path_string); +std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); + +// Class to create an EPContext node from an ORT's fused_node. +// +// Note: The class can be instantiated many times during EP's Compile() as to generate the EPContext nodes from fused_nodes/subgraphs and returns them to ORT via Compile(), +// ORT will end up creating the EPContext model. +class EPContextNodeHelper : public ApiPtrs { + public: + EPContextNodeHelper(TensorrtExecutionProvider& ep, + const OrtGraph* graph, + const OrtNode* fused_node) + : ApiPtrs{static_cast(ep)}, graph_(graph), fused_node_(fused_node) {} + + OrtStatus* CreateEPContextNode(const std::string& engine_cache_path, + char* engine_data, + size_t size, + const int64_t embed_mode, + const std::string& compute_capability, + const std::string& onnx_model_path, + OrtNode** ep_context_node); + + private: + const OrtGraph* graph_ = nullptr; + const OrtNode* fused_node_ = nullptr; +}; + +// Class to read an OrtGraph that contains an EPContext node and get the engine binary accordingly. +class EPContextNodeReader : public ApiPtrs { + public: + EPContextNodeReader(TensorrtExecutionProvider& ep, + const OrtLogger& logger, + std::unique_ptr* trt_engine, + nvinfer1::IRuntime* trt_runtime, + std::string ep_context_model_path, + std::string compute_capability, + bool weight_stripped_engine_refit, + std::string onnx_model_folder_path, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, + bool detailed_build_log) + : ApiPtrs{static_cast(ep)}, + ep_(ep), + logger_(logger), + trt_engine_(trt_engine), + trt_runtime_(trt_runtime), + ep_context_model_path_(ep_context_model_path), + compute_capability_(compute_capability), + weight_stripped_engine_refit_(weight_stripped_engine_refit), + onnx_model_folder_path_(onnx_model_folder_path), + onnx_model_bytestream_(onnx_model_bytestream), + onnx_model_bytestream_size_(onnx_model_bytestream_size), + onnx_external_data_bytestream_(onnx_external_data_bytestream), + onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size), + detailed_build_log_(detailed_build_log) { + } + + static bool GraphHasCtxNode(const OrtGraph* graph, const OrtApi& ort_api); + + OrtStatus* ValidateEPCtxNode(const OrtGraph* graph) const; + + OrtStatus* GetEpContextFromGraph(const OrtGraph& graph); + + private: + TensorrtExecutionProvider& ep_; + const OrtLogger& logger_; + std::unique_ptr* trt_engine_; + nvinfer1::IRuntime* trt_runtime_; + std::string ep_context_model_path_; // If using context model, it implies context model and engine cache is in the same directory + std::string compute_capability_; + bool weight_stripped_engine_refit_; + std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_; + size_t onnx_external_data_bytestream_size_; + bool detailed_build_log_; +}; // TRTCacheModelHandler +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h b/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h new file mode 100644 index 000000000..9e4324fb9 --- /dev/null +++ b/plugin_execution_providers/tensorrt/ort_trt_int8_cal_table.fbs.h @@ -0,0 +1,144 @@ +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ +#define FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace CalTableFlatBuffers { + +struct KeyValue; +struct KeyValueBuilder; + +struct TrtTable; +struct TrtTableBuilder; + +struct KeyValue FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef KeyValueBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_KEY = 4, + VT_VALUE = 6 + }; + const flatbuffers::String* key() const { + return GetPointer(VT_KEY); + } + bool KeyCompareLessThan(const KeyValue* o) const { + return *key() < *o->key(); + } + int KeyCompareWithValue(const char* val) const { + return strcmp(key()->c_str(), val); + } + const flatbuffers::String* value() const { + return GetPointer(VT_VALUE); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffsetRequired(verifier, VT_KEY) && + verifier.VerifyString(key()) && + VerifyOffset(verifier, VT_VALUE) && + verifier.VerifyString(value()) && + verifier.EndTable(); + } +}; + +struct KeyValueBuilder { + typedef KeyValue Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_key(flatbuffers::Offset key) { + fbb_.AddOffset(KeyValue::VT_KEY, key); + } + void add_value(flatbuffers::Offset value) { + fbb_.AddOffset(KeyValue::VT_VALUE, value); + } + explicit KeyValueBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + KeyValueBuilder& operator=(const KeyValueBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + fbb_.Required(o, KeyValue::VT_KEY); + return o; + } +}; + +inline flatbuffers::Offset CreateKeyValue( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset key = 0, + flatbuffers::Offset value = 0) { + KeyValueBuilder builder_(_fbb); + builder_.add_value(value); + builder_.add_key(key); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateKeyValueDirect( + flatbuffers::FlatBufferBuilder& _fbb, + const char* key = nullptr, + const char* value = nullptr) { + auto key__ = key ? _fbb.CreateString(key) : 0; + auto value__ = value ? _fbb.CreateString(value) : 0; + return CalTableFlatBuffers::CreateKeyValue( + _fbb, + key__, + value__); +} + +struct TrtTable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TrtTableBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DICT = 4 + }; + const flatbuffers::Vector>* dict() const { + return GetPointer>*>(VT_DICT); + } + bool Verify(flatbuffers::Verifier& verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DICT) && + verifier.VerifyVector(dict()) && + verifier.VerifyVectorOfTables(dict()) && + verifier.EndTable(); + } +}; + +struct TrtTableBuilder { + typedef TrtTable Table; + flatbuffers::FlatBufferBuilder& fbb_; + flatbuffers::uoffset_t start_; + void add_dict(flatbuffers::Offset>> dict) { + fbb_.AddOffset(TrtTable::VT_DICT, dict); + } + explicit TrtTableBuilder(flatbuffers::FlatBufferBuilder& _fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TrtTableBuilder& operator=(const TrtTableBuilder&); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTrtTable( + flatbuffers::FlatBufferBuilder& _fbb, + flatbuffers::Offset>> dict = 0) { + TrtTableBuilder builder_(_fbb); + builder_.add_dict(dict); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateTrtTableDirect( + flatbuffers::FlatBufferBuilder& _fbb, + std::vector>* dict = nullptr) { + auto dict__ = dict ? _fbb.CreateVectorOfSortedTables(dict) : 0; + return CalTableFlatBuffers::CreateTrtTable( + _fbb, + dict__); +} + +} // namespace CalTableFlatBuffers + +#endif // FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_ diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc new file mode 100644 index 000000000..090413398 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.cc @@ -0,0 +1,3801 @@ +#include +#include +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" + +#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL +#include "ort_graph_to_proto.h" + +#include "tensorrt_execution_provider_utils.h" +#include "tensorrt_execution_provider.h" +#include "cuda_allocator.h" +#include "onnx_ctx_model_helper.h" +#include "tensorrt_execution_provider_stream_support.h" +#include "onnx/onnx_pb.h" +#include "cuda/unary_elementwise_ops_impl.h" +#include "ep_utils.h" + +#ifdef _WIN32 +#include +#define LIBTYPE HINSTANCE +#define OPENLIB(libname) LoadLibrary(libname) +#define LIBFUNC(lib, fn) GetProcAddress((lib), (fn)) +#else +#include +#define LIBTYPE void* +#define OPENLIB(libname) dlopen((libname), RTLD_LAZY) +#define LIBFUNC(lib, fn) dlsym((lib), (fn)) +#endif + +const OrtApi* g_ort_api = nullptr; +const OrtEpApi* g_ep_api = nullptr; +const OrtModelEditorApi* g_model_editor_api = nullptr; + +namespace ONNX_NAMESPACE { +using int64s = google::protobuf::RepeatedField; +using float32s = google::protobuf::RepeatedField; +using StringStringEntryProtos = google::protobuf::RepeatedPtrField; +using TensorProtos = google::protobuf::RepeatedPtrField; +using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; +using ValueInfoProtos = google::protobuf::RepeatedPtrField; +using FunctionProtos = google::protobuf::RepeatedPtrField; +} // namespace ONNX_NAMESPACE + +namespace trt_ep { + +void CUDA_RETURN_IF_ERROR(cudaError_t res) { + if (res != cudaSuccess) abort(); +} + +#if NV_TENSORRT_MAJOR >= 10 +void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#else +// Only override this method when TensorRT <= 8.6 +void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/) noexcept { + // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr + // even for empty tensors, so allocate a dummy byte. + size = std::max(size, static_cast(1)); + if (size > allocated_size) { + cudaFree(outputPtr); + outputPtr = nullptr; + allocated_size = 0; + if (cudaMalloc(&outputPtr, size) == cudaSuccess) { + allocated_size = size; + } + } + // if cudaMalloc fails, returns nullptr. + return outputPtr; +} +#endif + +void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { + output_shapes.clear(); + output_shapes.reserve(dims.nbDims); + for (int i = 0; i < dims.nbDims; i++) { + output_shapes.push_back(dims.d[i]); + } +} + +TensorrtLogger& GetTensorrtLogger(bool verbose_log, + const OrtLogger& ort_default_logger, + const OrtApi* ort_api) { + const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; + static TensorrtLogger trt_logger(ort_default_logger, ort_api, log_level); + if (log_level != trt_logger.get_level()) { + trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING); + } + return trt_logger; +} + +std::unique_lock TensorrtExecutionProvider::GetApiLock() const { + static std::mutex singleton; + return std::unique_lock(singleton); +} + +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { + if (!builder_) { + { + auto lock = GetApiLock(); + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + +template +void GetShapeOfShapeTensor(Ort::ConstValue& input_tensor, + void* shape_values, + int shape_size, + cudaStream_t stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(shape_values, + input_tensor.GetTensorData(), + shape_size * sizeof(T), + cudaMemcpyDeviceToHost, + stream)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); +} + +bool ApplyProfileShapesFromProviderOptions(std::vector& trt_profiles, + nvinfer1::ITensor* input, + std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes, + ShapeRangesMap& input_explicit_shape_ranges, + const OrtLogger* logger) { + if (trt_profiles.size() == 0) { + std::string message = "[TensorRT EP] Number of optimization profiles should be greater than 0, but it's 0."; + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + return false; + } + + const std::string& input_name = input->getName(); + if (profile_min_shapes.find(input_name) == profile_min_shapes.end()) { + return false; + } + + if (input_explicit_shape_ranges.find(input_name) == input_explicit_shape_ranges.end()) { + std::unordered_map>> inner_map; + input_explicit_shape_ranges[input_name] = inner_map; + } + + std::string message = "[TensorRT EP] Begin to apply profile shapes ...\n" + + std::string("[TensorRT EP] Input tensor name is '") + input_name + std::string("', number of profiles found is ") + std::to_string(trt_profiles.size()); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + for (size_t i = 0; i < trt_profiles.size(); i++) { + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + auto trt_profile = trt_profiles[i]; + + // Shape tensor + if (input->isShapeTensor()) { + int shape_size = nb_dims == 0 ? 1 : static_cast(profile_min_shapes[input_name][i].size()); + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + + std::string message = "[TensorRT EP] shape size of this shape tensor is " + std::to_string(shape_size); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + for (int j = 0; j < shape_size; j++) { + auto min_value = profile_min_shapes[input_name][i][j]; + auto max_value = profile_max_shapes[input_name][i][j]; + auto opt_value = profile_opt_shapes[input_name][i][j]; + shapes_min[j] = static_cast(min_value); + shapes_max[j] = static_cast(max_value); + shapes_opt[j] = static_cast(opt_value); + std::string message = "[TensorRT EP] shapes_min.d[" + std::to_string(j) + std::string("] is ") + std::to_string(shapes_min[j]) + std::string("\n") + + std::string("[TensorRT EP] shapes_max.d[") + std::to_string(j) + std::string("] is ") + std::to_string(shapes_max[j]) + std::string("\n") + + std::string("[TensorRT EP] shapes_opt.d[") + std::to_string(j) + std::string("] is ") + std::to_string(shapes_opt[j]); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { + std::vector> profile_vector(trt_profiles.size()); + input_explicit_shape_ranges[input_name][j] = profile_vector; + } + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(min_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(max_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(opt_value); + } + + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } + // Execution tensor + else { + nvinfer1::Dims dims_min, dims_opt, dims_max; + dims_min.nbDims = nb_dims; + dims_max.nbDims = nb_dims; + dims_opt.nbDims = nb_dims; + + std::string message = "[TensorRT EP] number of dimension of this execution tensor is " + std::to_string(nb_dims); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + for (int j = 0; j < nb_dims; j++) { + if (dims.d[j] == -1) { + auto min_value = profile_min_shapes[input_name][i][j]; + auto max_value = profile_max_shapes[input_name][i][j]; + auto opt_value = profile_opt_shapes[input_name][i][j]; + dims_min.d[j] = static_cast(min_value); + dims_max.d[j] = static_cast(max_value); + dims_opt.d[j] = static_cast(opt_value); + + std::string message = "[TensorRT EP] dims_min.d[" + std::to_string(j) + std::string("] is ") + std::to_string(dims_min.d[j]) + std::string("\n") + + std::string("[TensorRT EP] dims_max.d[") + std::to_string(j) + std::string("] is ") + std::to_string(dims_max.d[j]) + std::string("\n") + + std::string("[TensorRT EP] dims_opt.d[") + std::to_string(j) + std::string("] is ") + std::to_string(dims_opt.d[j]); + Ort::ThrowOnError(g_ort_api->Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + if (input_explicit_shape_ranges[input_name].find(j) == input_explicit_shape_ranges[input_name].end()) { + std::vector> profile_vector(trt_profiles.size()); + input_explicit_shape_ranges[input_name][j] = profile_vector; + } + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(min_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(max_value); + input_explicit_shape_ranges[input_name][static_cast(j)][i].push_back(opt_value); + } else { + dims_min.d[j] = dims.d[j]; + dims_max.d[j] = dims.d[j]; + dims_opt.d[j] = dims.d[j]; + } + } + + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + } + return true; +} + +OrtStatusPtr ApplyProfileShapesFromInputTensorValue(std::vector& trt_profiles, + Ort::KernelContext ctx, + nvinfer1::ITensor* input, + ShapeRangesMap& shape_ranges, + const std::unordered_map& input_indexes, + std::unordered_map>& shape_tensor_values, + std::unordered_map>& shape_tensor_values_int64, + cudaStream_t stream, + bool* engine_update) { + for (size_t i = 0; i < trt_profiles.size(); i++) { + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + size_t input_index = 0; + const auto& iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + auto& shape_ranges_per_input = shape_ranges[input_name]; + + auto trt_profile = trt_profiles[i]; + + // If there are multiple profiles, for second and rest of profiles, simply copy the min/max/opt profile values from the first profile. + // Following "if statement" won't be executed since TRT EP currently only allows single profile for non-explicit profiles case. + if (i > 0) { + if (input->isShapeTensor()) { + // shape tensor + int shape_size = nb_dims == 0 ? 1 : static_cast(tensor_shapes[0]); + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + for (int j = 0; j < shape_size; j++) { + shapes_min[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN)); + shapes_max[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX)); + shapes_opt[j] = *(trt_profiles[0]->getShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT)); + } + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } else { + // execution tensor + nvinfer1::Dims dims_min, dims_opt, dims_max; + dims_min = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN); + dims_max = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX); + dims_opt = trt_profiles[0]->getDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + continue; + } + + // Create shape profile + if (input->isShapeTensor()) { + // Get shape values for shape tensor input + const auto tensor_type = tensor_info.GetElementType(); + // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + int shape_size = dims.nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); + // For setting TRT optimization profile. (Note: the min/opt/max profile values are still int32 even though int64 is supported after TRT 10) + std::vector values(shape_size); + + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + auto buffer = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); + shape_tensor_values[input_name].resize(shape_size); + for (int j = 0; j < shape_size; ++j) { + shape_tensor_values[input_name][j] = buffer[j]; + values[j] = buffer[j]; + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + auto buffer = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, buffer.get(), shape_size, stream); + shape_tensor_values_int64[input_name].resize(shape_size); + for (int j = 0; j < shape_size; ++j) { + shape_tensor_values_int64[input_name][j] = buffer[j]; + values[j] = static_cast(buffer[j]); + } + break; + } + default: { + return g_ort_api->CreateStatus(OrtErrorCode::ORT_EP_FAIL, std::string("TensorRT shape tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); + } + } + + // Update shape ranges + std::vector shapes_min(shape_size), shapes_opt(shape_size), shapes_max(shape_size); + int shape_range_size = static_cast(shape_ranges_per_input.size()); + if (shape_size == shape_range_size) { + // If shape size matches, check/update shape range + for (int j = 0; j < shape_size; ++j) { + auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile + shapes_min[j] = static_cast(shape_range[0]); + shapes_max[j] = static_cast(shape_range[1]); + shapes_opt[j] = static_cast(shape_range[2]); + + const auto& tensor_shape_value = values[j]; + // Update shape range lower bound + if (tensor_shape_value < shape_range[0]) { + shape_range[0] = tensor_shape_value; + shapes_min[j] = tensor_shape_value; + *engine_update = true; + } + // Update shape range upper bound + if (tensor_shape_value > shape_range[1]) { + shape_range[1] = tensor_shape_value; + shape_range[2] = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + *engine_update = true; + } + } + } else { + // If shape size doesn't match, initialize shape_range with the new shape value + shape_ranges_per_input.clear(); + for (int j = 0; j < shape_size; ++j) { + const auto& tensor_shape_value = values[j]; + std::vector> profile_vector; + std::vector shape_vector{tensor_shape_value, tensor_shape_value, tensor_shape_value}; + profile_vector.push_back(shape_vector); // only one profile needed + shape_ranges_per_input[j] = profile_vector; + shapes_min[j] = tensor_shape_value; + shapes_opt[j] = tensor_shape_value; + shapes_max[j] = tensor_shape_value; + } + *engine_update = true; + } + + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, &shapes_min[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, &shapes_max[0], shape_size); + trt_profile->setShapeValues(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, &shapes_opt[0], shape_size); + } else { // Execution tensor + nvinfer1::Dims dims_min(dims), dims_opt(dims), dims_max(dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + const auto& tensor_shape = tensor_shapes[j]; + if (shape_ranges_per_input.find(j) != shape_ranges_per_input.end()) { + auto& shape_range = shape_ranges_per_input[j][0]; // only has one profile + dims_min.d[j] = static_cast(shape_range[0]); + dims_max.d[j] = static_cast(shape_range[1]); + dims_opt.d[j] = static_cast(shape_range[2]); + + // Update minimum dimension + if (tensor_shape < shape_range[0]) { + shape_range[0] = tensor_shape; + dims_min.d[j] = static_cast(tensor_shape); + *engine_update = true; + } + // Update maximum dimension + if (tensor_shape > shape_range[1]) { + shape_range[1] = tensor_shape; + shape_range[2] = tensor_shape; + dims_max.d[j] = static_cast(tensor_shape); + dims_opt.d[j] = static_cast(tensor_shape); + *engine_update = true; + } + } + } + + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMIN, dims_min); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kMAX, dims_max); + trt_profile->setDimensions(input_name.c_str(), nvinfer1::OptProfileSelector::kOPT, dims_opt); + } + } + return nullptr; +} + +#define CASE_GET_INPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + data = const_cast(input_tensor_ptr); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + data = scratch_buffers.back().get(); \ + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + buffers[output_name] = output_tensor_ptr; \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = static_cast(elem_cnt); \ + } else { \ + scratch_buffers.push_back(MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = 1; \ + } \ + break; \ + } + +#define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \ + } \ + break; \ + } + +#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ + } \ + break; \ + } + +OrtStatusPtr BindContextInput(Ort::KernelContext& ctx, + nvinfer1::ICudaEngine* trt_engine, + nvinfer1::IExecutionContext* trt_context, + const char* input_name, + size_t input_index, + std::unordered_map>& shape_tensor_values, + std::unordered_map>& shape_tensor_values_int64, + std::vector>& scratch_buffers, + OrtAllocator* alloc, + cudaStream_t stream) { + try { + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + const auto tensor_type = tensor_info.GetElementType(); + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + const auto elem_cnt = tensor_info.GetElementCount(); + + if (trt_engine->isShapeInferenceIO(input_name)) { + // Bind "shape tensor" input buffer + + // The shape of the "shape tensor" is either zero dimension (scalar) or 1-dimension + int shape_size = trt_engine->getTensorShape(input_name).nbDims == 0 ? 1 : static_cast(tensor_shapes[0]); + switch (tensor_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + // get shape tensor value if not present + if (shape_tensor_values.find(input_name) == shape_tensor_values.end()) { + auto input = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + shape_tensor_values[input_name].resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + shape_tensor_values[input_name][i] = input[i]; + } + } + + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values[input_name][0])) { + std::string error_input_name = input_name; + std::string error_msg = + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + error_input_name + "'"; + return g_ort_api->CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + // get shape tensor value if not present + if (shape_tensor_values_int64.find(input_name) == shape_tensor_values_int64.end()) { + auto input = std::make_unique(shape_size); + GetShapeOfShapeTensor(input_tensor, input.get(), shape_size, stream); + shape_tensor_values_int64[input_name].resize(shape_size); + for (int i = 0; i < shape_size; ++i) { + shape_tensor_values_int64[input_name][i] = input[i]; + } + } + + if (!trt_context->setTensorAddress(input_name, &shape_tensor_values_int64[input_name][0])) { + std::string error_input_name = input_name; + std::string error_msg = + "TensorRT EP failed to call nvinfer1::IExecutionContext::setTensorAddress() for shape input '" + + error_input_name + "'"; + return g_ort_api->CreateStatus(ORT_EP_FAIL, error_msg.c_str()); + } + break; + } + default: { + std::string error_input_name = input_name; + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("The data type of shape tensor should be INT32 or INT64. Please check the data type of " + error_input_name).c_str()); + } + } + } else { + // Set shape for input tensor which is execution tensor + nvinfer1::Dims dims = trt_context->getTensorShape(input_name); + int nb_dims = dims.nbDims; + for (int j = 0, end = nb_dims; j < end; ++j) { + dims.d[j] = static_cast(tensor_shapes[j]); + } + if (!trt_context->setInputShape(input_name, dims)) { + std::string error_input_name = input_name; + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'").c_str()); + } + + // Bind "execution tensor" input buffer + // + // Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses. + // Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte. + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors + void* data = nullptr; + switch (tensor_type) { + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Cast int64 input to int32 input because TensorRT < 10 doesn't support int64 + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif + // Cast double input to float because TensorRT doesn't support double + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + default: { + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported.").c_str()); + } + } + trt_context->setTensorAddress(input_name, data); + } + } catch (const Ort::Exception& e) { + return g_ort_api->CreateStatus(ORT_EP_FAIL, e.what()); + } + return nullptr; +} + +OrtStatusPtr BindContextOutput(Ort::KernelContext& ctx, + nvinfer1::IExecutionContext* trt_context, + const char* output_name, + size_t output_index, + size_t output_type, + size_t i, + std::unordered_map& output_tensors, + std::unordered_map& output_dim_sizes, + DDSOutputAllocatorMap& dds_output_allocator_map, + std::vector>& scratch_buffers, + OrtAllocator* alloc, + std::unordered_map& buffers) { + // Get output shape + nvinfer1::Dims dims = trt_context->getTensorShape(output_name); + int nb_dims = dims.nbDims; + bool is_DDS = false; + std::vector output_shapes(nb_dims); + for (int j = 0, end = nb_dims; j < end; ++j) { + // data-dependent shape + if (dims.d[j] == -1) { + is_DDS = true; + break; + } + output_shapes[j] = dims.d[j]; + } + + auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end(); + + // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer. + // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output. + // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output, + // which we defer allocation until the size is known and don't call IExecution::setTensorAddress) + // + // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. + if (is_DDS || known_DDS) { + if (!known_DDS) { + auto allocatorPtr = std::make_unique(); + trt_context->setOutputAllocator(output_name, allocatorPtr.get()); + dds_output_allocator_map[output_name] = std::move(allocatorPtr); + } + } else { + try { + output_tensors[i] = ctx.GetOutput(output_index, output_shapes); + auto& output_tensor = output_tensors[i]; + const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + + switch (output_type) { + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // Allocate int32 CUDA memory for int64 output type because TensorRT < 10 doesn't support int64 + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) +#endif + // Allocate float CUDA memory for double output type because TensorRT doesn't support double + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) + default: { + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + } + } + trt_context->setTensorAddress(output_name, buffers[output_name]); + } catch (const Ort::Exception& e) { + return g_ort_api->CreateStatus(ORT_EP_FAIL, e.what()); + } + } + + return nullptr; +} + +OrtStatusPtr BindKernelOutput(Ort::KernelContext& ctx, + const OrtMemoryInfo* /*mem_info*/, + DDSOutputAllocatorMap& allocator_map, + char const* output_name, + size_t output_index, + size_t output_type, + cudaStream_t stream) { + try { + auto allocator = allocator_map[output_name].get(); + auto& shape = allocator->getOutputShape(); + auto output_tensor = ctx.GetOutput(output_index, shape); + + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + + /* + * Copy output data from allocation buffer to ORT kernel context output location or + * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. + * + * Note: + * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, + * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. + * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we + * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, + * and within the same stream, operations are guaranteed to be executed in order. + */ + switch (output_type) { + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) +#if NV_TENSORRT_MAJOR >= 10 + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) +#else + // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. +// CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) +#endif + // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. + // CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) + default: { + return g_ort_api->CreateStatus(ORT_EP_FAIL, std::string("TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported.").c_str()); + } + } + } catch (const Ort::Exception& e) { + return g_ort_api->CreateStatus(ORT_EP_FAIL, e.what()); + } + return nullptr; +} + +bool TensorrtExecutionProvider::AllNodesAssignedToSpecificEP(const OrtGraph* graph, const std::string& provider_type) const { + size_t num_nodes = 0; + THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + // Get all the nodes from the graph + std::vector nodes(num_nodes); + THROW_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + for (const auto node : nodes) { + const char* ep_name; + THROW_IF_ERROR(ort_api.Node_GetEpName(node, &ep_name)); + + if (std::string(ep_name) != provider_type) { + return false; + } + } + + return num_nodes != 0; +} + +// Check the graph is the subgraph of control flow op +bool TensorrtExecutionProvider::IsSubGraphOfControlFlowOp(const OrtGraph* graph) const { + const OrtNode* parent_node = nullptr; + THROW_IF_ERROR(ort_api.Graph_GetParentNode(graph, &parent_node)); + if (parent_node) { + const char* op_type = nullptr; + THROW_IF_ERROR(ort_api.Node_GetOperatorType(parent_node, &op_type)); + + if (control_flow_op_set_.find(std::string(op_type)) != control_flow_op_set_.end()) { + return true; + } + } + return false; +} + +// Check whether all the nodes of subgraph are supported +bool TensorrtExecutionProvider::IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const { + size_t num_nodes = 0; + THROW_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + int number_of_trt_nodes = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + number_of_trt_nodes += static_cast(group.first.size()); + } + } + + return number_of_trt_nodes == num_nodes; +} + +SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollection_t nodes_vector_input, + int iterations, const int max_iterations, + const OrtGraph* graph, bool* early_termination) const { + // Temporarily make all nodes supported + SubGraphCollection_t nodes_list_output = nodes_vector_input; + + return nodes_list_output; +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + TensorrtExecutionProvider* ep = static_cast(this_ptr); + const OrtApi& ort_api = ep->ort_api; + auto ort_graph = Ort::ConstGraph(graph); + + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(graph, &num_nodes)); + + // Get all the nodes from the graph + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(graph, nodes.data(), nodes.size())); + + SubGraphCollection_t parser_nodes_vector, supported_nodes_vector; + bool new_subgraph = true; + + std::unordered_set control_flow_op_set = {"If", "Loop", "Scan"}; + + // Get pre-excluded op list from provider options + auto get_exclude_ops_set = [&](std::string node_list_to_exclude) -> std::set { + std::set set; + if (!node_list_to_exclude.empty()) { + std::stringstream node_list(node_list_to_exclude); + std::string node; + while (std::getline(node_list, node, ',')) { + set.insert(node); + } + } + return set; + }; + + auto exclude_ops_set = get_exclude_ops_set(ep->op_types_to_exclude_); + + /* Iterate all the nodes and exclude the node if: + * 1. It's a control flow op and its subgraph(s) is not fully TRT eligible. + * 2. Its op type is in the exclusion list. + */ + for (size_t index = 0; index < nodes.size(); index++) { + const OrtNode* node = nodes[index]; + bool supported_node = true; + + /* If current node is control flow op, we take different approach based on following four cases: + * + * (1) control flow op is supported by TRT, and its subgraphs are all supported by TRT. Assign this node to TRT. + * (2) control flow op is supported by TRT, but not all its subgraphs supported by TRT. Don't assign this node to TRT. + * (3) control flow op is not supported by TRT, but its subgraphs all supported by TRT. Don't assign this node to TRT. + * (4) control flow op is not supported by TRT, and not all its subgraphs supported by TRT. Don't assign this node to TRT. + * + * For cases 2, 3, 4, even though the control flow op is not assigned to TRT, any portion of its subgraphs that can run in TRT will be still fused and assigned to TRT EP. + */ + const char* op_type = nullptr; + RETURN_IF_ERROR(ep->ort_api.Node_GetOperatorType(node, &op_type)); + + if (control_flow_op_set.find(op_type) != control_flow_op_set.end()) { + auto supported_control_flow_op = [&](const OrtNode* node) { + OrtStatus* status = nullptr; + size_t num_subgraphs = 0; + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetNumSubgraphs(node, &num_subgraphs)); + + std::vector node_subgraphs(num_subgraphs); + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Node_GetSubgraphs(node, node_subgraphs.data(), node_subgraphs.size(), nullptr)); + + // Iterate the node's subgraphs + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* subgraph = node_subgraphs[subgraph_idx]; + + // Get number of subgraph's nodes + size_t num_subgraph_nodes = 0; + RETURN_FALSE_AND_PRINT_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); + + // TRT EP should consider the empty subgraph is fully supported by TRT. + if (num_subgraph_nodes == 0) { + continue; + } + + if (!ep->AllNodesAssignedToSpecificEP(subgraph, ep->name_)) { + // if not all its subgraphs are supported, we need to exclude this control flow op + return false; + } + } + return true; + }; + supported_node = supported_control_flow_op(node); + } + + // Exclude any ops, if applicable + if (exclude_ops_set.find(op_type) != exclude_ops_set.end()) { + supported_node = false; + } + + if (supported_node) { + if (new_subgraph) { + parser_nodes_vector.emplace_back(); + // Mark all new graphs as "UnKnown" which will later be parsed by TRT parser + parser_nodes_vector.back().second = false; + new_subgraph = false; + } + parser_nodes_vector.back().first.emplace_back(index); + } else { + new_subgraph = true; + } + } + + // Use this local definitions for now + // TODO: Use provider option + int max_partition_iterations = 1000; + int min_subgraph_size = 1; + + bool early_termination = false; + supported_nodes_vector = ep->GetSupportedList(parser_nodes_vector, 0, max_partition_iterations, graph, &early_termination); + if (early_termination) { + supported_nodes_vector.clear(); + } + + // Remove subgraphs if its size is less than the predefined minimal size + for (auto it = supported_nodes_vector.begin(); it != supported_nodes_vector.end(); ++it) { + const size_t subgraph_size = it->first.size(); + if (subgraph_size < min_subgraph_size) { + supported_nodes_vector.erase(it--); + } + } + + // TODO: Detect and remove cycles from supported node list + + // TODO: Consolidate supported node list + + // Handle the case where the graph is subgraph of control flow op. + // The purpose is to make control flow op as well as its subgraphs run on TRT. + // Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level. + if (ep->IsSubGraphOfControlFlowOp(graph) && ep->IsSubGraphFullySupported(graph, supported_nodes_vector)) { + // const std::vector& node_index = graph.GetNodesInTopologicalOrder(1); + bool all_subgraphs_are_supported = true; + + // "If" control flow op has two subgraph bodies, "then" body and "else" body respectively. + // Check its parent node's another subgraph to see whether that subgraph is also fully supported by TRT. + Ort::ConstNode parent_node = ort_graph.GetParentNode(); + if (parent_node.GetOperatorType() == "If") { + all_subgraphs_are_supported = false; + SubGraphCollection_t subgraph_supported_nodes_vector; + + std::vector attr_name_subgraphs = parent_node.GetSubgraphs(); + for (auto attr_name_subgraph : attr_name_subgraphs) { + auto subgraph = attr_name_subgraph.sub_graph; + const OrtGraph* subgraph_raw_pointer = subgraph; + if (subgraph_raw_pointer != graph) { + size_t num_subgraph_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(subgraph, &num_subgraph_nodes)); + + // Another subgraph of "If" control flow op has no nodes. + // In this case, TRT EP should consider this empty subgraph is fully supported by TRT. + if (num_subgraph_nodes == 0) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow op has been parsed by GetCapability before and all subgraph's nodes assigned to TRT EP. + else if (ep->AllNodesAssignedToSpecificEP(subgraph, ep->name_)) { + all_subgraphs_are_supported = true; + break; + } + // Another subgraph of "If" control flow has been parsed by GetCapability and not all subgraph's nodes assigned to TRT EP. + // (Note: GetExecutionProviderType() returns "" meaning node has not yet been assigned to any EPs) + else if (!ep->AllNodesAssignedToSpecificEP(subgraph, "")) { + all_subgraphs_are_supported = false; + break; + } + + std::vector subgraph_nodes_vector(num_subgraph_nodes); + std::iota(std::begin(subgraph_nodes_vector), std::end(subgraph_nodes_vector), 0); + SubGraphCollection_t parser_subgraph_nodes_vector = {{subgraph_nodes_vector, false}}; + bool subgraph_early_termination = false; + + // Another subgraph of "If" control flow has not yet been parsed by GetCapability. + subgraph_supported_nodes_vector = ep->GetSupportedList(parser_subgraph_nodes_vector, 0, ep->max_partition_iterations_, subgraph, &subgraph_early_termination); + all_subgraphs_are_supported = ep->IsSubGraphFullySupported(subgraph, subgraph_supported_nodes_vector); + break; + } + } + } + + if (all_subgraphs_are_supported) { + // We want the subgraph nodes to be assigned to TRT EP but don't want them to be fused until later at the control flow op level. + // Simply request the subgraph nodes with a single ComputeCapability for each with no MetaDef (i.e. what the default implementation for IExecutionProvider::GetCapability does). + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + for (const auto& index : group.first) { + const OrtNode* supported_node = nodes[index]; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_node)); + } + } + } + std::string message = "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + return nullptr; + } + } + + int number_of_trt_nodes = 0; + for (const auto& group : supported_nodes_vector) { + if (!group.first.empty()) { + std::vector supported_nodes; + supported_nodes.reserve(group.first.size()); + + for (const auto& index : group.first) { + const OrtNode* supported_node = nodes[index]; + + supported_nodes.push_back(supported_node); + } + + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true as TRT doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, supported_nodes.data(), + supported_nodes.size(), &node_fusion_options)); + number_of_trt_nodes += static_cast(group.first.size()); + } + } + + const size_t number_of_subgraphs = supported_nodes_vector.size(); + if (number_of_trt_nodes == 0) { + std::string message = "[TensorRT EP] No graph will run on TensorRT execution provider"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } else if (number_of_trt_nodes == nodes.size()) { + std::string message = "[TensorRT EP] Whole graph will run on TensorRT execution provider"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } else { + std::string message = "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " + std::to_string(number_of_subgraphs); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + return nullptr; +} + +OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, + const OrtGraph* graph, + const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + /* out */ OrtNodeComputeInfo** node_compute_info, + /* out */ OrtNode** ep_context_node) { + TensorrtExecutionProvider* ep = static_cast(this_ptr); + + // Comment out following code if you want the "large" initializers to be saved to a external file. + /* + //Save initializers to external file + std::string ext_ini_file_path = "model_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path]( + const OrtValueInfo* value_info, const void* data, size_t bytes, bool& is_external, + std::string& location, int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + */ + + // Construct ModelProto from OrtGraph + ONNX_NAMESPACE::ModelProto model_proto; + + // add back handle_initializer_data to save initializer to external file + OrtEpUtils::OrtGraphToProto(*graph, model_proto /*, handle_initializer_data */); + + std::string string_buf; + model_proto.SerializeToString(&string_buf); + + if (dump_subgraphs_) { + // Dump TensorRT subgraphs + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &name)); + std::string subgraph_name = name; + subgraph_name += ".onnx"; + std::fstream dump(subgraph_name, std::ios::out | std::ios::trunc | std::ios::binary); + model_proto.SerializeToOstream(&dump); + } + + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_, logger_, &ort_api); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= (fp16_enable_ || int8_enable_) + ? 0 + : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#else + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); +#endif + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + auto trt_parser = + tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + if (max_workspace_size_ > 0) { + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); + } + + // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (fp16_enable_ && layer_norm_fp32_fallback_) { + for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) { + auto layer = trt_network->getLayer(idx); + auto next_layer = trt_network->getLayer(idx + 1); + if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && + next_layer->getType() == nvinfer1::LayerType::kREDUCE && + (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) { + std::string message = "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + layer->setPrecision(nvinfer1::DataType::kFLOAT); + next_layer->setPrecision(nvinfer1::DataType::kFLOAT); + layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT); + } + } + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + + int num_inputs = trt_network->getNbInputs(); + int num_outputs = trt_network->getNbOutputs(); + std::unordered_map input_indexes(num_inputs); + std::unordered_map output_indexes(num_outputs); + std::unordered_map output_types(num_outputs); + + /* + * Initialize shape range for each dynamic shape input tensor: + * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles + * during EP compile time. It won't make adjustment for profile values during EP compute time. + * + * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, + * INT_MIN, INT_MIN]. Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, + * max_input_value] based on input tensor value. + * + * + * Once the TRT profiles are created: + * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles + * will be applied to TRT builder config and the engine will be built at EP compile time. + * + * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create + * default shape as described above, and all the profiles won't be applied and engine won't be built until EP compute + * time. + */ + bool has_dynamic_shape = + false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false. + bool has_explicit_profile = false; + bool apply_explicit_profile = false; + int num_profiles = 0; + std::vector trt_profiles; + + // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape + // dimension(s) and min/max/opt values for dynamic shape input tensor. + // + // (1) Single profile case: + // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b + // has one dynamic shape dimension: dim_1. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape, max_shape, opt_shape]], + // dim_2: [[min_shape, max_shape, opt_shape]] + // }, + // tensor_b: { + // dim_1: [[min_shape, max_shape, opt_shape]] + // } + // } + // + // (2) Multiple profiles case: + // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: + // dim_1, and both of the tensors have two profiles. The data will be: + // { + // tensor_a: { + // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + // }, + // tensor_b: { + // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + // } + // } + ShapeRangesMap input_explicit_shape_ranges; + ShapeRangesMap input_implicit_shape_ranges; + + if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) { + has_explicit_profile = true; + num_profiles = GetNumProfiles(profile_min_shapes_); + for (int i = 0; i < num_profiles; i++) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } + } + + // Iterate all input tensors to check dynamic shape + for (unsigned int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + nvinfer1::Dims dims = input->getDimensions(); + int nb_dims = dims.nbDims; + + // Apply explicit optimization profiles provided by user + if (has_explicit_profile) { + apply_explicit_profile = + ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, + profile_opt_shapes_, input_explicit_shape_ranges, &ep->logger_); + } + + // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on + // input tensor values at EP compute time + if (!apply_explicit_profile) { + if (input->isShapeTensor()) { + // Shape tensor + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][0] = profile_vector; + has_dynamic_shape = true; + } else { + // Execution tensor + for (int j = 0, end = nb_dims; j < end; ++j) { + if (dims.d[j] == -1) { + std::vector> profile_vector; + std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN}; + profile_vector.push_back(shape_vector); // only one profile needed + input_implicit_shape_ranges[input_name][j] = profile_vector; + has_dynamic_shape = true; + } + } + } + apply_explicit_profile = false; + } + } + + // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user + if (has_explicit_profile) { + // TRT EP has a constraint here. + // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify + // profiles through provider options. + if (has_dynamic_shape) { + std::ostringstream msg; + msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly " + "set profiles through provider options.\n"; + msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also " + "needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n"; + msg << "Following input(s) has no associated shape profiles provided: "; + auto begin = input_implicit_shape_ranges.begin(); + auto end = input_implicit_shape_ranges.end(); + auto it = begin; + if (it != end) { + msg << it->first; + ++it; + } + for (; it != end; ++it) { + msg << "," << it->first; + } + // return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str()); + } else { + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } + } + } + // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default. + // It will later set proper min/max/opt shape values duing EP compute time. + else if (!has_explicit_profile && has_dynamic_shape) { + trt_profiles.push_back(trt_builder->createOptimizationProfile()); + } + + // Check platform availability for low precision + if (fp16_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_builder->platformHasFastFp16()) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + fp16_enable_ = false; + std::string message = "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE or ORT_TENSORRT_BF16_ENABLE is set, but platform doesn't support fast native fp16/bf16"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + if (int8_enable_) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_builder->platformHasFastInt8()) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + int8_enable_ = false; + std::string message = "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + // Load INT8 calibration table + std::unordered_map dynamic_range_map; + if (int8_enable_ && int8_calibration_cache_available_) { + const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_); + if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) { + throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path); + } + } + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &name)); + std::string fused_node_name = name; + + // Set precision flags + std::string trt_node_name_with_precision = fused_node_name; + if (fp16_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + trt_node_name_with_precision += "_fp16"; + std::string message = "[TensorRT EP] FP16 mode is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (int8_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + trt_node_name_with_precision += "_int8"; + std::string message = "[TensorRT EP] INT8 mode is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // Set DLA + if (fp16_enable_ || int8_enable_) { + if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8 + int number_of_dla_core = trt_builder->getNbDLACores(); + if (number_of_dla_core == 0) { + std::string message = "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + dla_enable_ = false; + } else { + if (dla_core_ >= number_of_dla_core) { + std::string message = "[TensorRT EP] Try to use DLA core #" + std::to_string(dla_core_) + + std::string(", but it exceeds platform's maximum DLA core number ") + std::to_string(number_of_dla_core) + + std::string(". Use DLA core 0 instead."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + dla_core_ = 0; + } + std::string message = "[TensorRT EP] use DLA core " + dla_core_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(dla_core_); + trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_); + } + } + } + + // enable sparse weights + if (sparsity_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + std::string message = "[TensorRT EP] Sparse weights are allowed"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + if (build_heuristics_enable_) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + std::string message = "[TensorRT EP] Builder heuristics are enabled." + + std::string(" For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder ") + + std::string("optimization level as 2 to enable builder heuristics."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2 + if (build_heuristics_enable_) { + if (builder_optimization_level_ == 2) { + std::string message = "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization " + std::string("level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } else { + std::string message = "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set " + std::string("builder optimization level as 2 to enable builder heuristics."); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } +#endif + +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (builder_optimization_level_ != 3) { + trt_config->setBuilderOptimizationLevel(builder_optimization_level_); + std::string message = "[TensorRT EP] Builder optimization level is set to " + builder_optimization_level_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + // limit auxiliary streams + if (auxiliary_streams_ >= 0) { + trt_config->setMaxAuxStreams(auxiliary_streams_); + std::string message = "[TensorRT EP] Auxiliary streams are se to " + auxiliary_streams_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#else + if (builder_optimization_level_ != 3) { + std::string message = "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (auxiliary_streams_ >= 0) { + std::string message = "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#endif + + if (weight_stripped_engine_enable_) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + std::string message = "[TensorRT EP] STRIP_PLAN is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + message = "[TensorRT EP] REFIT_IDENTICAL is enabled"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); +#else + std::string message = "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); +#endif + } + + // limit used tactic sources + if (!tactic_sources_.empty()) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= GetTacticSourceFromString(tactic_sources_); + trt_config->setTacticSources(tactics); + std::string message = "[TensorRT EP] Tactic sources are limited using " + tactic_sources_; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + // Build TRT engine (if needed) and load TRT engine if: + // (1) Graph has no dynamic shape input + // (2) All the dynamic shape inputs have associated explicit profiles specified by user + // + // Otherwise engine will be handled at inference time. + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + + std::string cache_path = ""; + std::string cache_suffix = ""; + // Customize cache prefix if assigned + if (!cache_prefix_.empty()) { + // Generate cache suffix in case user would like to customize cache prefix + cache_suffix = "_" + GetCacheSuffix(fused_node_name, trt_node_name_with_precision); + cache_path = GetCachePath(cache_path_, cache_prefix_) + cache_suffix; + } else { + cache_path = GetCachePath(cache_path_, trt_node_name_with_precision); + } + + std::string cache_hw_compat = "_sm" + compute_capability_; +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // Enable hardware compatility mode if assigned + if (engine_cache_enable_ && engine_hw_compatible_) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + cache_hw_compat = "_sm80+"; + std::string message = "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#endif + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if + // they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable_ && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit_ = true; + } + + std::unique_ptr serialized_engine; + + if (!has_dynamic_shape) { + std::string timing_cache_path = ""; + bool engine_update = false; + if (timing_cache_enable_) { + timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_); + } + { + // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs + // lock protection to prevent race condition when inferencing with multithreading. + auto lock = GetApiLock(); + + // If explicit profile flag is on and engine cache enable flag is on, + // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to + // rebuild the engine. + if (has_explicit_profile && engine_cache_enable_) { + engine_update = + CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (engine_update) { + std::string message = "[TensorRT EP] Engine will be built"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } else { + std::string message = "[TensorRT EP] Engine won't be rebuilt"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) { + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + trt_engine = + std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + std::string message = "[TensorRT EP] DeSerialized " + engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (trt_engine == nullptr) { + std::string err_msg = "TensorRT EP could not deserialize engine from cache: " + engine_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + } else if (engine_decryption_enable_ && engine_cache_enable_ && + std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) { + // Decrypt engine + size_t engine_size = 0; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + std::string err_msg = "TensorRT EP could not get engine buffer size"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + std::string err_msg = "TensorRT EP could not call engine decryption function decrypt"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + // Deserialize engine + trt_engine = + std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + std::string message = "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (trt_engine == nullptr) { + std::string err_msg = "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } else { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set INT8 per tensor dynamic range + if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_network, dynamic_range_map)) { + std::string err_msg = "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (timing_cache_enable_) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), + loaded_timing_cache.size())); + if (timing_cache == nullptr) { + std::string err_msg = "TensorRT EP could not create timing cache: " + timing_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + std::string message = "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + // Build engine + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); + } + + serialized_engine = + std::unique_ptr(trt_builder->buildSerializedNetwork(*trt_network, *trt_config)); + + if (serialized_engine == nullptr) { + std::string err_msg = "TensorRT EP failed to create engine from network for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + trt_engine = std::unique_ptr( + runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (trt_engine == nullptr) { + std::string err_msg = "TensorRT EP failed to deserialize engine for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + std::string message = "TensorRT engine build for " + trt_node_name_with_precision + std::string(" took: ") + + std::to_string(std::chrono::duration_cast(engine_build_stop - engine_build_start).count()) + std::string("ms"); + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (engine_cache_enable_) { + // Serialize engine profile if it has explicit profiles + if (has_explicit_profile) { + SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges); + std::string message = "[TensorRT EP] Serialized " + profile_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + if (engine_decryption_enable_) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available + // first. + if (engine_encryption_ != nullptr) { + if (!engine_encryption_(encrypted_engine_cache_path.c_str(), + reinterpret_cast(serialized_engine->data()), serialized_engine->size())) { + std::string err_msg = "TensorRT EP call to engine encryption library failed"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + std::string message = "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } else { + std::string message = "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + std::string message = "[TensorRT EP] Serialized engine " + engine_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + // serialize and save timing cache + if (timing_cache_enable_) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + std::string err_msg = "TensorRT EP could not serialize timing cache: " + timing_cache_path; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + std::string message = "[TensorRT EP] Serialized timing cache " + timing_cache_path; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + } + } + + if (weight_stripped_engine_refit_) { + std::string message = "[TensorRT EP] Refit engine from main ONNX file after engine build"; + Ort::ThrowOnError(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + auto status = RefitEngine(model_path_, + onnx_model_folder_path_, + engine_cache_path, + false /* path check for security */, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + trt_engine.get(), + true /* serialize refitted engine to disk */, detailed_build_log_); + if (status != nullptr) { + return ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); + } + } + + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { + // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. + max_ctx_mem_size_ = 0; + context_memory_ = nullptr; +#if NV_TENSORRT_MAJOR < 10 + trt_context = + std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr( + trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + std::string err_msg = "TensorRT EP could not build execution context for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + + // Create input to index map + // TRT network input -> ORT fused_node input index + for (int i = 0; i < num_inputs; ++i) { + auto input = trt_network->getInput(i); + const std::string& input_name = input->getName(); + const auto& iter = input_map.find(input_name); + if (iter != input_map.end()) { + input_indexes[input_name] = iter->second; + } + } + + // Create output to index and type maps + // TRT network output -> ORT fused_node output index + const auto& graph_output = model_proto.graph().output(); + for (int i = 0; i < num_outputs; ++i) { + const std::string& output_name = trt_network->getOutput(i)->getName(); + const auto& iter = output_map.find(output_name); + if (iter != output_map.end()) { + output_indexes[output_name] = iter->second; + } + const auto& tensor_type = graph_output[i].type().tensor_type(); + output_types[output_name] = tensor_type.elem_type(); + } + + // Save TRT engine, other TRT objects and input/output info to map + parsers_.emplace(fused_node_name, std::move(trt_parser)); + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + networks_.emplace(fused_node_name, std::move(trt_network)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + input_shape_ranges_[fused_node_name] = input_implicit_shape_ranges; + profiles_.emplace(fused_node_name, std::move(trt_profiles)); + + // Create EP Context nodes + std::unique_ptr ep_ctx_node_helper = std::make_unique(*ep, graph, fused_node); + if (dump_ep_context_model_) { + std::string compute_capability_hw_compat = compute_capability_; + if (engine_cache_enable_ && engine_hw_compatible_) { + compute_capability_hw_compat = "80+"; + } + + char* serialized_engine_pointer = nullptr; + size_t serialized_engine_size = 0; + + if (serialized_engine) { + serialized_engine_pointer = reinterpret_cast(serialized_engine->data()); + serialized_engine_size = serialized_engine->size(); + } else if (!serialized_engine && ep_context_embed_mode_ && engine_cache_enable_) { + serialized_engine = std::unique_ptr(trt_engine->serialize()); + serialized_engine_pointer = reinterpret_cast(serialized_engine->data()); + serialized_engine_size = serialized_engine->size(); + } + + ep_ctx_node_helper->CreateEPContextNode(engine_cache_path, + serialized_engine_pointer, + serialized_engine_size, + ep_context_embed_mode_, + compute_capability_hw_compat, + model_path_, + ep_context_node); + } + + std::unique_ptr compute_state = std::make_unique(); + + // translate tactic sources string to nvinfer1::TacticSources + nvinfer1::TacticSources tactics = 0; + if (!tactic_sources_.empty()) { + tactics = GetTacticSourceFromString(tactic_sources_); + } + *compute_state = { + static_cast(device_id_), + fused_node_name, + builder_.get(), + &parsers_[fused_node_name], + &engines_[fused_node_name], + &contexts_[fused_node_name], + &networks_[fused_node_name], + input_info_[fused_node_name], + output_info_[fused_node_name], + input_shape_ranges_[fused_node_name], + &tensorrt_mu_, + compute_capability_, + max_workspace_size_, + fp16_enable_, + int8_enable_, + int8_calibration_cache_available_, + dla_enable_, + dla_core_, + trt_node_name_with_precision, + engine_cache_enable_, + cache_path_, + runtime_.get(), + profiles_[fused_node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &context_memory_, + dynamic_range_map, + engine_decryption_enable_, + engine_decryption_, + engine_encryption_, + timing_cache_enable_, + global_cache_path_, + force_timing_cache_match_, + detailed_build_log_, + build_heuristics_enable_, + sparsity_enable_, + builder_optimization_level_, + auxiliary_streams_, + !tactic_sources_.empty(), + tactics, + cuda_graph_enable_, + weight_stripped_engine_enable_, + weight_stripped_engine_refit_, + model_path_, + onnx_model_folder_path_, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + cache_prefix_, + cache_suffix, + engine_hw_compatible_, + sync_stream_after_enqueue_}; + + ep->compute_states_[fused_node_name] = std::move(compute_state); + + // Update the OrtNodeComputeInfo associated with the graph. + auto ep_node_compute_info = std::make_unique(*ep); + *node_compute_info = ep_node_compute_info.release(); + + return nullptr; +} + +OrtStatus* TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graph, + const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_info) { + TensorrtExecutionProvider* ep = static_cast(this_ptr); + + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.Node_GetName(fused_node, &name)); + std::string fused_node_name = name; + + std::unique_ptr trt_engine; + std::unique_ptr trt_context; + std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index + std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index + std::unordered_map output_types; // TRT engine output name -> ORT output tensor type + + // Get engine binary data and deserialize it + std::unique_ptr ep_context_node_reader = std::make_unique(*ep, + logger_, + &trt_engine, + runtime_.get(), + model_path_, + compute_capability_, + weight_stripped_engine_enable_, + onnx_model_folder_path_, + onnx_model_bytestream_, + onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, + detailed_build_log_); + RETURN_IF_ERROR(ep_context_node_reader->GetEpContextFromGraph(*graph)); + + // Build context + // Note: Creating an execution context from an engine is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + if (context_memory_sharing_enable_) { + // Reset the max_ctx_mem_size_ and context_memory_ since we don't have access to the allocator here. + max_ctx_mem_size_ = 0; + context_memory_ = nullptr; +#if NV_TENSORRT_MAJOR < 10 + trt_context = + std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory()); +#else + trt_context = std::unique_ptr( + trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + trt_context = std::unique_ptr(trt_engine->createExecutionContext()); + } + if (!trt_context) { + std::string err_msg = "TensorRT EP could not build execution context for fused node: " + fused_node_name; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // Create input/output to index maps + // TRT engine input -> ORT fused_node input index + // TRT engine output -> ORT fused_node output index + for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + const auto& iter = input_map.find(name); + if (iter != input_map.end()) { + input_indexes[name] = iter->second; + } + } else { + const auto& iter = output_map.find(name); + if (iter != output_map.end()) { + output_indexes[name] = iter->second; + } + } + } + + // Create output to type map + size_t num_graph_outputs = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(graph, &num_graph_outputs)); + + std::vector graph_outputs(num_graph_outputs); + RETURN_IF_ERROR(ort_api.Graph_GetOutputs(graph, graph_outputs.data(), graph_outputs.size())); + + for (size_t i = 0; i < graph_outputs.size(); i++) { + const OrtValueInfo* value_info = graph_outputs[i]; + + const char* value_info_name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &value_info_name)); + + const OrtTypeInfo* type_info = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(value_info, &type_info)); + + const OrtTensorTypeAndShapeInfo* type_shape = nullptr; + RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(type_info, &type_shape)); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + RETURN_IF_ERROR(ort_api.GetTensorElementType(type_shape, &elem_type)); + + output_types[value_info_name] = elem_type; + } + + // Save TRT engine, TRT context and input/output info to map + engines_.emplace(fused_node_name, std::move(trt_engine)); + contexts_.emplace(fused_node_name, std::move(trt_context)); + input_info_[fused_node_name].push_back(input_indexes); + output_info_[fused_node_name].push_back(output_indexes); + output_info_[fused_node_name].push_back(output_types); + + std::unique_ptr compute_state = std::make_unique(); + + *compute_state = { + static_cast(device_id_), + fused_node_name, + &engines_[fused_node_name], + &contexts_[fused_node_name], + input_info_[fused_node_name], + output_info_[fused_node_name], + context_memory_sharing_enable_, + &max_ctx_mem_size_, + &context_memory_, + &tensorrt_mu_, + sync_stream_after_enqueue_}; + + ep->compute_states_for_ep_context_[fused_node_name] = std::move(compute_state); + + // Update the OrtNodeComputeInfo associated with the graph. + auto ep_node_compute_info = std::make_unique(*ep); + *node_compute_info = ep_node_compute_info.release(); + + return nullptr; +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CompileImpl(_In_ OrtEp* this_ptr, + _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, + _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept { + TensorrtExecutionProvider* ep = static_cast(this_ptr); + const OrtApi& ort_api = ep->ort_api; + + gsl::span node_compute_infos_result(node_compute_infos, count); + gsl::span ep_context_nodes_result(ep_context_nodes, count); + + for (size_t fused_node_idx = 0; fused_node_idx < count; fused_node_idx++) { + auto fused_node = fused_nodes[fused_node_idx]; + + // Gets number of node's inputs and outputs + size_t num_node_inputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(fused_node, &num_node_inputs)); + + std::vector node_inputs(num_node_inputs); + RETURN_IF_ERROR(ort_api.Node_GetInputs(fused_node, node_inputs.data(), node_inputs.size())); + + // Builds map from input name to its index in input list + std::unordered_map input_map; + input_map.reserve(num_node_inputs); + for (size_t i = 0; i < num_node_inputs; i++) { + const OrtValueInfo* value_info = node_inputs[i]; + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name)); + + input_map.emplace(name, i); + } + + // Gets number of node's outputs + size_t num_node_outputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(fused_node, &num_node_outputs)); + + std::vector node_outputs(num_node_outputs); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(fused_node, node_outputs.data(), node_outputs.size())); + + // Builds map from output name to its index in output list + std::unordered_map output_map; + output_map.reserve(num_node_outputs); + for (size_t i = 0; i < num_node_outputs; i++) { + const OrtValueInfo* value_info = node_outputs[i]; + const char* name = nullptr; + RETURN_IF_ERROR(ort_api.GetValueInfoName(value_info, &name)); + + output_map.emplace(name, i); + } + + OrtStatus* status; + if (EPContextNodeReader::GraphHasCtxNode(graphs[fused_node_idx], ort_api)) { + RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromPrecompiledEngine(this_ptr, graphs[fused_node_idx], fused_node, + input_map, output_map, + &node_compute_infos_result[fused_node_idx])); + } else { + RETURN_IF_ERROR(ep->CreateNodeComputeInfoFromGraph(this_ptr, graphs[fused_node_idx], fused_node, input_map, + output_map, &node_compute_infos_result[fused_node_idx], + &ep_context_nodes_result[fused_node_idx])); + } + } + + return nullptr; +} + +const char* ORT_API_CALL TensorrtExecutionProvider::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->name_.c_str(); +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProvider::CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept { + // A per-session OrtSyncStreamImpl can be created here if the session options affect the implementation. + // Logging of any issues should use logger_ which is the session logger. + + TensorrtExecutionProvider* ep = static_cast(this_ptr); + + // we only create streams for the default device memory. + if (auto mem_type = ep->factory_.ep_api.MemoryDevice_GetMemoryType(memory_device); + mem_type != OrtDeviceMemoryType_DEFAULT) { + std::string error = "Invalid OrtMemoryDevice. Expected OrtDeviceMemoryType_DEFAULT(0). Got "; + error += std::to_string(mem_type); + return ep->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, error.c_str()); + } + + auto device_id = ep->factory_.ep_api.MemoryDevice_GetDeviceId(memory_device); + + auto sync_stream = std::make_unique(ep->factory_, ep, device_id, nullptr); + *stream = sync_stream.release(); + + return nullptr; +} + +/** + * Refit the weight-stripped engine + */ +OrtStatus* TensorrtExecutionProvider::RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log) { +#if NV_TENSORRT_MAJOR >= 10 + bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; + bool refit_with_external_data = onnx_external_data_bytestream != nullptr && onnx_external_data_bytestream_size != 0; + bool refit_complete = false; + std::filesystem::path onnx_model_path{onnx_model_folder_path}; + if (refit_from_file) { + if (!onnx_model_filename.empty()) { + onnx_model_path.append(onnx_model_filename); + } + if (onnx_model_path.empty()) { + std::string err_msg = "The ONNX model was not provided as path. Please use provide an ONNX bytestream to enable refitting the weightless engine."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } else { + // check if file path to ONNX is legal + if (path_check && IsAbsolutePath(onnx_model_path.string())) { + std::string err_msg = + "For security purpose, the ONNX model path should be set with a relative path, but it is an absolute path: " + onnx_model_path.string(); + "weightless engine."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + if (path_check && IsRelativePathToParentPath(onnx_model_path.string())) { + std::string err_msg = + "The ONNX model path has '..'. For security purpose, it's not allowed to point outside the directory."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + if (!(std::filesystem::exists(onnx_model_path) && std::filesystem::is_regular_file(onnx_model_path))) { + std::string err_msg = "The ONNX model " + onnx_model_path.string() + " does not exist."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + } + + // weight-stripped engine refit logic + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log, logger_, &ort_api); + auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); + auto parser_refitter = + std::unique_ptr(nvonnxparser::createParserRefitter(*refitter, trt_logger)); + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + // New refit APIs + if (refit_with_external_data) { + // A valid model bytestream must be passed. + if (refit_from_file) { + std::string err_msg = "TensorRT EP's refit with external data must be called with a valid ONNX model bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + if (!parser_refitter->loadModelProto(onnx_model_bytestream, onnx_model_bytestream_size, nullptr)) { + std::string err_msg = "TensorRT EP's IParserRefitter could not load model from provided onnx_model_bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // Extract weight information from the Refitter. + int required_weights = refitter->getAllWeights(0, nullptr); + std::vector refit_names(required_weights); + refitter->getAllWeights(required_weights, refit_names.data()); + std::string message = "[TensorRT EP] Refitter requires " + std::to_string(required_weights) + " weights"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + // Vectors to keep track of data pointers. + std::vector names; + names.reserve(required_weights); + std::vector bytes; + bytes.reserve(required_weights); + std::vector sizes; + sizes.reserve(required_weights); + + auto onnx_model = std::make_unique(); + ONNX_NAMESPACE::TensorProtos* allInitializers_byte_stream; + + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { + std::string err_msg = "The provided ONNX bytestream to refit could not be parsed."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + message = "[TensorRT EP] Initializers that were found " + std::to_string(allInitializers_byte_stream->size()); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + // Loop through all initializers + int missing_initializer_data = 0; + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); + if (weight_is_refittable) { + if (proto.has_data_location()) { + if (proto.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); + } + } + names.push_back(proto.name()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + std::string err_msg = "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } else if (proto.has_raw_data()) { + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); + } else { + message = "[TensorRT EP] Proto: " + proto_name + " has no raw nor external data."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + ++missing_initializer_data; + } + } else { + message = "[TensorRT EP] Initializer with name: " + proto_name + " was not marked as refittable"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + if (missing_initializer_data) { + std::string err_msg = "[TensorRT EP] RefitEngine is missing " + std::to_string(missing_initializer_data) + " initializers."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // Load extracted initializers into the parser + if (!names.empty()) { + message = "[TensorRT EP] Number of initializers submitted to refitter " + std::to_string(names.size()); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + for (size_t i = 0; i < names.size(); i++) { + bool refloadInit = parser_refitter->loadInitializer(names[i].c_str(), bytes[i], sizes[i]); + if (!refloadInit) { + std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + } + // Perform refit. + if (!parser_refitter->refitModelProto()) { + std::string err_msg = "TensorRT EP's IParserRefitter refitModelProto() failed with the provided external data bytestream."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + refit_complete = true; + } +#else + // Refitting with external data is not supported prior to TensorRT 10.13. Log a warning in this case for the user. + if (refit_with_external_data) { + message = "[TensorRT EP] Refitting with an onnx_external_data_bytestream is only supported on TensorRT versions >= 10.13! This parameter will be ignored for refitting, and the resulting refitted engine may be incorrect."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + // If new refit flow was not completed, then fallback to refit_from_file. + if (!refit_complete) { + if (refit_from_file) { + std::string message = "[TensorRT EP] Refitting from file on disk: " + onnx_model_path.string(); + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string(); + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } else { + std::string message = "[TensorRT EP] Refitting from byte array"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + std::string err_msg = "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + } + + if (refitter->refitCudaEngine()) { + std::string message = "[TensorRT EP] Successfully refitted the weight-stripped engine."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } else { + std::string err_msg = + "TensorRT EP's IRefitter could not refit deserialized weight-stripped engine with weights contained in: " + + onnx_model_path.string(); + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // serialize the refitted engine to disk + if (serialize_refitted_engine) { + std::string refitted_engine_cache = GetWeightRefittedEnginePath(weight_stripped_engine_cath_path); + nvinfer1::IHostMemory* serialized_engine = trt_engine->serialize(); + std::ofstream engine_file(refitted_engine_cache, std::ios::binary | std::ios::out); + engine_file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + std::string message = "[TensorRT EP] Serialize the refitted engine to " + refitted_engine_cache; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + return nullptr; +#else + std::string err_msg = "TensorRT EP's IParserRefitter can only be used on TRT 10.0 onwards."; + return ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); +#endif +} + +TensorrtExecutionProvider::~TensorrtExecutionProvider() { + if (alloc_ != nullptr) { + ort_api.ReleaseAllocator(alloc_); + } +} + +/// +/// +/// Plugin TensorRT EP implementing OrtEp +/// +/// +TensorrtExecutionProvider::TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, + const std::string& name, + const OrtSessionOptions& session_options, + const OrtLogger& logger) + : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized + ApiPtrs{static_cast(factory)}, + factory_(factory), + name_{name}, + session_options_{session_options}, + logger_{logger} { + // Implementation of OrtEp interfaces + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + Compile = CompileImpl; + ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; + + // Initialize the execution provider. + + Ort::Status ort_status(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("Plugin EP has been created with name " + name_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__)); + + // populate apis as global for utility functions + g_ort_api = &ort_api; + g_ep_api = &ep_api; + g_model_editor_api = &model_editor_api; + + // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to + // the session option configurations with the key prefix "ep..". + // We extract those EP options to create a new "provider options" key-value map. + std::string lowercase_ep_name = name_.c_str(); + std::transform(lowercase_ep_name.begin(), lowercase_ep_name.end(), lowercase_ep_name.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + // The implementation of the SessionOptionsAppendExecutionProvider C API function automatically adds EP options to + // the session option configurations with the key prefix "ep..". + std::string key_prefix = "ep." + lowercase_ep_name + "."; + + // Get all the provider options as session config from sesson + ProviderOptions provider_options; + + // Get the provider options from all the config entries in session option + OrtKeyValuePairs* key_value_pairs = nullptr; + ort_api.GetSessionOptionsConfigEntries(&session_options, &key_value_pairs); + + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + ort_api.GetKeyValuePairs(key_value_pairs, &keys, &values, &num_entries); + + for (size_t i = 0; i < num_entries; ++i) { + const char* key = keys[i]; + + // only gets ep provider options + if (strncmp(key, key_prefix.c_str(), key_prefix.size()) == 0) { + std::string key_str = key; + const char* value = values[i]; + provider_options[key_str.substr(key_prefix.size())] = value; + } + } + + ort_api.ReleaseKeyValuePairs(key_value_pairs); + + // Provider options to TensorrtExecutionProviderInfo + info_ = TensorrtExecutionProviderInfo::FromProviderOptions(provider_options); + info_.has_trt_options = true; + device_id_ = info_.device_id; + + std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes; + + // incase the EP context is dumped the engine cache has to be enabled + auto enable_engine_cache_for_ep_context_model = [this]() { + if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) { + engine_cache_enable_ = true; + } + }; + + // get provider options + if (info_.has_trt_options) { + max_partition_iterations_ = info_.max_partition_iterations; + min_subgraph_size_ = info_.min_subgraph_size; + max_workspace_size_ = info_.max_workspace_size; + fp16_enable_ = info_.fp16_enable; + int8_enable_ = info_.int8_enable; + if (int8_enable_) { + int8_calibration_cache_name_ = info_.int8_calibration_table_name; + int8_use_native_tensorrt_calibration_table_ = info_.int8_use_native_calibration_table; + } + if (fp16_enable_ || int8_enable_) { // DLA can only be enabled with FP16 or INT8 + dla_enable_ = info_.dla_enable; + dla_core_ = info_.dla_core; + } + dump_subgraphs_ = info_.dump_subgraphs; + engine_cache_enable_ = info_.engine_cache_enable; + weight_stripped_engine_enable_ = info_.weight_stripped_engine_enable; + onnx_model_folder_path_ = info_.onnx_model_folder_path; + onnx_model_bytestream_ = info_.onnx_bytestream; + onnx_model_bytestream_size_ = info_.onnx_bytestream_size; + onnx_external_data_bytestream_ = info_.external_data_bytestream; + onnx_external_data_bytestream_size_ = info_.external_data_bytestream_size; + timing_cache_enable_ = info_.timing_cache_enable; + force_timing_cache_match_ = info_.force_timing_cache; + detailed_build_log_ = info_.detailed_build_log; + dump_ep_context_model_ = info_.dump_ep_context_model; + ep_context_file_path_ = info_.ep_context_file_path; + ep_context_embed_mode_ = info_.ep_context_embed_mode; + enable_engine_cache_for_ep_context_model(); + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + cache_path_ = info_.engine_cache_path; + cache_prefix_ = info_.engine_cache_prefix; + } + // use a more global cache if given + if (timing_cache_enable_) { + if (!info_.timing_cache_path.empty()) { + global_cache_path_ = info_.timing_cache_path; + } else { + global_cache_path_ = cache_path_; + } + } + engine_decryption_enable_ = info_.engine_decryption_enable; + if (engine_decryption_enable_) { + engine_decryption_lib_path_ = info_.engine_decryption_lib_path; + } + force_sequential_engine_build_ = info_.force_sequential_engine_build; + context_memory_sharing_enable_ = info_.context_memory_sharing_enable; + if (fp16_enable_) { + layer_norm_fp32_fallback_ = info_.layer_norm_fp32_fallback; + } + build_heuristics_enable_ = info_.build_heuristics_enable; + sparsity_enable_ = info_.sparsity_enable; + builder_optimization_level_ = info_.builder_optimization_level; + auxiliary_streams_ = info_.auxiliary_streams; + tactic_sources_ = info_.tactic_sources; + profile_min_shapes = info_.profile_min_shapes; + profile_max_shapes = info_.profile_max_shapes; + profile_opt_shapes = info_.profile_opt_shapes; + cuda_graph_enable_ = info_.cuda_graph_enable; + engine_hw_compatible_ = info_.engine_hw_compatible; + op_types_to_exclude_ = info_.op_types_to_exclude; + } else { + // deprecate env provider option + } + + // Validate setting + if (max_partition_iterations_ <= 0) { + std::string message = "[TensorRT EP] TensorRT option trt_max_partition_iterations must be a positive integer value. Set it to 1000"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + max_partition_iterations_ = 1000; + } + if (min_subgraph_size_ <= 0) { + std::string message = "[TensorRT EP] TensorRT option trt_min_subgraph_size must be a positive integer value. Set it to 1"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + min_subgraph_size_ = 1; + } + if (max_workspace_size_ <= 0) { + std::string message = "[TensorRT EP] TensorRT option trt_max_workspace_size must be a positive integer value. Set it to 1073741824 (1GB)"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + max_workspace_size_ = 1 << 30; + } + if (dla_core_ < 0) { + std::string message = "[TensorRT EP] TensorRT option trt_dla_core must be a non-negative integer value. Set it to 0"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + dla_core_ = 0; + } + + // If ep_context_file_path_ is provided as a directory, create it if it's not existed + if (dump_ep_context_model_ && !ep_context_file_path_.empty() && std::filesystem::path(ep_context_file_path_).extension().empty() && !std::filesystem::is_directory(ep_context_file_path_)) { + if (!std::filesystem::create_directory(ep_context_file_path_)) { + throw std::runtime_error("Failed to create directory " + ep_context_file_path_); + } + } + + // If dump_ep_context_model_ is enable, TRT EP forces cache_path_ to be the relative path of ep_context_file_path_. + // For example, + // - original cache path = "engine_cache_dir" -> new cache path = "./context_model_dir/engine_cache_dir" + // - original cache path = "" -> new cache path = "./context_model_dir" + // The new cache path will be saved as the "ep_cache_context" node attritue of the EP context node. + // For security reason, it needs to make sure the engine cache is saved inside context model directory. + if (dump_ep_context_model_ && engine_cache_enable_) { + if (IsAbsolutePath(cache_path_)) { + std::string message = "In the case of dumping context model and for security purpose, the trt_engine_cache_path should be set with a relative path, but it is an absolute path: " + cache_path_; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (IsRelativePathToParentPath(cache_path_)) { + std::string message = "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + // Engine cache relative path to context model directory. + // It's used when dumping the "ep_cache_context" node attribute. + engine_cache_relative_path_to_context_model_dir_ = cache_path_; + + // Make cache_path_ to be the relative path of ep_context_file_path_ + cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string(); + } + + // Hardware compatibility: pre-check on environment + if (engine_cache_enable_ && engine_hw_compatible_) { +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + if (std::stoi(compute_capability_) < 80) { + std::string message = "Engine hardware compatibility cannot be enabled as GPU arch < 80. "; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + engine_hw_compatible_ = false; + } else if (std::stoi(compute_capability_) == 87) { + std::string message = "Engine hardware compatibility cannot be enabled on Jetson Orin. "; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + engine_hw_compatible_ = false; + } +#else + std::string message = "Engine hardware compatibility cannot be enabled as TRT < 8.6. "; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + engine_hw_compatible_ = false; +#endif + } + + if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) { + if (!cache_path_.empty() && !fs::is_directory(cache_path_)) { + if (!fs::create_directory(cache_path_)) { + throw std::runtime_error("Failed to create directory " + cache_path_); + } + } + if (!global_cache_path_.empty() && !fs::is_directory(global_cache_path_)) { + if (!fs::create_directory(global_cache_path_)) { + throw std::runtime_error("Failed to create directory " + global_cache_path_); + } + } + } + + if (engine_decryption_enable_) { + LIBTYPE handle = OPENLIB(engine_decryption_lib_path_.c_str()); + if (handle == nullptr) { + std::string message = "TensorRT EP could not open shared library from " + engine_decryption_lib_path_; + THROW_IF_ERROR(ort_api.CreateStatus(ORT_EP_FAIL, message.c_str())); + } + engine_decryption_ = (int (*)(const char*, char*, size_t*))LIBFUNC(handle, "decrypt"); + engine_encryption_ = (int (*)(const char*, char*, size_t))LIBFUNC(handle, "encrypt"); + if (engine_decryption_ == nullptr) { + std::string message = "TensorRT EP could not find decryption function in shared library from " + engine_decryption_lib_path_; + THROW_IF_ERROR(ort_api.CreateStatus(ORT_EP_FAIL, message.c_str())); + } + } + + if (int8_enable_) { + int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); + } + + /* + * Parse explicit min/max/opt profile shapes from provider options. + * + * The format of min/max/opt profile shapes is defined as below: + * "input1:dim1xdim2...,input2:dim1xdim2...,...,input1:dim3xdim4...,input2:dim3xdim4...,..." + * + * (Note: if multiple shapes with same input name are specified, TRT EP will consider them as multiple profiles. + * Please refer to ParserProfileShapes() for more details) + * + */ + bool status = true; + if (status) { + status = ParseProfileShapes(profile_min_shapes, profile_min_shapes_); + if (!status) { + profile_min_shapes_.clear(); + std::string message = "[TensorRT EP] The format of provider option 'trt_profile_min_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + if (status) { + status = ParseProfileShapes(profile_max_shapes, profile_max_shapes_); + if (!status) { + profile_max_shapes_.clear(); + std::string message = "[TensorRT EP] The format of provider option 'trt_profile_max_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + if (status) { + status = ParseProfileShapes(profile_opt_shapes, profile_opt_shapes_); + if (!status) { + profile_opt_shapes_.clear(); + std::string message = "[TensorRT EP] The format of provider option 'trt_profile_opt_shapes' is wrong, please follow the format of 'input1:dim1xdimd2...,input2:dim1xdim2...,...'"; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + if (status) { + status = ValidateProfileShapes(profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_); + if (!status) { + std::string message = "[TensorRT EP] Profile shapes validation failed. Make sure the provider options 'trt_profile_min_shapes', 'trt_profile_max_shapes' and 'trt_profile_opt_shapes' have same input name and number of profile."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + message = "[TensorRT EP] TRT EP will implicitly create optimization profiles based on input tensor for you."; + Ort::ThrowOnError(ort_api.Logger_LogMessage(&logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + profile_min_shapes_.clear(); + profile_max_shapes_.clear(); + profile_opt_shapes_.clear(); + } + } + + // cuda graph: + // cudaStreamSynchronize() is not allowed in cuda graph capture. + // + // external stream: + // If user provides "external" cuda stream, only this cuda stream will be used even if multiple threads are running InferenceSession.Run() concurrently. + // So, no need to synchronize different streams after enqueueV3. + if (cuda_graph_enable_ || external_stream_) { + sync_stream_after_enqueue_ = false; + } + + { + auto lock = GetApiLock(); + runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_, logger_, &ort_api))); + } + + // EP Context setting + if (dump_ep_context_model_) { + extra_attr_keys_.push_back(k_ep_ctx_hardware_architecture.c_str()); + extra_attr_keys_.push_back(k_ep_ctx_onnx_model_filename.c_str()); + + if (engine_cache_enable_ && engine_hw_compatible_) { + extra_attr_values_.push_back(k_cc_hw_compatible.c_str()); + } else { + extra_attr_values_.push_back(compute_capability_.c_str()); + } + extra_attr_values_.push_back(model_path_); + } +} + +void ORT_API_CALL TensorrtExecutionProvider::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) noexcept { + (void)this_ptr; + for (size_t i = 0; i < num_node_compute_infos; i++) { + delete node_compute_infos[i]; + } +} + +// +// Implementation of TRTEpNodeComputeInfo +// +TRTEpNodeComputeInfo::TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* TRTEpNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); + auto state_it = ep.compute_states_.find(fused_node_name); + if (state_it == ep.compute_states_.end()) { + std::string message = "Unable to TensorRT EP's compute state for fused node with name " + fused_node_name; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + TensorrtComputeState& trt_ep_compute_state = *state_it->second; + *compute_state = &trt_ep_compute_state; + return nullptr; +} + +OrtStatus* TRTEpNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + TensorrtComputeState* trt_state = reinterpret_cast(compute_state); + Ort::KernelContext ctx(kernel_context); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel + // function state, access one builder, create/serialize/save engine, save profile and serialize/save timing cache. + // Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + auto fused_node_name = trt_state->fused_node_name; + // This map "shape_ranges" contains the shape range info for setting TRT optimization profiles. + // The info is used for both shape tensor and execution tensor: + // tensor name->(dimension->[min, max, opt]) + auto& shape_ranges = trt_state->input_shape_ranges; + std::unordered_map> + shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this + // inference run + std::unordered_map> + shape_tensor_values_int64; // same as above but for int64 shape tensor input + + uint16_t device_id = trt_state->device_id; + auto max_workspace_size = trt_state->max_workspace_size; + auto trt_builder = trt_state->builder; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto trt_profiles = trt_state->profiles; + auto context_memory = trt_state->context_memory; + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + auto cache_prefix = trt_state->cache_prefix; + auto compute_capability = trt_state->compute_capability; + auto engine_cache_enable = trt_state->engine_cache_enable; + auto engine_hw_compatible = trt_state->engine_hw_compatible; + auto timing_cache_enable = trt_state->timing_cache_enable; + auto force_timing_cache_match = trt_state->force_timing_cache; + auto global_cache_path = trt_state->timing_cache_path; + auto detailed_build_log = trt_state->detailed_build_log; + + auto weight_stripped_engine_enable = trt_state->weight_stripped_engine_enable; + auto weight_stripped_engine_refit = trt_state->weight_stripped_engine_refit; + auto model_path = trt_state->model_path; + auto onnx_model_folder_path = trt_state->onnx_model_folder_path; + auto onnx_model_bytestream = trt_state->onnx_model_bytestream; + auto onnx_model_bytestream_size = trt_state->onnx_model_bytestream_size; + auto onnx_external_data_bytestream = trt_state->onnx_external_data_bytestream; + auto onnx_external_data_bytestream_size = trt_state->onnx_external_data_bytestream_size; + + auto sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; + + int num_inputs = static_cast(input_indexes.size()); + int num_outputs = static_cast(output_indexes.size()); + bool engine_update = false; + bool context_update = false; + std::unordered_set input_names; + + std::unordered_map& dds_output_allocator_maps = ep.GetDDSOutputAllocators(); + auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; + + // Get default OrtMemoryInfo from factory + const OrtMemoryInfo* mem_info = nullptr; + if (ep.factory_.cuda_gpu_memory_infos.find(device_id) != + ep.factory_.cuda_gpu_memory_infos.end()) { + mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get(); + } + + // Get allocator from OrtKernelContext + if (ep.alloc_ == nullptr) { + Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &ep.alloc_)); + } + OrtAllocator* alloc = ep.alloc_; + + void* cuda_stream; + Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even + // if they share the same compute capacity Prepare cache name + std::string cache_path = ""; + // Customize cache prefix if assigned + if (!cache_prefix.empty()) { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->cache_prefix) + trt_state->cache_suffix; + } else { + cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision); + } + + // Enable hardware compatility mode if assigned + std::string cache_hw_compat = "_sm" + compute_capability; +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + if (engine_cache_enable && engine_hw_compatible) { + cache_hw_compat = "_sm80+"; + std::string message = "[TensorRT EP] Hardware compatibility is enabled when loading and capturing engine cache."; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#endif + + // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache + // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even + // if they share the same compute capacity + const std::string cache_path_prefix = cache_path + cache_hw_compat; + std::string engine_cache_path = cache_path_prefix + ".engine"; + const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; + const std::string profile_cache_path = cache_path_prefix + ".profile"; + std::string timing_cache_path = ""; + if (timing_cache_enable) { + timing_cache_path = GetTimingCachePath(global_cache_path, compute_capability); + } + + // If weight-stripped engine is enabled and refitted engine cache is not present, + // TRT EP will use the engine cache with ".stripped.engine" appended to the end. + const std::filesystem::path engine_cache_fs_path = engine_cache_path; + if (weight_stripped_engine_enable && !std::filesystem::exists(engine_cache_fs_path)) { + engine_cache_path = cache_path_prefix + ".stripped.engine"; + weight_stripped_engine_refit = true; + } + + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && !trt_state->engine_decryption_enable && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + std::string message = "[TensorRT EP] DeSerialized " + profile_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); + + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + std::string err_msg = "TensorRT EP Failed to Build Engine."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + message = "[TensorRT EP] DeSerialized " + engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + trt_engine = trt_state->engine->get(); + context_update = true; + + } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && + profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + std::string message = "[TensorRT EP] DeSerialized " + profile_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) { + std::string err_msg = "TensorRT EP could not get engine buffer size"; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + std::string err_msg = "TensorRT EP could not call engine decryption function decrypt"; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size)); + if (!(*(trt_state->engine))) { + std::string err_msg = "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + message = "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + trt_engine = trt_state->engine->get(); + context_update = true; + } + } + + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); + + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile + // shape values have not yet resolved. TRT EP will help determine the min/max/opt profile values based on current + // input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, + shape_tensor_values, shape_tensor_values_int64, stream, + &engine_update); + if (status != nullptr) { + std::string err_msg = "TensorRT EP failed to parse input tensor and generate optimization profiles."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } + } + + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined + // behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + if (max_workspace_size > 0) { + trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size); + } + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + trt_config->setInt8Calibrator(nullptr); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + std::string err_msg = "TensorRT EP failed to set INT8 dynamic range."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + // Set precision + if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + std::string message = "[TensorRT EP] INT8 mode is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + std::string message = "[TensorRT EP] FP16 mode is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + std::string message = "[TensorRT EP] use DLA core " + trt_state->dla_core; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + std::string message = "[TensorRT EP] Sparse weights are allowed"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5 + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + std::string message = "[TensorRT EP] Builder heuristics are enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + std::string message = "[TensorRT EP] Builder optimization level is set to " + trt_state->builder_optimization_level; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + std::string message = "[TensorRT EP] Auxiliary streams are se to " + trt_state->auxiliary_streams; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#else + if (trt_state->builder_optimization_level != 3) { + std::string message = "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + if (trt_state->auxiliary_streams >= 0) { + std::string message = "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#endif + if (weight_stripped_engine_enable) { +#if NV_TENSORRT_MAJOR >= 10 + trt_config->setFlag(nvinfer1::BuilderFlag::kSTRIP_PLAN); + std::string message = "[TensorRT EP] STRIP_PLAN is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + trt_config->setFlag(nvinfer1::BuilderFlag::kREFIT_IDENTICAL); + message = "[TensorRT EP] REFIT_IDENTICAL is enabled"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); +#else + std::string message = "[TensorRT EP] weight-stripped engines can only be used on TRT 10.0 onwards!"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); +#endif + } + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + std::string message = "[TensorRT EP] Tactic sources are limited using bitmask " + tactics; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), + loaded_timing_cache.size())); + if (timing_cache == nullptr) { + std::string err_msg = "TensorRT EP could not create timing cache: " + timing_cache_path; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match); + if (detailed_build_log) { + std::string message = "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + +#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8 + // Enable hardware compatility mode if assigned + if (trt_state->engine_hw_compatible) { + trt_config->setHardwareCompatibilityLevel(nvinfer1::HardwareCompatibilityLevel::kAMPERE_PLUS); + std::string message = "[TensorRT EP] Re-generate engine with hardware compatibility enabled."; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } +#endif + + // Build engine + std::unique_ptr serialized_engine; + { + auto lock = ep.GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log) { + engine_build_start = std::chrono::steady_clock::now(); + } + serialized_engine = std::unique_ptr( + trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config)); + if (!serialized_engine) { + std::string err_msg = "TensorRT EP failed to create engine from network."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); + if (!(*(trt_state->engine))) { + std::string err_msg = "TensorRT EP failed to deserialize engine."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + if (detailed_build_log) { + auto engine_build_stop = std::chrono::steady_clock::now(); + std::string message = "TensorRT engine build for " + trt_state->trt_node_name_with_precision + " took: " + std::to_string(std::chrono::duration_cast(engine_build_stop - engine_build_start).count()) + "ms"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + if (!(*(trt_state->engine))) { + std::string err_msg = "TensorRT EP Failed to Build Engine."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + std::string message = "[TensorRT EP] Serialized " + profile_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + + // Serialize engine + if (trt_state->engine_decryption_enable) { + // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available + // first. + if (trt_state->engine_encryption != nullptr) { + if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), + reinterpret_cast(serialized_engine->data()), + serialized_engine->size())) { + std::string err_msg = "TensorRT EP could not call engine encryption function encrypt"; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + std::string message = "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } else { + std::string message = "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk"; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size()); + std::string message = "[TensorRT EP] Serialized " + engine_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + std::string err_msg = "TensorRT EP could not serialize timing cache: " + timing_cache_path; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log) { + std::string message = "[TensorRT EP] Serialized timing cache " + timing_cache_path; + Ort::ThrowOnError(ep.ort_api.Logger_LogMessage(&ep.logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + + // TODO: In current ORT's EPContext design, there is no way TRT EP can update the engine cache binary in EPContext node with the rebuilt engine. + // The hacky way is to directly modify the EPContext model that graph_partitioner generates in session initialization. + + context_update = true; + + if (weight_stripped_engine_refit) { + auto status = + ep.RefitEngine(model_path, + onnx_model_folder_path, + engine_cache_path, + false /* path check for security */, + onnx_model_bytestream, + onnx_model_bytestream_size, + onnx_external_data_bytestream, + onnx_external_data_bytestream_size, + trt_engine, + true /* serialize refitted engine to disk */, detailed_build_log); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "RefitEngine failed."); + } + } + } + + if (context_update) { + if (trt_state->context_memory_sharing_enable) { +#if NV_TENSORRT_MAJOR < 10 + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); +#else + *(trt_state->context) = + std::unique_ptr(trt_state->engine->get()->createExecutionContext( + nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); +#endif + } else { + *(trt_state->context) = + std::unique_ptr(trt_state->engine->get()->createExecutionContext()); + } + if (!(*(trt_state->context))) { + std::string err_msg = "TensorRT EP failed to create context."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + trt_context = trt_state->context->get(); + } + + // Check before using trt_engine + if (trt_engine == nullptr) { + std::string err_msg = "No engine is found."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, + shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextInput failed."); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + auto status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, + output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextOutput failed."); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + *context_memory = MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true); + } + trt_context->setDeviceMemory((*context_memory).get()); + } + + // TODO: Add support for CUDA graph for plugin ep. + /* + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_.SetStream(stream); + CaptureBegin(0); + } + */ + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + std::string err_msg = "TensorRT EP execution context enqueue failed."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this + * function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent + * issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per + * stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling + * InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each + * thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple + * streams which is not suggested. But, since the whole compute_func() is protected by the lock and if + * cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all + * operations to prevent the concurrent issue mentioned above. However, if cuda graph is enabled, TRT EP won't call + * cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindKernelOutput failed."); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } + } + } + + // TODO: Add support for CUDA graph for plugin ep. + /* + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream + // mentioned in graph capture above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and + // ExecuteGraph() per inference_session.cc. It's safe to start/end CUDA graph capture in compute_func() here since + // cuda graph object is maintained by a per thread basis. + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(0); + // CUDA work issued to a capturing stream doesn't actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(ReplayGraph(0)); + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + */ + + return nullptr; +} + +void TRTEpNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + TensorrtComputeState& trt_ep_compute_state = *reinterpret_cast(compute_state); + (void)trt_ep_compute_state; + // Do nothing for here. +} + +// +// Implementation of TRTEpEpContextNodeComputeInfo +// +TRTEpEpContextNodeComputeInfo::TRTEpEpContextNodeComputeInfo(TensorrtExecutionProvider& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* TRTEpEpContextNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); + auto state_it = ep.compute_states_for_ep_context_.find(fused_node_name); + if (state_it == ep.compute_states_for_ep_context_.end()) { + std::string message = "Unable to TensorRT EP's compute state for fused node with name " + fused_node_name; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + TensorrtComputeStateForEPContext& trt_ep_compute_state = *state_it->second; + *compute_state = &trt_ep_compute_state; + return nullptr; +} + +OrtStatus* TRTEpEpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + auto* node_compute_info = static_cast(this_ptr); + TensorrtExecutionProvider& ep = node_compute_info->ep; + + TensorrtComputeStateForEPContext* trt_state = reinterpret_cast(compute_state); + Ort::KernelContext ctx(kernel_context); + + // The whole compute_function should be considered the critical section. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); + + const std::unordered_map& input_indexes = (trt_state->input_info)[0]; + const std::unordered_map& output_indexes = (trt_state->output_info)[0]; + const std::unordered_map& output_types = (trt_state->output_info)[1]; + uint16_t device_id = trt_state->device_id; + auto fused_node_name = trt_state->fused_node_name; + std::unordered_map& dds_output_allocator_maps = ep.GetDDSOutputAllocators(); + auto& dds_output_allocator_map = dds_output_allocator_maps[fused_node_name]; + auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); + auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; + auto context_memory = trt_state->context_memory; + auto sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; + int num_outputs = static_cast(output_indexes.size()); + std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run + std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input + + // Get default OrtMemoryInfo from factory + const OrtMemoryInfo* mem_info = nullptr; + if (ep.factory_.cuda_gpu_memory_infos.find(device_id) != + ep.factory_.cuda_gpu_memory_infos.end()) { + mem_info = ep.factory_.cuda_gpu_memory_infos[device_id].get(); + } + + // Get allocator from OrtKernelContext + if (ep.alloc_ == nullptr) { + Ort::ThrowOnError(ep.ort_api.KernelContext_GetAllocator(kernel_context, mem_info, &ep.alloc_)); + } + OrtAllocator* alloc = ep.alloc_; + + void* cuda_stream; + Ort::ThrowOnError(ep.ort_api.KernelContext_GetGPUComputeStream(kernel_context, &cuda_stream)); + cudaStream_t stream = static_cast(cuda_stream); + + // cudaStream_t stream; + cudaStreamCreate(&stream); + + // Check before using trt_engine + if (trt_engine == nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "No engine is found."); + } + + // Get input and output binding names + int total_bindings = trt_engine->getNbIOTensors(); + std::vector input_binding_names, output_binding_names; + for (int i = 0, end = total_bindings; i < end; ++i) { + auto const& name = trt_engine->getIOTensorName(i); + auto const& mode = trt_engine->getTensorIOMode(name); + if (mode == nvinfer1::TensorIOMode::kINPUT) { + input_binding_names.push_back(name); + } else { + output_binding_names.push_back(name); + } + } + + /* + * Set input shapes and bind input buffers + */ + std::vector> scratch_buffers; + for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) { + char const* input_name = input_binding_names[i]; + + size_t input_index = 0; + const auto iter = input_indexes.find(input_name); + if (iter != input_indexes.end()) { + input_index = iter->second; + } + auto input_tensor = ctx.GetInput(input_index); + auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); + const auto tensor_shapes = tensor_info.GetShape(); + + auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_tensor_values, + shape_tensor_values_int64, scratch_buffers, alloc, stream); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextInput failed."); + } + } + + /* + * Set output shapes and bind output buffers + */ + std::unordered_map buffers; + buffers.reserve(num_outputs); + using OutputOrtValue = Ort::UnownedValue; + std::unordered_map output_tensors; + output_tensors.reserve(num_outputs); + std::unordered_map output_dim_sizes; + output_dim_sizes.reserve(num_outputs); + + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + + size_t output_type = 0; + const auto type_iter = output_types.find(output_name); + if (type_iter != output_types.end()) { + output_type = type_iter->second; + } + + auto status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, + output_dim_sizes, dds_output_allocator_map, scratch_buffers, alloc, buffers); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindContextOutput failed."); + } + } + + // Set execution context memory + if (trt_state->context_memory_sharing_enable) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + size_t mem_size = trt_engine->getDeviceMemorySize(); +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + if (mem_size > *max_context_mem_size_ptr) { + *max_context_mem_size_ptr = mem_size; + *context_memory = MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr, true); + } + trt_context->setDeviceMemory((*context_memory).get()); + } + + // TODO: Add support for CUDA graph for plugin ep. + /* + // Start CUDA graph capture. + // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because + // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { + // LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; + cuda_graph_.SetStream(stream); + CaptureBegin(0); + } + */ + + // Run TRT inference + if (!trt_context->enqueueV3(stream)) { + std::string err_msg = "TensorRT EP execution context enqueue failed."; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, err_msg.c_str()); + } + + /* + * Given that InferenceSession::Run() is guaranteed to be thread-safe meaning multiple threads can call this + * function concurrently, TRT EP needs to carefully take care of concurrency here, if not, following concurrent + * issue might happen: + * + * It's suggested that to perform inference concurrently in multiple streams, use one trt execution context per + * stream. In the design of TRT EP (Not apply per-thread context implementation) and if multiple threads are calling + * InferenceSession::Run() concurrently, the trt execution context instance is shared by all the threads and each + * thread aquires different stream from ORT. So TRT EP will end up having one trt execution context using multiple + * streams which is not suggested. But, since the whole compute_func() is protected by the lock and if + * cudaStreamSynchronize() is enforced here, one trt execution context per stream is guaranteed. + * + * Therefore, TRT EP needs to call cudaStreamSynchronize() which means to wait until stream has completed all + * operations to prevent the concurrent issue mentioned above. However, if cuda graph is enabled, TRT EP won't call + * cudaStreamSynchronize() since it's not allowed during graph capture. + */ + if (sync_stream_after_enqueue) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + } + + // Assign TRT output back to ORT output + // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished) + // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output + for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { + char const* output_name = output_binding_names[i]; + + size_t output_type = 0; + const auto& iter = output_types.find(output_name); + if (iter != output_types.end()) { + output_type = iter->second; + } + + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { + size_t output_index = 0; + const auto& index_iter = output_indexes.find(output_name); + if (index_iter != output_indexes.end()) { + output_index = index_iter->second; + } + auto status = BindKernelOutput(ctx, mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream); + if (status != nullptr) { + return ep.ort_api.CreateStatus(ORT_EP_FAIL, "BindKernelOutput failed."); + } + } else { + auto& output_tensor = output_tensors[i]; +#if NV_TENSORRT_MAJOR < 10 + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } +#endif + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); + if (output_tensor_ptr != nullptr) { + cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, + output_dim_sizes[i]); + } + } + } + } + + // TODO: Add support for CUDA graph for plugin ep. + /* + // End CUDA graph capture. + // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream + // mentioned in graph capture above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and + // ExecuteGraph() per inference_session.cc. It's safe to start/end CUDA graph capture in compute_func() here since + // cuda graph object is maintained by a per thread basis. + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { + if (IsGraphCaptureAllowed()) { + CaptureEnd(0); + // CUDA work issued to a capturing stream doesn't actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(ReplayGraph(0)); + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + */ + + return nullptr; +} + +void TRTEpEpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + TensorrtComputeStateForEPContext& trt_ep_compute_state = *reinterpret_cast(compute_state); + (void)trt_ep_compute_state; + // Do nothing for here. +} +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def new file mode 100644 index 000000000..ae83cb71f --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.def @@ -0,0 +1,5 @@ +LIBRARY "TensorRTEp.dll" +EXPORTS + CreateEpFactories @1 + ReleaseEpFactory @2 + \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h new file mode 100644 index 000000000..953b2b051 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.h @@ -0,0 +1,416 @@ +#pragma once + +#include "tensorrt_provider_factory.h" +#include "utils/provider_options.h" +#include "tensorrt_execution_provider_info.h" +#include "nv_includes.h" + +#include +#include +#include +#include +#include + +#ifdef _WIN32 +#define EXPORT_API __declspec(dllexport) +#else +#define EXPORT_API +#endif + +using HashValue = uint64_t; +using AllocateFunc = void* (*)(void*, size_t, size_t); +using DestroyFunc = void (*)(void*, void*); + +namespace trt_ep { + +class TensorrtLogger : public nvinfer1::ILogger { + nvinfer1::ILogger::Severity verbosity_; + const OrtLogger& ort_default_logger_; + const OrtApi* ort_api_ = nullptr; + + public: + TensorrtLogger(const OrtLogger& ort_default_logger, + const OrtApi* ort_api, + Severity verbosity = Severity::kWARNING) + : ort_default_logger_{ort_default_logger}, ort_api_{ort_api}, verbosity_(verbosity) {} + void log(Severity severity, const char* msg) noexcept override { + if (severity <= verbosity_) { + time_t rawtime = std::time(0); + struct tm stm; +#ifdef _MSC_VER + gmtime_s(&stm, &rawtime); +#else + gmtime_r(&rawtime, &stm); +#endif + char buf[256]; + strftime(&buf[0], 256, + "%Y-%m-%d %H:%M:%S", + &stm); + const char* sevstr = (severity == Severity::kINTERNAL_ERROR ? " BUG" : severity == Severity::kERROR ? " ERROR" + : severity == Severity::kWARNING ? "WARNING" + : severity == Severity::kINFO ? " INFO" + : "UNKNOWN"); + OrtLoggingLevel ort_severity; + if (severity <= Severity::kERROR) { + ort_severity = ORT_LOGGING_LEVEL_ERROR; + } else { + ort_severity = ORT_LOGGING_LEVEL_WARNING; + } + + std::string message = "[" + std::string(buf) + " " + std::string(sevstr) + "] " + std::string(msg); + + Ort::ThrowOnError(ort_api_->Logger_LogMessage(&ort_default_logger_, + ort_severity, + message.c_str(), ORT_FILE, __LINE__, __FUNCTION__)); + } + } + void set_level(Severity verbosity) { + verbosity_ = verbosity; + } + Severity get_level() const { + return verbosity_; + } +}; + +namespace tensorrt_ptr { + +template +using unique_pointer = std::unique_ptr>; +}; // namespace tensorrt_ptr + +class OutputAllocator : public nvinfer1::IOutputAllocator { + public: +#if NV_TENSORRT_MAJOR >= 10 + void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override; +#else + void* reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override; +#endif + void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override; + + void* getBuffer() { + return outputPtr; + } + + std::vector& getOutputShape() { + return output_shapes; + } + + uint64_t getSize() { + return allocated_size; + } + + ~OutputAllocator() override { + cudaFree(outputPtr); + } + + private: + void* outputPtr{nullptr}; + uint64_t allocated_size = 0; + std::vector output_shapes; +}; + +struct TensorrtComputeState { + uint32_t device_id; + std::string fused_node_name; + nvinfer1::IBuilder* builder; + tensorrt_ptr::unique_pointer* parser = nullptr; + std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; + std::unique_ptr* network = nullptr; + std::vector> input_info; + std::vector> output_info; + std::unordered_map>>> input_shape_ranges; + std::mutex* tensorrt_mu_ptr = nullptr; + std::string compute_capability; + size_t max_workspace_size = 1 << 30; // 1GB; + bool fp16_enable = false; + bool int8_enable = false; + bool int8_calibration_cache_available = false; + bool dla_enable = false; + int dla_core = 0; + std::string trt_node_name_with_precision; + bool engine_cache_enable = false; + std::string engine_cache_path; + nvinfer1::IRuntime* runtime = nullptr; + std::vector profiles; + bool context_memory_sharing_enable = false; + size_t* max_context_mem_size_ptr = nullptr; + AllocatorUniquePtr* context_memory = nullptr; + std::unordered_map dynamic_range_map; + bool engine_decryption_enable = false; + int (*engine_decryption)(const char*, char*, size_t*) = nullptr; + int (*engine_encryption)(const char*, char*, size_t) = nullptr; + bool timing_cache_enable = true; + std::string timing_cache_path; + bool force_timing_cache = false; + bool detailed_build_log = false; + bool build_heuristics_enable = false; + bool sparsity_enable = false; + int builder_optimization_level = 3; + int auxiliary_streams = -1; + bool filter_tactic_sources = false; + nvinfer1::TacticSources tactic_sources; + bool cuda_graph_enable = false; + bool weight_stripped_engine_enable = false; + bool weight_stripped_engine_refit = false; + char* model_path; + std::string onnx_model_folder_path; + const void* onnx_model_bytestream; + size_t onnx_model_bytestream_size; + const void* onnx_external_data_bytestream; + size_t onnx_external_data_bytestream_size; + std::string cache_prefix; + std::string cache_suffix; + bool engine_hw_compatible = false; + bool sync_stream_after_enqueue = true; +}; + +// Minimum information to construct kernel function state for EPContext workflow +struct TensorrtComputeStateForEPContext { + uint32_t device_id; + std::string fused_node_name; + std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; + std::vector> input_info; + std::vector> output_info; + bool context_memory_sharing_enable = false; + size_t* max_context_mem_size_ptr = nullptr; + AllocatorUniquePtr* context_memory = nullptr; + std::mutex* tensorrt_mu_ptr = nullptr; + bool sync_stream_after_enqueue = true; +}; + +using ShapeRangesMap = std::unordered_map>>>; +using DDSOutputAllocatorMap = std::unordered_map>; +std::string GetWeightRefittedEnginePath(std::string engine_cache_path); + +static const std::string k_cc_hw_compatible = "80+"; +static const std::string k_ep_ctx_hardware_architecture = "hardware_architecture"; +static const std::string k_ep_ctx_onnx_model_filename = "onnx_model_filename"; + +/// +/// +/// Plugin TensorRT EP implementing OrtEp. +/// +/// +struct TensorrtExecutionProvider : public OrtEp, public ApiPtrs { + TensorrtExecutionProvider(TensorrtExecutionProviderFactory& factory, const std::string& name, + const OrtSessionOptions& session_options, + const OrtLogger& logger); + ~TensorrtExecutionProvider(); + + TensorrtExecutionProviderFactory& factory_; + std::string name_; + const OrtSessionOptions& session_options_; + const OrtLogger& logger_; + + std::unordered_map> compute_states_; + std::unordered_map> compute_states_for_ep_context_; + + SubGraphCollection_t GetSupportedList(SubGraphCollection_t supported_nodes_list, int iterations, + const int max_iterations, const OrtGraph* graph, bool* early_termination) const; + + OrtStatus* CreateNodeComputeInfoFromPrecompiledEngine(OrtEp* this_ptr, const OrtGraph* graph, + const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_info); + + OrtStatus* CreateNodeComputeInfoFromGraph(OrtEp* this_ptr, const OrtGraph* graph, const OrtNode* fused_node, + std::unordered_map& input_map, + std::unordered_map& output_map, + OrtNodeComputeInfo** node_compute_info, + OrtNode** ep_context_node); + + OrtStatus* RefitEngine(std::string onnx_model_filename, + std::string& onnx_model_folder_path, + std::string& weight_stripped_engine_cath_path, + bool path_check, + const void* onnx_model_bytestream, + size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, + nvinfer1::ICudaEngine* trt_engine, + bool serialize_refitted_engine, + bool detailed_build_log); + + std::unordered_map& GetDDSOutputAllocators() { + return dds_output_allocator_maps_; + } + + /** + Get a unique_lock object to control the concurrency behavior. + Every api call not in the thread-safe operations(https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading) + should be protected by a lock when invoked by multiple threads concurrently. + */ + std::unique_lock GetApiLock() const; + + std::unordered_map trt_node_name_with_precision_; + std::unordered_map> dynamic_range_map_; + std::unordered_map cache_suffix_; + bool external_stream_ = false; + cudaStream_t stream_ = nullptr; + + // The OrtAllocator object will be get during ep compute time + // and should be kept for the lifetime of TRT EP object. + OrtAllocator* alloc_ = nullptr; + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, + _In_ const OrtNode** fused_nodes, _In_ size_t count, + _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, + _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, + size_t num_node_compute_infos) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept; + + mutable TensorrtExecutionProviderInfo info_; + int max_partition_iterations_ = 1000; + size_t min_subgraph_size_ = 1; + size_t max_workspace_size_ = 1 << 30; // 1GB + bool fp16_enable_ = false; + bool int8_enable_ = false; + bool dla_enable_ = false; + int dla_core_ = 0; + bool force_sequential_engine_build_ = false; + std::string int8_calibration_cache_name_; + bool int8_calibration_cache_available_ = false; + bool int8_use_native_tensorrt_calibration_table_ = false; + bool dump_subgraphs_ = false; + bool engine_cache_enable_ = false; + bool weight_stripped_engine_enable_ = false; + bool weight_stripped_engine_refit_ = false; + std::string onnx_model_folder_path_; + const void* onnx_model_bytestream_; + size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_ = nullptr; + size_t onnx_external_data_bytestream_size_ = 0; + bool build_heuristics_enable_ = false; + bool sparsity_enable_ = false; + int builder_optimization_level_ = 3; + int auxiliary_streams_ = -1; + std::string tactic_sources_; + std::string global_cache_path_, cache_path_, engine_decryption_lib_path_; + std::unique_ptr runtime_ = nullptr; + std::mutex tensorrt_mu_; + int device_id_; + std::string compute_capability_; + bool context_memory_sharing_enable_ = false; + bool layer_norm_fp32_fallback_ = false; + size_t max_ctx_mem_size_ = 0; + AllocatorUniquePtr context_memory_ = nullptr; + mutable char model_path_[4096] = {}; // Reserved for max path length + bool engine_decryption_enable_ = false; + int (*engine_decryption_)(const char*, char*, size_t*) = nullptr; + int (*engine_encryption_)(const char*, char*, size_t) = nullptr; + bool timing_cache_enable_ = false; + bool force_timing_cache_match_ = false; + bool detailed_build_log_ = false; + bool cuda_graph_enable_ = false; + std::string cache_prefix_; + bool engine_hw_compatible_ = false; + std::string op_types_to_exclude_; + + // For create/dump EP context node model + bool dump_ep_context_model_ = false; + std::string ep_context_file_path_; + int ep_context_embed_mode_ = 0; + std::string ctx_model_path_; + std::string ep_cache_context_attr_; + std::string engine_cache_relative_path_to_context_model_dir_; + + OrtGraph* ep_ctx_graph_ = nullptr; + std::vector extra_attr_keys_; + std::vector extra_attr_values_; + + std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; + + mutable std::unique_ptr builder_; + + // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. + // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. + // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. + std::unordered_map> parsers_; + std::unordered_map> engines_; + std::unordered_map> contexts_; + std::unordered_map> builders_; + std::unordered_map> networks_; + std::unordered_map>> input_info_; + std::unordered_map>> output_info_; + std::unordered_map>> profile_min_shapes_; + std::unordered_map>> profile_max_shapes_; + std::unordered_map>> profile_opt_shapes_; + std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with + std::unordered_map> profiles_; + std::unordered_map dds_output_allocator_maps_; + + // TODO: Add support for external cudnn and cublas. + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture + // cudnnHandle_t external_cudnn_handle_ = nullptr; + // cublasHandle_t external_cublas_handle_ = nullptr; + + // Call cudaStreamSynchronize() after TRT enqueueV3() + mutable bool sync_stream_after_enqueue_ = true; + + // TODO: Add support for CUDA graph for plugin ep. + /* + CUDAGraph cuda_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + */ + + bool IsGraphCaptureAllowed() const { return false; }; + + nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; + + /**Check whether all the nodes of the graph are assigned to specific ep*/ + bool AllNodesAssignedToSpecificEP(const OrtGraph* graph, const std::string& provider_type) const; + + /**Check the graph is the subgraph of control flow op*/ + bool IsSubGraphOfControlFlowOp(const OrtGraph* graph) const; + + /**Check whether all the nodes of subgraph are supported*/ + bool IsSubGraphFullySupported(const OrtGraph* graph, SubGraphCollection_t supported_nodes_vector) const; +}; + +/// +/// +/// Plugin TensorRT EP OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. +/// +/// +struct TRTEpNodeComputeInfo : OrtNodeComputeInfo { + explicit TRTEpNodeComputeInfo(TensorrtExecutionProvider& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + TensorrtExecutionProvider& ep; +}; + +struct TRTEpEpContextNodeComputeInfo : OrtNodeComputeInfo { + explicit TRTEpEpContextNodeComputeInfo(TensorrtExecutionProvider& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + TensorrtExecutionProvider& ep; +}; +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds new file mode 100644 index 000000000..a6d2ef09a --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider.lds @@ -0,0 +1,7 @@ +VERS_1.0.0 { + global: + CreateEpFactories; + ReleaseEpFactory; + local: + *; +}; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc new file mode 100644 index 000000000..b0716b041 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.cc @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "tensorrt_execution_provider_data_transfer.h" + +#include +#include +#include + +namespace trt_ep { + +void CUDA_RETURN_IF_ERROR(cudaError_t res); + +/*static*/ +bool ORT_API_CALL TRTEpDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + auto& impl = *static_cast(this_ptr); + + auto it = std::find_if(impl.cuda_gpu_mem_devices_.begin(), impl.cuda_gpu_mem_devices_.end(), + [&impl, &src_memory_device, &dst_memory_device](const OrtMemoryDevice* memory_device) { + bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, memory_device); + bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, memory_device); + return src_is_our_device || dst_is_our_device; + }); + + if (it != impl.cuda_gpu_mem_devices_.end()) { + return true; + } + return false; +} + +// function to copy one or more tensors. +// implementation can optionally use async copy if a stream is available for the input. +/*static*/ +OrtStatus* ORT_API_CALL TRTEpDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, + OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + + auto src_tensors = gsl::make_span(src_tensors_ptr, num_tensors); + auto dst_tensors = gsl::make_span(dst_tensors_ptr, num_tensors); + auto streams = gsl::make_span(streams_ptr, num_tensors); + + for (size_t i = 0; i < num_tensors; ++i) { + // NOTE: Stream support will be a separate PR. ignore teh streams_ptr values for now + + const OrtMemoryDevice* src_device = nullptr; + const OrtMemoryDevice* dst_device = nullptr; + src_device = impl.ep_api.Value_GetMemoryDevice(src_tensors[i]); + dst_device = impl.ep_api.Value_GetMemoryDevice(dst_tensors[i]); + + OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + bool copy_involves_pinned_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || + dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; + + const void* src_data = nullptr; + void* dst_data = nullptr; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensors[i], &src_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensors[i], &dst_data)); + + size_t bytes = 0; + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensors[i], &bytes)); + + // for the sync version of memcpy, launch to cuda default stream + if (dst_device_type == OrtMemoryInfoDeviceType_GPU) { + if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> GPU + // Copy only if the two addresses are different and bytes > 0. + if (dst_data != src_data && bytes > 0) { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } else { + // CPU -> GPU, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + if (src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + } + } else if (src_device_type == OrtMemoryInfoDeviceType_GPU) { + // GPU -> CPU, this is blocking + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } else { + // CPU -> CPU involves copy to/from pinned memory and a synchronize may be required first + // ORT_ENFORCE(dst_data != src_data); + memcpy(dst_data, src_data, bytes); + } + } + + return nullptr; +} + +/*static*/ +void ORT_API_CALL TRTEpDataTransfer::ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept { + // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore + // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) + // + // If you create a new instance on each call to OrtEpFactory::CreateDataTransfer you call `delete` here + // delete static_cast(this_ptr); + ; +} +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h new file mode 100644 index 000000000..34221f3a8 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_data_transfer.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "ep_utils.h" +#include "onnxruntime_c_api.h" + +namespace trt_ep { + +struct TRTEpDataTransfer : OrtDataTransferImpl, ApiPtrs { + TRTEpDataTransfer(ApiPtrs api_ptrs, std::vector& device_mem_infos, + std::vector& shared_mem_infos) + : ApiPtrs(api_ptrs), cuda_gpu_mem_devices_{device_mem_infos}, cuda_pinned_mem_devices_{shared_mem_infos} { + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool ORT_API_CALL CanCopyImpl(const OrtDataTransferImpl* this_ptr, const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept; + + // function to copy one or more tensors. + // implementation can optionally use async copy if a stream is available for the input. + static OrtStatus* ORT_API_CALL CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr, + OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, + size_t num_tensors) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; + + private: + std::vector& cuda_gpu_mem_devices_; + std::vector& cuda_pinned_mem_devices_; +}; +} // namespace trt_ep \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc new file mode 100644 index 000000000..17c65ef4c --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.cc @@ -0,0 +1,157 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include //#incldue "core/providers/cuda/cuda_pch.h" + +#include "tensorrt_execution_provider_info.h" +#include "provider_options_utils.h" +#include "cuda/cuda_common.h" +#include "ep_utils.h" + +namespace tensorrt { +namespace provider_option_names { +constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; +constexpr const char* kMaxPartitionIterations = "trt_max_partition_iterations"; +constexpr const char* kMinSubgraphSize = "trt_min_subgraph_size"; +constexpr const char* kMaxWorkspaceSize = "trt_max_workspace_size"; +constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kInt8Enable = "trt_int8_enable"; +constexpr const char* kInt8CalibTable = "trt_int8_calibration_table_name"; +constexpr const char* kInt8UseNativeCalibTable = "trt_int8_use_native_calibration_table"; +constexpr const char* kDLAEnable = "trt_dla_enable"; +constexpr const char* kDLACore = "trt_dla_core"; +constexpr const char* kDumpSubgraphs = "trt_dump_subgraphs"; +constexpr const char* kEngineCacheEnable = "trt_engine_cache_enable"; +constexpr const char* kEngineCachePath = "trt_engine_cache_path"; +constexpr const char* kWeightStrippedEngineEnable = "trt_weight_stripped_engine_enable"; +constexpr const char* kOnnxModelFolderPath = "trt_onnx_model_folder_path"; +constexpr const char* kEngineCachePrefix = "trt_engine_cache_prefix"; +constexpr const char* kDecryptionEnable = "trt_engine_decryption_enable"; +constexpr const char* kDecryptionLibPath = "trt_engine_decryption_lib_path"; +constexpr const char* kForceSequentialEngineBuild = "trt_force_sequential_engine_build"; +// add new provider option name here. +constexpr const char* kContextMemorySharingEnable = "trt_context_memory_sharing_enable"; +constexpr const char* kLayerNormFP32Fallback = "trt_layer_norm_fp32_fallback"; +constexpr const char* kTimingCacheEnable = "trt_timing_cache_enable"; +constexpr const char* kTimingCachePath = "trt_timing_cache_path"; +constexpr const char* kForceTimingCacheMatch = "trt_force_timing_cache"; +constexpr const char* kDetailedBuildLog = "trt_detailed_build_log"; +constexpr const char* kBuildHeuristics = "trt_build_heuristics_enable"; +constexpr const char* kSparsityEnable = "trt_sparsity_enable"; +constexpr const char* kBuilderOptimizationLevel = "trt_builder_optimization_level"; +constexpr const char* kAuxiliaryStreams = "trt_auxiliary_streams"; +constexpr const char* kTacticSources = "trt_tactic_sources"; +constexpr const char* kExtraPluginLibPaths = "trt_extra_plugin_lib_paths"; +constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes"; +constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes"; +constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes"; +constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable"; +constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode"; +constexpr const char* kEpContextFilePath = "trt_ep_context_file_path"; +constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; +constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; +constexpr const char* kONNXBytestream = "trt_onnx_bytestream"; +constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size"; +constexpr const char* kExternalDataBytestream = "trt_external_data_bytestream"; +constexpr const char* kExternalDataBytestreamSize = "trt_external_data_bytestream_size"; +constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude"; + +} // namespace provider_option_names +} // namespace tensorrt + +TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { + TensorrtExecutionProviderInfo info{}; + + void* user_compute_stream = nullptr; + void* onnx_bytestream = nullptr; + void* external_data_bytestream = nullptr; + THROW_IF_ERROR( + ProviderOptionsParser{} + .AddValueParser( + tensorrt::provider_option_names::kDeviceId, + [&info](const std::string& value_str) -> OrtStatus* { + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + int num_devices{}; + CUDA_RETURN_IF_ERROR(cudaGetDeviceCount(&num_devices)); + RETURN_IF_NOT( + 0 <= info.device_id && info.device_id < num_devices, + "Invalid device ID: ", info.device_id, + ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); + return nullptr; + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kMaxPartitionIterations, info.max_partition_iterations) + .AddAssignmentToReference(tensorrt::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + tensorrt::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> OrtStatus* { + size_t address; + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return nullptr; + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kMinSubgraphSize, info.min_subgraph_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kMaxWorkspaceSize, info.max_workspace_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8Enable, info.int8_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8CalibTable, info.int8_calibration_table_name) + .AddAssignmentToReference(tensorrt::provider_option_names::kInt8UseNativeCalibTable, info.int8_use_native_calibration_table) + .AddAssignmentToReference(tensorrt::provider_option_names::kDLAEnable, info.dla_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDLACore, info.dla_core) + .AddAssignmentToReference(tensorrt::provider_option_names::kDumpSubgraphs, info.dump_subgraphs) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCacheEnable, info.engine_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePath, info.engine_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kWeightStrippedEngineEnable, info.weight_stripped_engine_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kOnnxModelFolderPath, info.onnx_model_folder_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineCachePrefix, info.engine_cache_prefix) + .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionEnable, info.engine_decryption_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDecryptionLibPath, info.engine_decryption_lib_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kForceSequentialEngineBuild, info.force_sequential_engine_build) + .AddAssignmentToReference(tensorrt::provider_option_names::kContextMemorySharingEnable, info.context_memory_sharing_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kTimingCachePath, info.timing_cache_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache) + .AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log) + .AddAssignmentToReference(tensorrt::provider_option_names::kBuildHeuristics, info.build_heuristics_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kSparsityEnable, info.sparsity_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kBuilderOptimizationLevel, info.builder_optimization_level) + .AddAssignmentToReference(tensorrt::provider_option_names::kAuxiliaryStreams, info.auxiliary_streams) + .AddAssignmentToReference(tensorrt::provider_option_names::kTacticSources, info.tactic_sources) + .AddAssignmentToReference(tensorrt::provider_option_names::kExtraPluginLibPaths, info.extra_plugin_lib_paths) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMinShapes, info.profile_min_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes) + .AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) + .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextFilePath, info.ep_context_file_path) + .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode) + .AddAssignmentToReference(tensorrt::provider_option_names::kEngineHwCompatible, info.engine_hw_compatible) + .AddValueParser( + tensorrt::provider_option_names::kONNXBytestream, + [&onnx_bytestream](const std::string& value_str) -> OrtStatus* { + size_t address; + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + onnx_bytestream = reinterpret_cast(address); + return nullptr; + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) + .AddValueParser( + tensorrt::provider_option_names::kExternalDataBytestream, + [&external_data_bytestream](const std::string& value_str) -> OrtStatus* { + size_t address; + RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + external_data_bytestream = reinterpret_cast(address); + return nullptr; + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kExternalDataBytestreamSize, info.external_data_bytestream_size) + .AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude) + .Parse(options)); // add new provider option here. + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + info.onnx_bytestream = onnx_bytestream; + info.external_data_bytestream = external_data_bytestream; + return info; +} \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h new file mode 100644 index 000000000..df315cf9a --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_info.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "provider_options.h" + +#include + +// Information needed to construct trt execution providers. +struct TensorrtExecutionProviderInfo { + int device_id{0}; + bool has_user_compute_stream{false}; + void* user_compute_stream{nullptr}; + bool has_trt_options{false}; + int max_partition_iterations{1000}; + int min_subgraph_size{1}; + size_t max_workspace_size{1 << 30}; + bool fp16_enable{false}; + bool int8_enable{false}; + std::string int8_calibration_table_name{""}; + bool int8_use_native_calibration_table{false}; + bool dla_enable{false}; + int dla_core{0}; + bool dump_subgraphs{false}; + bool engine_cache_enable{false}; + std::string engine_cache_path{""}; + bool weight_stripped_engine_enable{false}; + std::string onnx_model_folder_path{""}; + const void* onnx_bytestream{nullptr}; + size_t onnx_bytestream_size{0}; + const void* external_data_bytestream{nullptr}; + size_t external_data_bytestream_size{0}; + bool engine_decryption_enable{false}; + std::string engine_decryption_lib_path{""}; + bool force_sequential_engine_build{false}; + bool context_memory_sharing_enable{false}; + bool layer_norm_fp32_fallback{false}; + bool timing_cache_enable{false}; + std::string timing_cache_path{""}; + bool force_timing_cache{false}; + bool detailed_build_log{false}; + bool build_heuristics_enable{false}; + bool sparsity_enable{false}; + int builder_optimization_level{3}; + int auxiliary_streams{-1}; + std::string tactic_sources{""}; + std::string extra_plugin_lib_paths{""}; + std::string profile_min_shapes{""}; + std::string profile_max_shapes{""}; + std::string profile_opt_shapes{""}; + bool cuda_graph_enable{false}; + bool dump_ep_context_model{false}; + std::string ep_context_file_path{""}; + int ep_context_embed_mode{0}; + std::string engine_cache_prefix{""}; + bool engine_hw_compatible{false}; + std::string op_types_to_exclude{""}; + + static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); +}; diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc new file mode 100644 index 000000000..a6a95451f --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.cc @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "tensorrt_execution_provider_stream_support.h" +#include "tensorrt_provider_factory.h" +#include "tensorrt_execution_provider.h" + +#include "cuda/cuda_common.h" +#include "cuda/cuda_call.h" + +namespace trt_ep { + +// +// TrtSyncStreamImpl implementation +// + +TrtSyncStreamImpl::TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory, const OrtEp* ep, uint32_t device_id, const OrtKeyValuePairs* /*stream_options*/) + : ApiPtrs(factory), ep_{ep}, factory_{&factory} { + ort_version_supported = ORT_API_VERSION; + CreateNotification = CreateNotificationImpl; + GetHandle = GetHandleImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + + const TensorrtExecutionProvider* trt_ep = static_cast(ep_); + if (trt_ep->external_stream_) { + stream_ = trt_ep->stream_; + own_stream_ = false; + } else { + CUDA_CALL_THROW(cudaSetDevice(static_cast(device_id))); + cudaStream_t stream = nullptr; + CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + stream_ = stream; + own_stream_ = true; + } +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncStreamImpl::CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification) noexcept { + auto& impl = *static_cast(this_ptr); + + std::unique_ptr trt_sync_notification; + RETURN_IF_ERROR(TrtSyncNotificationImpl::Create(impl.stream_, impl, trt_sync_notification)); + + *notification = trt_sync_notification.release(); + return nullptr; +} + +/*static*/ +void* ORT_API_CALL TrtSyncStreamImpl::GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return static_cast(impl.stream_); +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncStreamImpl::FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + + // only flush when we own the stream, not external + if (impl.own_stream_) CUDA_CALL_THROW(cudaStreamSynchronize(static_cast(impl.stream_))); + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncStreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + return nullptr; +} + +// callback for EP library to release any internal state +/*static*/ +void ORT_API_CALL TrtSyncStreamImpl::ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +// +// Notification support +// + +/*static*/ +OrtStatus* TrtSyncNotificationImpl::Create(cudaStream_t stream, const ApiPtrs& apis, + std::unique_ptr& notification) { + auto trt_sync_notification = std::make_unique(stream, apis); + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(&trt_sync_notification->event_, cudaEventDisableTiming)); + + notification = std::move(trt_sync_notification); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventRecord(impl.event_, impl.stream_)); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept { + auto& impl = *static_cast(this_ptr); + void* handle = impl.ort_api.SyncStream_GetHandle(stream); + CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(static_cast(handle), impl.event_)); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL TrtSyncNotificationImpl::WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(impl.event_)); + + return nullptr; +} + +/*static*/ +void ORT_API_CALL TrtSyncNotificationImpl::ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h new file mode 100644 index 000000000..7242c247b --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_stream_support.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "onnxruntime_c_api.h" +#include "tensorrt_provider_factory.h" +#include "ep_utils.h" + +#include + +namespace trt_ep { +// +// Class implementing Stream support for synchronization. +// +struct TrtSyncStreamImpl : public OrtSyncStreamImpl, public ApiPtrs { + TrtSyncStreamImpl(TensorrtExecutionProviderFactory& factory, + const OrtEp* ep, + uint32_t device_id, + const OrtKeyValuePairs* /*stream_options*/); + + private: + static OrtStatus* ORT_API_CALL CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** sync_notification) noexcept; + static void* ORT_API_CALL GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + + // EP instance if the stream is being created internally for inferencing. + // nullptr when the stream is created outside of an inference session for data copies. + const OrtEp* ep_; + TensorrtExecutionProviderFactory* factory_{nullptr}; + + cudaStream_t stream_{nullptr}; + bool own_stream_{true}; +}; + +// +// Class implementing synchronization notification support. +// +struct TrtSyncNotificationImpl : public OrtSyncNotificationImpl, public ApiPtrs { + static OrtStatus* Create(cudaStream_t stream, const ApiPtrs& apis, + std::unique_ptr& notification); + + TrtSyncNotificationImpl(cudaStream_t stream, const ApiPtrs& apis) : stream_(stream), ApiPtrs(apis) { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + Release = ReleaseImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + } + + private: + static OrtStatus* ORT_API_CALL ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept; + static OrtStatus* ORT_API_CALL WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + + cudaStream_t& stream_; + cudaEvent_t event_; +}; +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h new file mode 100644 index 000000000..091a7a160 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_execution_provider_utils.h @@ -0,0 +1,1110 @@ +#pragma once + +#include "onnxruntime_cxx_api.h" + +#include "ep_utils.h" +#include "flatbuffers/idl.h" +#include "ort_trt_int8_cal_table.fbs.h" +#include "make_string.h" + +#include "nv_includes.h" +#include "gsl/narrow" + +#include +#include +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +namespace trt_ep { + +bool CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { + size_t alloc_size = size; + if (alignment == 0) { + *out = alloc_size * nmemb; + } else { + size_t alignment_mask = alignment - 1; + *out = (alloc_size * nmemb + alignment_mask) & ~static_cast(alignment_mask); + } + return true; +} + +template +AllocatorUniquePtr MakeUniquePtrFromOrtAllocator(OrtAllocator* ort_allocator, size_t count_or_bytes, + bool use_reserve = false) { + size_t alloc_size = count_or_bytes; + // if T is not void, 'count_or_bytes' == number of items so allow for that + if constexpr (!std::is_void::value) { + // sizeof(void) isn't valid, but the compiler isn't smart enough to ignore that this line isn't + // reachable if T is void. use std::conditional to 'use' void* in the sizeof call + constexpr auto size = sizeof(typename std::conditional::value, void*, T>::type); + CalcMemSizeForArrayWithAlignment(count_or_bytes, size, 0, &alloc_size); + } + + T* p = nullptr; + if (use_reserve) { + p = static_cast(ort_allocator->Reserve(ort_allocator, alloc_size)); + } else { + p = static_cast(ort_allocator->Alloc(ort_allocator, alloc_size)); + } + + return AllocatorUniquePtr{p, [ort_allocator](T* p) { ort_allocator->Free(ort_allocator, p); }}; +} + +// Following helper functions/struct, GetNodeInputEdgeCount, GetOutputNodes, KahnsTopologicalSort, VisitorPriorityQueue, PriorityNodeCompare are added but are not used for now. +// TODO: They will be used for graph partition in the following PR. + +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + +// Get the number of input edges that come from another node upstream. +static OrtStatus* GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_edges) { + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_inputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumInputs(node, &num_inputs)); + + std::vector inputs(num_inputs); + RETURN_IF_ERROR(ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); + + // Sum the number of inputs with a producer node. + num_input_edges = 0; + + for (const OrtValueInfo* ort_input : inputs) { + Ort::ConstValueInfo input{ort_input}; + if (input == nullptr) continue; // Skip missing optional input + + auto producer_info = input.GetProducerNode(); + num_input_edges += static_cast(producer_info.node != nullptr); + } + + return nullptr; +} + +// Get all output nodes that consume an output from the given node. +static OrtStatus* GetOutputNodes(const OrtNode* node, std::vector& result) { + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_outputs = 0; + RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(node, &num_outputs)); + + std::vector outputs(num_outputs); + RETURN_IF_ERROR(ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); + + std::vector output_nodes; + output_nodes.reserve(num_outputs); // May have more than `num_outputs` + + // Gather the OrtNode consumers of every output. + for (const OrtValueInfo* ort_output : outputs) { + Ort::ConstValueInfo output{ort_output}; + if (output == nullptr) continue; // Skip missing optional output + + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); + } + } + + result = std::move(output_nodes); + return nullptr; +} + +// Kahn's topological sort. +// Adapted from onnxruntime/core/graph/graph.cc to use public C API graph types. +static OrtStatus* KahnsTopologicalSort(const OrtGraph& graph, + const std::function& enter, + const std::function& comp) { + const OrtApi& ort_api = Ort::GetApi(); + + try { + // Get all nodes + size_t num_nodes = 0; + RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); + + if (num_nodes == 0) { + return Ort::Status{nullptr}; // Nothing to sort. + } + + std::vector nodes(num_nodes); + RETURN_IF_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); + + // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. + size_t max_node_id = 0; + for (const OrtNode* node : nodes) { + size_t node_id = 0; + RETURN_IF_ERROR(ort_api.Node_GetId(node, &node_id)); + max_node_id = std::max(max_node_id, node_id); + } + + std::vector in_degree(max_node_id + 1, 0); + std::vector topo_order; + VisitorPriorityQueue to_visit(comp); + + topo_order.reserve(num_nodes); + + // Initialize in_degree and initial nodes to visit first. + for (const OrtNode* node : nodes) { + size_t input_edge_count = 0; + RETURN_IF_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); + + size_t node_id = 0; + RETURN_IF_ERROR(ort_api.Node_GetId(node, &node_id)); + + in_degree[node_id] = input_edge_count; + if (input_edge_count == 0) { + to_visit.push(node); + } + } + + while (!to_visit.empty()) { + const OrtNode* current_node = to_visit.top(); + to_visit.pop(); + + if (!current_node) continue; + + if (enter) { + enter(current_node); + } + + std::vector output_nodes; + RETURN_IF_ERROR(GetOutputNodes(current_node, output_nodes)); + + for (const auto& output_node : output_nodes) { + size_t output_node_id = 0; + RETURN_IF_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); + + auto& node_in_degree = in_degree[output_node_id]; + node_in_degree--; + + if (node_in_degree == 0) { + to_visit.push(output_node); + } + } + + size_t current_node_id = 0; + RETURN_IF_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); + topo_order.push_back(current_node_id); + } + + if (num_nodes != topo_order.size()) { + return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +// Node comparison functor copied from onnxruntime/core/graph/graph.cc +struct PriorityNodeCompare { + inline bool IsHighPri(const OrtNode* n) const { + // local statics so we can compare std::strings in the checks + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); + + const char* op_type = nullptr; + Ort::Status status(Ort::GetApi().Node_GetOperatorType(n, &op_type)); + ENFORCE(status.IsOK()); + + return shape_op == op_type || size_op == op_type; + } + + // Used for std::priority_queue + // If return false, n1 will be output first + // If return true, n2 will be output first + bool operator()(const OrtNode* n1, const OrtNode* n2) const { + // nodes in global high priority list will be output first + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; + } + + // nodes with lower priority value will be output first + const auto n1_priority = 0; // n1->Priority(); // Looks to always be 0 inside ORT? + const auto n2_priority = 0; // n2->Priority(); // Looks to always be 0 inside ORT? + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; + } + + // otherwise, nodes with lower index will be output first + size_t n1_id = 0; + Ort::Status status1(Ort::GetApi().Node_GetId(n1, &n1_id)); + ENFORCE(status1.IsOK()); + + size_t n2_id = 0; + Ort::Status status2(Ort::GetApi().Node_GetId(n2, &n2_id)); + ENFORCE(status2.IsOK()); + + return n1_id > n2_id; + } +}; + +bool SetDynamicRange(nvinfer1::INetworkDefinition& network, std::unordered_map& dynamic_range_map) { + // Set dynamic range for input tensors + for (int i = 0; i < network.getNbInputs(); ++i) { + const std::string tensor_name = network.getInput(i)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!network.getInput(i)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for network input " << tensor_name; + return false; + } + } + } + + // Set dynamic range for activations and weights + for (int i = 0; i < network.getNbLayers(); ++i) { + auto trt_layer = network.getLayer(i); + for (int j = 0, e = trt_layer->getNbOutputs(); j < e; ++j) { + const std::string tensor_name = trt_layer->getOutput(j)->getName(); + auto dynamic_range_iter = dynamic_range_map.find(tensor_name); + if (dynamic_range_iter != dynamic_range_map.end()) { +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(-dynamic_range_iter->second, dynamic_range_iter->second)) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for tensor " << tensor_name; + return false; + } + } else if (trt_layer->getType() == nvinfer1::LayerType::kCONSTANT) { + nvinfer1::IConstantLayer* const_layer = static_cast(trt_layer); + const std::string const_layer_name = const_layer->getName(); + auto trt_weights = const_layer->getWeights(); + double max_weight = std::numeric_limits::min(); + for (int64_t k = 0, end = trt_weights.count; k < end; ++k) { + double weight{}; + switch (trt_weights.type) { + case nvinfer1::DataType::kFLOAT: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kBOOL: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT8: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kHALF: + weight = static_cast(trt_weights.values)[k]; + break; + case nvinfer1::DataType::kINT32: + weight = static_cast(trt_weights.values)[k]; + break; +#if NV_TENSORRT_MAJOR >= 10 + case nvinfer1::DataType::kINT64: + weight = static_cast(static_cast(trt_weights.values)[k]); + break; +#endif // NV_TENSORRT_MAJOR >= 10 + default: + // LOGS_DEFAULT(ERROR) << "Found unsupported datatype for layer " << const_layer_name; + return false; + } + max_weight = std::max(max_weight, std::abs(weight)); + } +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + if (!trt_layer->getOutput(j)->setDynamicRange(static_cast(-max_weight), + static_cast(max_weight))) { +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + // LOGS_DEFAULT(ERROR) << "Failed to set dynamic range for layer " << const_layer_name; + return false; + } + } + } + } + return true; +} + +std::vector SplitToStringVec(std::string const& s, char separator) { + std::vector splitted; + + for (size_t start = 0; start < s.length();) { + size_t separatorIndex = s.find(separator, start); + if (separatorIndex == std::string::npos) { + separatorIndex = s.length(); + } + splitted.emplace_back(s.substr(start, separatorIndex - start)); + start = separatorIndex + 1; + } + + return splitted; +} + +nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { + nvinfer1::TacticSources disabledTactics = 0; + nvinfer1::TacticSources enabledTactics = 0; + std::vector tacticList = SplitToStringVec(tactic_string, ','); + for (auto& t : tacticList) { + bool enable{false}; + if (t.front() == '+') { + enable = true; + } else if (t.front() != '-') { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source must be prefixed with + or - skipping: " << t; + } + t.erase(0, 1); + + const auto toUpper = [](std::string& sourceName) { + std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), + [](char c) { return gsl::narrow(std::toupper(c)); }); + return sourceName; + }; + + nvinfer1::TacticSource source{}; + t = toUpper(t); + if (t == "CUBLAS") { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUBLAS; +#endif + } else if (t == "CUBLASLT" || t == "CUBLAS_LT") { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0"; +#if NV_TENSORRT_MAJOR < 9 + source = nvinfer1::TacticSource::kCUBLAS_LT; +#endif + } else if (t == "CUDNN") { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0"; +#if NV_TENSORRT_MAJOR < 10 + source = nvinfer1::TacticSource::kCUDNN; +#endif + } else if (t == "EDGE_MASK_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS; + } else if (t == "JIT_CONVOLUTIONS") { + source = nvinfer1::TacticSource::kJIT_CONVOLUTIONS; + } else { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic source was not found with name: " << t; + } + + uint32_t sourceBit = 1U << static_cast(source); + + if (enable) { + enabledTactics |= sourceBit; + } else { + disabledTactics |= sourceBit; + } + } + return enabledTactics & ~disabledTactics; +} + +inline std::vector loadTimingCacheFile(const std::string inFileName) { + std::ifstream iFile(inFileName, std::ios::in | std::ios::binary); + if (!iFile) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not read timing cache from: " << inFileName + // << ". A new timing cache will be generated and written."; + return std::vector(); + } + iFile.seekg(0, std::ifstream::end); + size_t fsize = iFile.tellg(); + iFile.seekg(0, std::ifstream::beg); + std::vector content(fsize); + iFile.read(content.data(), fsize); + iFile.close(); + return content; +} + +inline void saveTimingCacheFile(const std::string outFileName, const nvinfer1::IHostMemory* blob) { + std::ofstream oFile(outFileName, std::ios::out | std::ios::binary); + if (!oFile) { + // LOGS_DEFAULT(WARNING) << "[TensorRT EP] Could not write timing cache to: " << outFileName; + return; + } + oFile.write((char*)blob->data(), blob->size()); + oFile.close(); +} + +float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) { + int s = (input >> 31) & 0x01; + int e = ((input & 0x7f800000) >> 23) - 127; + int p = -1; + double m = 0.0; + for (int i = 0; i < 23; ++i) { + m += ((input >> (23 - i - 1)) & 0x01) * pow(2.0, p--); + } + return static_cast((s ? -1 : 1) * pow(2.0, e) * (m + 1.0)); +} + +/* + * Read calibration table for INT8 quantization + * Two kind of calibration tables are supported, + * 1. ORT generated calibration table + * The table is pre-serialized by flatbuffers. + * Each entry in the table is a key-value pair, + * key: tensor name, value: maximum absolute value in floating point + * For example, + * data_0 2.008338 + * ... + * 2. Native TensorRT generated calibration table + * Data format is defined by TensorRT as, + * tensor name : scale in 32-bit single precision IEEE754 format + * For example, + * TRT-7103-EntropyCalibration2 + * data_0: 4000889d + * ... + */ +bool ReadDynamicRange(const std::string file_name, const bool is_trt_calibration_table, std::unordered_map& dynamic_range_map) { + std::ifstream infile(file_name, std::ios::binary | std::ios::in); + if (!infile) { + return false; + } + + if (is_trt_calibration_table) { + // Native TensorRT generated calibration table + std::string line; + char delim = ':'; + if (std::getline(infile, line)) { + std::istringstream first_line(line); + std::string version; + std::getline(first_line, version, delim); + std::size_t found = version.find("TRT-"); + if (found != std::string::npos) { + while (std::getline(infile, line)) { + std::istringstream in_line(line); + std::string str; + std::getline(in_line, str, delim); + std::string tensor_name = str; + std::getline(in_line, str, delim); + unsigned long scale_int = std::strtoul(str.c_str(), nullptr, 16); + float scale_float = ConvertSinglePrecisionIEEE754ToFloat(scale_int); + float dynamic_range = scale_float * 127.0f; + dynamic_range_map[tensor_name] = dynamic_range; + } + } else { + throw std::runtime_error("This is not a TensorRT generated calibration table " + file_name); + } + } + } else { + // ORT generated calibration table + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + auto flat_table = flatbuffers::GetRoot((const uint8_t*)data.get()); + auto flat_dict = flat_table->dict(); + for (size_t i = 0, end = flat_dict->size(); i < end; ++i) { + flatbuffers::uoffset_t idx = static_cast(i); + dynamic_range_map[flat_dict->Get(idx)->key()->str()] = std::stof(flat_dict->Get(idx)->value()->str()); + } + } + return true; +} + +/* + * Get number of profile setting. + * + * profile_min_shapes/profile_max_shapes/profile_opt_shapes may contain multiple profile settings. + * Note: TRT EP currently only supports one profile setting. + * + * { + * tensor_a: [[dim_0_value_0, dim_1_value_1, dim_2_value_2]], + * tensor_b: [[dim_0_value_3, dim_1_value_4, dim_2_value_5]] + * } + * + */ +int GetNumProfiles(std::unordered_map>>& profile_shapes) { + int num_profile = 0; + for (auto it = profile_shapes.begin(); it != profile_shapes.end(); it++) { + num_profile = static_cast(it->second.size()); + if (num_profile > 0) { + break; + } + } + return num_profile; +} + +/* + * Seralize engine profile + * The profile contains min/max shape ranges of dynamic shape dimensions of each input tensor + * For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b + * has one dynamic shape dimension: dim_1. The data in profile will be, + * key: tensor_a, value: dim_0 min_shape max_shape dim_2 min_shape max_shape + * key: tensor_b, value: dim_1 min_shape max_shape + * + * [Deprecated] Use SerializeProfileV2 + */ +void SerializeProfile(const std::string& file_name, std::unordered_map>>& shape_ranges) { + // Serialize profile + flexbuffers::Builder builder; + auto profile_start = builder.StartMap(); + for (auto outer_it = shape_ranges.begin(); outer_it != shape_ranges.end(); ++outer_it) { + builder.TypedVector(outer_it->first.c_str(), [&] { + for (auto inner_it = outer_it->second.begin(); inner_it != outer_it->second.end(); ++inner_it) { + builder.Int(inner_it->first); + builder.Int(inner_it->second.first); + builder.Int(inner_it->second.second); + } + }); + } + builder.EndMap(profile_start); + builder.Finish(); + + // Save flexbuffer + std::ofstream file(file_name, std::ios::binary | std::ios::out); + auto buf = builder.GetBuffer(); + size_t size = builder.GetSize(); + file.write(reinterpret_cast(&buf[0]), size); + file.close(); +} + +// Deserialize engine profile +// [Deprecated] Use DeserializeProfileV2 +std::unordered_map>> DeserializeProfile(std::ifstream& infile) { + // Load flexbuffer + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + + // Deserialize profile + std::unordered_map>> shape_ranges; + auto tensors_range_entries = flexbuffers::GetRoot((const uint8_t*)data.get(), length).AsMap(); + auto keys = tensors_range_entries.Keys(); + auto values = tensors_range_entries.Values(); + for (size_t i = 0, i_end = keys.size(); i < i_end; ++i) { + auto dim_range_vectors = values[i].AsTypedVector(); + std::unordered_map> inner_map; + for (size_t j = 0, j_end = dim_range_vectors.size() / 3; j < j_end; ++j) { + size_t idx = 3 * j; + inner_map[dim_range_vectors[idx].AsInt64()] = std::make_pair(dim_range_vectors[idx + 1].AsInt64(), dim_range_vectors[idx + 2].AsInt64()); + } + shape_ranges[keys[i].AsString().c_str()] = inner_map; + } + return shape_ranges; +} + +/* + * Seralize engine profile. (This function starts from ORT 1.15) + * + * + * (1) Single profile case: + * Assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, + * and tensor_b has one dynamic shape dimension: dim_1. + * + * The data before serialization will be: + * { + * tensor_a: { + * dim_0: [[min_shape_0, max_shape_0, opt_shape_0]], + * dim_2: [[min_shape_2, max_shape_2, opt_shape_2]] + * }, + * tensor_b: { + * dim_1: [[min_shape_1, max_shape_1, opt_shape_1]] + * } + * } + * + * The data after serialization will be: + * { + * tensor_a: [dim_0, min_shape_0, max_shape_0, opt_shape_0, dim_2, min_shape_2, max_shape_2, opt_shape_2] + * tensor_b: [dim_1, min_shape_1, max_shape_1, opt_shape_1] + * } + * + * + * (2) Multiple profiles case: + * For example, if the data before serialization is: + * { + * tensor_a: { + * dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + * }, + * tensor_b: { + * dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + * } + * } + * + * The data after serialization will be: + * { + * tensor_a: [dim_0, min_shape_0, max_shape_0, opt_shape_0, dim_0, min_shape_1, max_shape_1, opt_shape_1] + * | | | | + * ---------------- profile 0 ----------------- ---------------- profile 1 ----------------- + * + * tensor_b: [dim_1, min_shape_2, max_shape_2, opt_shape_2, dim_1, min_shape_3, max_shape_3, opt_shape_3] + * | | | | + * ---------------- profile 0 ----------------- ---------------- profile 1 ----------------- + * } + * + */ +void SerializeProfileV2(const std::string& file_name, std::unordered_map>>>& shape_ranges) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In SerializeProfileV2()"; + // Serialize profile + flexbuffers::Builder builder; + auto tensor_map_start = builder.StartMap(); + for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << tensor_it->first.c_str() << "'"; + builder.TypedVector(tensor_it->first.c_str(), [&] { + for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { + size_t num_profiles = dim_it->second.size(); + for (size_t i = 0; i < num_profiles; i++) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] profile #" << i << ", dim is " << dim_it->first; + builder.Int(dim_it->first); + builder.Int(dim_it->second[i][0]); + builder.Int(dim_it->second[i][1]); + builder.Int(dim_it->second[i][2]); + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim_it->first << ", " << dim_it->second[i][0] << ", " << dim_it->second[i][1] << ", " << dim_it->second[i][2]; + } + } + }); + } + builder.EndMap(tensor_map_start); + builder.Finish(); + + // Save flexbuffer + std::ofstream file(file_name, std::ios::binary | std::ios::out); + auto buf = builder.GetBuffer(); + size_t size = builder.GetSize(); + file.write(reinterpret_cast(&buf[0]), size); + file.close(); +} + +/* + * Deserialize engine profile. (This function starts from ORT 1.15) + * + * + * (1) Single profile case: + * Assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, + * and tensor_b has one dynamic shape dimension: dim_1. + * + * The data in profile file will be: + * { + * tensor_a: [dim_0, min_shape_0, max_shape_0, opt_shape_0, dim_2, min_shape_2, max_shape_2, opt_shape_2] + * tensor_b: [dim_1, min_shape_1, max_shape_1, opt_shape_1] + * } + * + * The data after deserialization will be: + * { + * tensor_a: { + * dim_0: [[min_shape_0, max_shape_0, opt_shape_0]], + * dim_2: [[min_shape_2, max_shape_2, opt_shape_2]] + * }, + * tensor_b: { + * dim_1: [[min_shape_1, max_shape_1, opt_shape_1]] + * } + * } + * + * + * (2) Multiple profiles case: + * For example, if the data in profile file is: + * { + * tensor_a: [dim_0, min_shape_0, max_shape_0, opt_shape_0, dim_0, min_shape_1, max_shape_1, opt_shape_1] + * | | | | + * ---------------- profile 0 ----------------- ---------------- profile 1 ----------------- + * + * tensor_b: [dim_1, min_shape_2, max_shape_2, opt_shape_2, dim_1, min_shape_3, max_shape_3, opt_shape_3] + * | | | | + * ---------------- profile 0 ----------------- ---------------- profile 1 ----------------- + * } + * + * The data after deserialization will be: + * { + * tensor_a: { + * dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]] + * }, + * tensor_b: { + * dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]] + * } + * } + */ +std::unordered_map>>> DeserializeProfileV2(std::ifstream& infile) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] In DeserializeProfileV2()"; + // Load flexbuffer + infile.seekg(0, std::ios::end); + size_t length = infile.tellg(); + infile.seekg(0, std::ios::beg); + std::unique_ptr data{new char[length]}; + infile.read((char*)data.get(), length); + infile.close(); + + // Deserialize profile + std::unordered_map>>> shape_ranges; + auto tensors_range_entries = flexbuffers::GetRoot((const uint8_t*)data.get(), length).AsMap(); + auto keys = tensors_range_entries.Keys(); + auto values = tensors_range_entries.Values(); + for (size_t i = 0, end = keys.size(); i < end; ++i) { // iterate tensors + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] input tensor is '" << keys[i].AsString().c_str() << "'"; + auto dim_range_vector = values[i].AsTypedVector(); + std::unordered_map>> inner_map; + std::vector> profile_vector; + + for (size_t k = 0; k < (dim_range_vector.size() / 4); k++) { // iterate dim, min, max, opt for all profiles + std::vector shape_vector; + auto idx = 4 * k; + auto dim = dim_range_vector[idx].AsInt64(); + shape_vector.push_back(dim_range_vector[idx + 1].AsInt64()); // min shape + shape_vector.push_back(dim_range_vector[idx + 2].AsInt64()); // max shape + shape_vector.push_back(dim_range_vector[idx + 3].AsInt64()); // opt shape + + if (inner_map.find(dim) == inner_map.end()) { + inner_map[dim] = profile_vector; + } + inner_map[dim].push_back(shape_vector); + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << dim << ", " << shape_vector[0] << ", " << shape_vector[1] << ", " << shape_vector[2]; + } + shape_ranges[keys[i].AsString().c_str()] = inner_map; + } + return shape_ranges; +} + +/* + * Compare profile shapes from profile file (.profile) with explicit profile min/max/opt shapes. + * Return false meaning no need to rebuild engine if everything is same. + * Otherwise return true and engine needs to be rebuilt. + */ +bool CompareProfiles(const std::string& file_name, + std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes) { + std::ifstream profile_file(file_name, std::ios::binary | std::ios::in); + if (!profile_file) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << file_name << " doesn't exist."; + return true; + } + + std::unordered_map>>> shape_ranges; + shape_ranges = DeserializeProfileV2(profile_file); + + /* The format of the two data structures are below, for example: + * + * shape_ranges: + * { + * tensor_a: { + * dim_0: [[min_shape, max_shape, opt_shape]], + * dim_2: [[min_shape, max_shape, opt_shape]] + * }, + * tensor_b: { + * dim_1: [[min_shape, max_shape, opt_shape]] + * } + * } + * + * profile_min_shapes: + * { + * tensor_a: [[dim_0_value_0, dim_1_value_1, dim_2_value_2]], + * tensor_b: [[dim_0_value_3, dim_1_value_4, dim_2_value_5]] + * } + * + */ + + // Check number of dynamic shape inputs + if (profile_min_shapes.size() != shape_ranges.size()) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of dynamic shape inputs are not the same."; + return true; + } + + // Iterate through shape_ranges map + for (auto tensor_it = shape_ranges.begin(); tensor_it != shape_ranges.end(); tensor_it++) { // iterate tensors + auto tensor_name = tensor_it->first; + if (profile_min_shapes.find(tensor_name) == profile_min_shapes.end()) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tensor name '" << tensor_name << "' doesn't exist in trt_profile_min_shapes."; + return true; + } + + for (auto dim_it = tensor_it->second.begin(); dim_it != tensor_it->second.end(); dim_it++) { // iterate dimensions + auto dim = dim_it->first; + auto num_profiles = GetNumProfiles(profile_min_shapes); + + if (dim_it->second.size() != static_cast(num_profiles)) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Numbers of profiles are not the same."; + return true; + } + + for (size_t i = 0; i < dim_it->second.size(); i++) { // iterate (multiple) profile(s) + auto shape_values = dim_it->second[i]; + if (dim > (profile_min_shapes[tensor_name][i].size() - 1)) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] dimension " << dim << " of '" << tensor_name << "' in " << file_name << " exceeds the total dimension of trt_profile_min_shapes."; + return true; + } + + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_min_shapes[tensor_name][i][dim]; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[0] << " in " << file_name; + if (profile_min_shapes[tensor_name][i][dim] != shape_values[0]) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] min shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_max_shapes[tensor_name][i][dim]; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[1] << " in " << file_name; + if (profile_max_shapes[tensor_name][i][dim] != shape_values[1]) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] max shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << profile_opt_shapes[tensor_name][i][dim]; + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape value of dimension " << dim << " of '" << tensor_name << "' is " << shape_values[2] << " in " << file_name; + if (profile_opt_shapes[tensor_name][i][dim] != shape_values[2]) { + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] opt shape values of dimension " << dim << " of '" << tensor_name << "' are not the same"; + return true; + } + } + } + } + return false; +} + +/* + * Get cache by name + * + */ +std::string GetCachePath(const std::string& root, const std::string& name) { + if (root.empty()) { + return name; + } else { + fs::path path = root; + path.append(name); + return path.string(); + } +} + +/* + * Get compute capability + * + */ +std::string GetComputeCapacity(const cudaDeviceProp& prop) { + const std::string compute_capability = std::to_string(prop.major * 10 + prop.minor); + return compute_capability; +} + +/* + * Get Timing by compute capability + * + */ +std::string GetTimingCachePath(const std::string& root, std::string& compute_cap) { + // append compute capability of the GPU as this invalidates the cache and TRT will throw when loading the cache + const std::string timing_cache_name = "TensorrtExecutionProvider_cache_sm" + + compute_cap + ".timing"; + return GetCachePath(root, timing_cache_name); +} + +/* + * Get cache by type + * + * \param root root path of the cache + * \param file_extension It could be ".engine", ".profile" or ".timing" + */ +std::vector GetCachesByType(const std::string& root, std::string file_extension) { + std::vector cache_files; + for (const auto& entry : fs::directory_iterator(root)) { + if (fs::path(file_extension) == fs::path(entry).extension()) { + cache_files.push_back(fs::path(entry)); + } + } + return cache_files; +} + +bool IsCacheExistedByType(const std::string& root, std::string file_extension) { + auto cache_files = GetCachesByType(root, file_extension); + if (cache_files.size() == 0) { + return false; + } + return true; +} + +void RemoveCachesByType(const std::string& root, std::string file_extension) { + auto cache_files = GetCachesByType(root, file_extension); + for (const auto& entry : cache_files) { + fs::remove(entry); + } +} + +bool ValidateProfileShapes(std::unordered_map>>& profile_min_shapes, + std::unordered_map>>& profile_max_shapes, + std::unordered_map>>& profile_opt_shapes) { + if (profile_min_shapes.empty() && profile_max_shapes.empty() && profile_opt_shapes.empty()) { + return true; + } + + if ((profile_min_shapes.size() != profile_max_shapes.size()) && + (profile_min_shapes.size() != profile_opt_shapes.size()) && + (profile_max_shapes.size() != profile_opt_shapes.size())) { + return false; + } + + std::unordered_map>>::iterator it; + for (it = profile_min_shapes.begin(); it != profile_min_shapes.end(); it++) { + auto input_name = it->first; + auto num_profile = it->second.size(); + + // input_name must also be in max/opt profile + if ((profile_max_shapes.find(input_name) == profile_max_shapes.end()) || + (profile_opt_shapes.find(input_name) == profile_opt_shapes.end())) { + return false; + } + + // number of profiles should be the same + if ((num_profile != profile_max_shapes[input_name].size()) || + (num_profile != profile_opt_shapes[input_name].size())) { + return false; + } + } + + return true; +} + +/* + * Make input-name and shape as a pair. + * This helper function is being used by ParseProfileShapes(). + * + * For example: + * The input string is "input_id:32x1", + * after the string is being parsed, the pair object is returned as below. + * pair("input_id", [32, 1]) + * + * Return true if string can be successfully parsed or false if string has wrong format. + */ +bool MakeInputNameShapePair(std::string pair_string, std::pair>& pair) { + if (pair_string.empty()) { + return true; + } + + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << pair_string; + + std::stringstream input_string_stream(pair_string); + char first_delim = ':'; + char second_delim = 'x'; + std::string input_name; + std::string shape; + std::getline(input_string_stream, input_name, first_delim); + std::getline(input_string_stream, shape, first_delim); + + std::vector shapes; + std::stringstream shape_string_stream(shape); + std::string value; + while (std::getline(shape_string_stream, value, second_delim)) { + shapes.push_back(std::stoi(value)); + } + + // wrong input string + if (input_name.empty() || shapes.empty()) { + return false; + } + + pair.first = input_name; + pair.second = shapes; + + return true; +} + +/* + * Parse explicit profile min/max/opt shapes from TensorRT EP provider options. + * + * For example: + * The provider option is --trt_profile_min_shapes="input_id:32x1,attention_mask:32x1,input_id:32x41,attention_mask:32x41", + * after string is being parsed, the profile shapes has two profiles and is being represented as below. + * {"input_id": [[32, 1], [32, 41]], "attention_mask": [[32, 1], [32, 41]]} + * + * Return true if string can be successfully parsed or false if string has wrong format. + */ +bool ParseProfileShapes(std::string profile_shapes_string, std::unordered_map>>& profile_shapes) { + if (profile_shapes_string.empty()) { + return true; + } + + std::stringstream input_string_stream(profile_shapes_string); + char delim = ','; + std::string input_name_with_shape; // input_name:shape, ex: "input_id:32x1" + while (std::getline(input_string_stream, input_name_with_shape, delim)) { + std::pair> pair; + if (!MakeInputNameShapePair(input_name_with_shape, pair)) { + return false; + } + + std::string input_name = pair.first; + if (profile_shapes.find(input_name) == profile_shapes.end()) { + std::vector> profile_shape_vector; + profile_shapes[input_name] = profile_shape_vector; + } + profile_shapes[input_name].push_back(pair.second); + + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << input_name; + std::string shape_string = ""; + for (auto v : pair.second) { + shape_string += std::to_string(v); + shape_string += ", "; + } + ////LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << shape_string; + } + + return true; +} + +std::vector split(const std::string& str, char delimiter) { + std::vector tokens; + std::string token; + std::istringstream tokenStream(str); + while (std::getline(tokenStream, token, delimiter)) { + tokens.push_back(token); + } + return tokens; +} + +std::string join(const std::vector& vec, const std::string& delimiter) { + std::string result; + for (size_t i = 0; i < vec.size(); ++i) { + result += vec[i]; + if (i < vec.size() - 1) { + result += delimiter; + } + } + return result; +} + +/* + * Parse engine cache name suffix when user customizes prefix for engine cache name + * + * For example: + * When default subgraph name is "TensorrtExecutionProvider_TRTKernel_graph_torch-jit-export_2068723788287043730_189_189_fp16" + * This func will generate the suffix "2068723788287043730_189_fp16" + * + */ +std::string GetCacheSuffix(const std::string& fused_node_name, const std::string& trt_node_name_with_precision) { + std::vector split_fused_node_name = split(fused_node_name, '_'); + if (split_fused_node_name.size() >= 3) { + // Get index of model hash from fused_node_name + std::string model_hash = split_fused_node_name[split_fused_node_name.size() - 3]; + size_t index = fused_node_name.find(model_hash); + // Parse suffix from trt_node_name_with_precision, as it has additional precision info + std::vector suffix_group = split(trt_node_name_with_precision.substr(index), '_'); + if (suffix_group.size() > 2) { + suffix_group.erase(suffix_group.begin() + 2); + } + return join(suffix_group, "_"); + } + return ""; +} + +} // namespace trt_ep diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc new file mode 100644 index 000000000..6ff43e783 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.cc @@ -0,0 +1,312 @@ +#include "onnxruntime_cxx_api.h" +#include "tensorrt_provider_factory.h" +#include "tensorrt_execution_provider.h" +#include "cuda_allocator.h" + +#include +#include +#include +#include +#include +#include +#include + +namespace trt_ep { + +TensorrtExecutionProviderFactory::TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis) + : ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; + + GetSupportedDevices = GetSupportedDevicesImpl; + + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; +} + +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_name_.c_str(); +} + +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_.c_str(); +} + +const char* ORT_API_CALL TensorrtExecutionProviderFactory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->ep_version_.c_str(); +} + +OrtStatus* TensorrtExecutionProviderFactory::CreateMemoryInfoForDevices(int num_devices) { + cuda_gpu_memory_infos.reserve(num_devices); + cuda_pinned_memory_infos.reserve(num_devices); + + for (int device_id = 0; device_id < num_devices; ++device_id) { + OrtMemoryInfo* mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("Cuda", OrtMemoryInfoDeviceType_GPU, + /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, + /* device_id */ device_id, OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); + + cuda_gpu_memory_infos[device_id] = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + + // HOST_ACCESSIBLE memory should use the non-CPU device type + mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CudaPinned", OrtMemoryInfoDeviceType_GPU, + /*vendor OrtDevice::VendorIds::NVIDIA*/ 0x10DE, + /* device_id */ device_id, OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info)); + + cuda_pinned_memory_infos[device_id] = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + } + + return nullptr; +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + // Create two memory infos per device. + // The memory info is required to create allocator and gpu data transfer. + int num_cuda_devices = 0; + cudaGetDeviceCount(&num_cuda_devices); + RETURN_IF_ERROR(factory->CreateMemoryInfoForDevices(num_cuda_devices)); + + int32_t device_id = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + // C API + const OrtHardwareDevice& device = *devices[i]; + + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // These can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.CreateKeyValuePairs(&ep_options); + + // The ep options can be provided here as default values. + // Users can also call SessionOptionsAppendExecutionProvider_V2 C API with provided ep options to override. + factory->ort_api.AddKeyValuePair(ep_metadata, "gpu_type", "data center"); // random example using made up values + factory->ort_api.AddKeyValuePair(ep_options, "trt_builder_optimization_level", "3"); + + // OrtEpDevice copies ep_metadata and ep_options. + OrtEpDevice* ep_device = nullptr; + auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_device); + + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + + const OrtMemoryInfo* cuda_gpu_mem_info = factory->cuda_gpu_memory_infos[device_id].get(); + const OrtMemoryInfo* cuda_pinned_mem_info = factory->cuda_pinned_memory_infos[device_id].get(); + + // Register the allocator info required by TRT EP. + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_gpu_mem_info)); + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, cuda_pinned_mem_info)); + + // Get memory device from memory info for gpu data transfer + factory->cuda_gpu_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_gpu_mem_info)); + factory->cuda_pinned_mem_devices.push_back(factory->ep_api.MemoryInfo_GetMemoryDevice(cuda_pinned_mem_info)); + + ep_devices[num_ep_devices++] = ep_device; + ++device_id; + } + + // C++ API equivalent. Throws on error. + //{ + // Ort::ConstHardwareDevice device(devices[i]); + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // Ort::KeyValuePairs ep_metadata; + // Ort::KeyValuePairs ep_options; + // ep_metadata.Add("version", "0.1"); + // ep_options.Add("trt_builder_optimization_level", "3"); + // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; + // ep_devices[num_ep_devices++] = ep_device.release(); + // } + //} + } + + // Create gpu data transfer + auto data_transfer_impl = std::make_unique(static_cast(*factory), + factory->cuda_gpu_mem_devices, // device memory + factory->cuda_pinned_mem_devices // shared memory + ); + + factory->data_transfer_impl = std::move(data_transfer_impl); + + return nullptr; +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateEpImpl( + OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Out_ OrtEp** ep) noexcept { + auto* factory = static_cast(this_ptr); + *ep = nullptr; + + if (num_devices != 1) { + // we only registered for GPU and only expected to be selected for one GPU + // if you register for multiple devices (e.g. CPU, GPU and maybe NPU) you will get an entry for each device + // the EP has been selected for. + return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "TensorRT EP only supports selection for one device."); + } + + // Create the execution provider + RETURN_IF_ERROR(factory->ort_api.Logger_LogMessage(logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "Creating TensorRT EP", ORT_FILE, __LINE__, __FUNCTION__)); + + // use properties from the device and ep_metadata if needed + // const OrtHardwareDevice* device = devices[0]; + // const OrtKeyValuePairs* ep_metadata = ep_metadata[0]; + + auto trt_ep = std::make_unique(*factory, factory->ep_name_, *session_options, *logger); + + *ep = trt_ep.release(); + return nullptr; +} + +void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept { + TensorrtExecutionProvider* trt_ep = static_cast(ep); + delete trt_ep; +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto& factory = *static_cast(this_ptr); + + // NOTE: The factory implementation is free to return a shared OrtAllocator* instance instead of creating a new + // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make + // ReleaseAllocatorImpl a no-op. + + // NOTE: EP should implement its own arena logic. ep_arena.cc/h is provided as a reference and we use it here for + // device memory. `allocator_options` can be used for arena configuration and there is a helper in ep_arena.h + // to convert from OrtKeyValuePairs to the same arena config settings that ORT uses. + // You are of course free to have completely different settings. + + const OrtMemoryDevice* mem_device = factory.ep_api.MemoryInfo_GetMemoryDevice(memory_info); + uint32_t device_id = factory.ep_api.MemoryDevice_GetDeviceId(mem_device); + + if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_DEFAULT) { + // use the one that previously created + if (factory.cuda_gpu_allocators.find(device_id) != factory.cuda_gpu_allocators.end()) { + *allocator = factory.cuda_gpu_allocators[device_id].get(); + return nullptr; + } + + // create a CUDA allocator + auto cuda_allocator = std::make_unique(memory_info, static_cast(device_id)); + + *allocator = cuda_allocator.get(); + factory.cuda_gpu_allocators[device_id] = std::move(cuda_allocator); + + } else if (factory.ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // use the one that previously created + if (factory.cuda_pinned_allocators.find(device_id) != factory.cuda_pinned_allocators.end()) { + *allocator = factory.cuda_pinned_allocators[device_id].get(); + return nullptr; + } + + // create a CUDA PINNED allocator + auto cuda_pinned_allocator = std::make_unique(memory_info); + + *allocator = cuda_pinned_allocator.get(); + factory.cuda_pinned_allocators[device_id] = std::move(cuda_pinned_allocator); + + } else { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " + "Value did not come directly from an OrtEpDevice returned by this factory."); + } + + return nullptr; +} + +void ORT_API_CALL TensorrtExecutionProviderFactory::ReleaseAllocatorImpl(OrtEpFactory* /*this*/, + OrtAllocator* allocator) noexcept { + // no-op. The allocators will be shared across sessions. + // delete static_cast(allocator); +} + +OrtStatus* ORT_API_CALL TensorrtExecutionProviderFactory::CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + auto& factory = *static_cast(this_ptr); + *data_transfer = factory.data_transfer_impl.get(); + + return nullptr; +} + +bool ORT_API_CALL TensorrtExecutionProviderFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return true; +} + +} // namespace trt_ep + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ort_ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + + // Factory could use registration_name or define its own EP name. + std::unique_ptr factory = std::make_unique(registration_name, *default_logger, ApiPtrs{*ort_api, *ort_ep_api, *model_editor_api}); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} + +} // extern "C" diff --git a/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h new file mode 100644 index 000000000..fcb0eba14 --- /dev/null +++ b/plugin_execution_providers/tensorrt/tensorrt_provider_factory.h @@ -0,0 +1,69 @@ +#pragma once + +#include "ep_utils.h" +#include "tensorrt_execution_provider_data_transfer.h" +#include "cuda_allocator.h" + +using MemoryInfoUniquePtr = std::unique_ptr>; + +namespace trt_ep { + +/// +/// Plugin TensorRT EP factory that can create an OrtEp and return information about the supported hardware devices. +/// +struct TensorrtExecutionProviderFactory : public OrtEpFactory, public ApiPtrs { + public: + TensorrtExecutionProviderFactory(const char* ep_name, const OrtLogger& default_logger, ApiPtrs apis); + + OrtStatus* CreateMemoryInfoForDevices(int num_devices); + + // CUDA gpu memory and CUDA pinned memory are required for allocator and data transfer, these are the OrtMemoryInfo + // instance required for that. + // Current TRT EP implementation uses one default OrtMemoryInfo and one host accessible OrtMemoryInfo per ep device. + std::unordered_map cuda_gpu_memory_infos; // device id -> memory info + std::unordered_map cuda_pinned_memory_infos; + + // Keeps allocators per ep device in factory so they can be shared across sessions. + std::unordered_map> cuda_gpu_allocators; // device id -> allocator + std::unordered_map> cuda_pinned_allocators; + + std::vector cuda_gpu_mem_devices; + std::vector cuda_pinned_mem_devices; + std::unique_ptr data_transfer_impl; // data transfer implementation for this factory + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, size_t num_devices, + OrtEpDevice** ep_devices, size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, size_t num_devices, + const OrtSessionOptions* session_options, const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept; + + const std::string ep_name_; // EP name + const std::string vendor_{"Nvidia"}; // EP vendor name + const std::string ep_version_{"0.1.0"}; // EP version + const OrtLogger& default_logger_; +}; +} // namespace trt_ep \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h new file mode 100644 index 000000000..05eff33c0 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_call.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +// ----------------------------------------------------------------------- +// Error handling +// ----------------------------------------------------------------------- +// +template +const char* CudaErrString(ERRTYPE) { + THROW(); +} + +template +std::conditional_t CudaCall( + ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg, const char* file, const int line) { + if (retCode != successCode) { + try { + int currentCudaDevice = -1; + cudaGetDevice(¤tCudaDevice); + cudaGetLastError(); // clear last CUDA error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=? ; file=%s ; line=%d ; expr=%s; %s", + libName, (int)retCode, CudaErrString(retCode), currentCudaDevice, + // hostname, + file, line, exprString, msg); + if constexpr (THRW) { + // throw an exception with the error info + THROW(str); + } else { + return MAKE_STATUS(ORT_EP_FAIL, str); + } + } catch (const std::exception& e) { // catch, log, and rethrow since CUDA code sometimes hangs in destruction, so we'd never get to see the error + if constexpr (THRW) { + THROW(e.what()); + } else { + return MAKE_STATUS(ORT_EP_FAIL, e.what()); + } + } + } + if constexpr (!THRW) { + return nullptr; + } +} + +#define CUDA_CALL(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) +#define CUDA_CALL_THROW(expr) (CudaCall((expr), #expr, "CUDA", cudaSuccess, "", __FILE__, __LINE__)) diff --git a/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h b/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h new file mode 100644 index 000000000..38f9d147e --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/cuda/cuda_common.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cuda_call.h" + +namespace cuda { + +#define CUDA_RETURN_IF_ERROR(expr) RETURN_IF_ERROR(CUDA_CALL(expr)) + +} // namespace cuda \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/utils/ep_utils.h b/plugin_execution_providers/tensorrt/utils/ep_utils.h new file mode 100644 index 000000000..90022ef57 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/ep_utils.h @@ -0,0 +1,118 @@ +#pragma once + +#include "onnxruntime_cxx_api.h" + +// #include "flatbuffers/idl.h" +// #include "ort_trt_int8_cal_table.fbs.h" +#include "make_string.h" +// #include "core/providers/cuda/cuda_pch.h" +// #include "core/common/path_string.h" +// #include "core/framework/murmurhash3.h" + +// #include"nv_includes.h" +#include "gsl/narrow" + +#include +#include +#include +#include +#include +#include +#include + +struct ApiPtrs { + const OrtApi& ort_api; + const OrtEpApi& ep_api; + const OrtModelEditorApi& model_editor_api; +}; + +namespace trt_ep { + +#define ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(MakeString(__VA_ARGS__)); \ + } \ + } while (false) + +#define THROW(...) \ + throw std::runtime_error(MakeString(__VA_ARGS__)); + +#define RETURN_IF_ORTSTATUS_ERROR(fn) RETURN_IF_ERROR(fn) + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF_ORT_STATUS_ERROR(fn) \ + do { \ + auto _status = (fn); \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ...) \ + do { \ + if ((cond)) { \ + return Ort::GetApi().CreateStatus(ORT_EP_FAIL, MakeString(__VA_ARGS__).c_str()); \ + } \ + } while (0) + +#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__) + +#define MAKE_STATUS(error_code, msg) \ + Ort::GetApi().CreateStatus(error_code, (msg)); + +#define THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if (_status != nullptr) { \ + std::ostringstream oss; \ + oss << Ort::GetApi().GetErrorMessage(_status); \ + Ort::GetApi().ReleaseStatus(_status); \ + throw std::runtime_error(oss.str()); \ + } \ + } while (0) + +#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \ + return false; \ + } \ + } while (0) + +// Helper to release Ort one or more objects obtained from the public C API at the end of their scope. +template +struct DeferOrtRelease { + DeferOrtRelease(T** object_ptr, std::function release_func) + : objects_(object_ptr), count_(1), release_func_(release_func) {} + + DeferOrtRelease(T** objects, size_t count, std::function release_func) + : objects_(objects), count_(count), release_func_(release_func) {} + + ~DeferOrtRelease() { + if (objects_ != nullptr && count_ > 0) { + for (size_t i = 0; i < count_; ++i) { + if (objects_[i] != nullptr) { + release_func_(objects_[i]); + objects_[i] = nullptr; + } + } + } + } + T** objects_ = nullptr; + size_t count_ = 0; + std::function release_func_ = nullptr; +}; + +template +using AllocatorUniquePtr = std::unique_ptr>; + +} // namespace trt_ep \ No newline at end of file diff --git a/plugin_execution_providers/tensorrt/utils/helper.cc b/plugin_execution_providers/tensorrt/utils/helper.cc new file mode 100644 index 000000000..8bf59d2f5 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/helper.cc @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef _WIN32 +#include +#include +#endif + +#ifdef ORT_NO_EXCEPTIONS +#if defined(__ANDROID__) +#include +#else +#include +#endif +#endif + +#include +#include "ep_utils.h" + +#ifdef _WIN32 +std::string ToUTF8String(std::wstring_view s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); + assert(len > 0); + std::string ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} + +std::wstring ToWideString(std::string_view s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); + assert(len > 0); + std::wstring ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} +#endif // #ifdef _WIN32 + +#ifdef NO_EXCEPTIONS +void PrintFinalMessage(const char* msg) { +#if defined(__ANDROID__) + __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); +#else + // TODO, consider changing the output of the error message from std::cerr to logging when the + // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output + // might not be easily accessible on some systems such as mobile + // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS + std::cerr << msg << std::endl; +#endif +} +#endif // #ifdef NO_EXCEPTIONS diff --git a/plugin_execution_providers/tensorrt/utils/make_string.h b/plugin_execution_providers/tensorrt/utils/make_string.h new file mode 100644 index 000000000..a21be30ba --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/make_string.h @@ -0,0 +1,122 @@ +/** + * Copyright (c) 2016-present, Facebook, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// Portions Copyright (c) Microsoft Corporation + +#pragma once + +#include +#include +#include + +namespace detail { + +inline void MakeStringImpl(std::ostringstream& /*ss*/) noexcept { +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t) noexcept { + ss << t; +} + +template +inline void MakeStringImpl(std::ostringstream& ss, const T& t, const Args&... args) noexcept { + MakeStringImpl(ss, t); + MakeStringImpl(ss, args...); +} + +// see MakeString comments for explanation of why this is necessary +template +inline std::string MakeStringImpl(const Args&... args) noexcept { + std::ostringstream ss; + MakeStringImpl(ss, args...); + return ss.str(); +} + +// +// Infrastructure to convert char[n] to char* to reduce binary size +// + +// default is to leave the type as is +template +struct if_char_array_make_ptr { + using type = T; +}; + +// specialization that matches an array reference, which is what the char array from a string literal +// used in a call to MakeString will be. +// if the type is a char[n] array we 'decay' it to a char* so that the usages can be folded. +template +struct if_char_array_make_ptr { + // remove a single extent (T[x] -> T, but T[x][y] -> T[y]) so we only match char[x], + // and get the type name without the 'const' so both 'const char (&)[n]' and 'char (&)[n]' are matched. + using element_type = typename std::remove_const::type>::type; + using type = typename std::conditional::value, T*, T (&)[N]>::type; +}; + +// helper to make usage simpler in MakeString +template +using if_char_array_make_ptr_t = typename if_char_array_make_ptr::type; +} // namespace detail + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses the current locale. + */ +template +std::string MakeString(const Args&... args) { + // We need to update the types from the MakeString template instantiation to decay any char[n] to char*. + // e.g. MakeString("in", "out") goes from MakeString to MakeStringImpl + // so that MakeString("out", "in") will also match MakeStringImpl instead of requiring + // MakeStringImpl. + // + // We have to do the type processing before any actual work, so this function purely implements the type processing. + // If we do not do it this way we do not get the full binary size reduction. + // + // See https://stackoverflow.com/a/29418212/684911 for overall details of the approach, but note it does not cover + // the need to do the type processing as a separate step. + + return detail::MakeStringImpl(detail::if_char_array_make_ptr_t(args)...); +} + +/** + * Makes a string by concatenating string representations of the arguments. + * This version uses std::locale::classic(). + */ +template +std::string MakeStringWithClassicLocale(const Args&... args) { + std::ostringstream ss; + ss.imbue(std::locale::classic()); + detail::MakeStringImpl(ss, args...); + return ss.str(); +} + +// MakeString versions for already-a-string types. + +inline std::string MakeString(const std::string& str) { + return str; +} + +inline std::string MakeString(const char* cstr) { + return cstr; +} + +inline std::string MakeStringWithClassicLocale(const std::string& str) { + return str; +} + +inline std::string MakeStringWithClassicLocale(const char* cstr) { + return cstr; +} diff --git a/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h new file mode 100644 index 000000000..6f07c67ad --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/ort_graph_to_proto.h @@ -0,0 +1,868 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// DO NOT include ORT header files as this is meant to be a header-only utility that can be copied +// to other projects. + +/* + SUMMARY: + Utilities to serialize an OrtGraph into an ONNX GraphProto or ModelProto. Can be used by execution provider + implementations that need to convert an OrtGraph instance into an ONNX protobuf model. + + Users may copy this file and modify as needed. + + USAGE: + This is a header-only implementation that includes both the function declarations and definitions. Copy this file + into a project that links with both ONNX Runtime and ONNX. + + Define the ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL preprocessor macro before the #include statement in exactly one C++ + file to define the implementation. Example: + + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + Other compilation units that depend on these utilities should include this file without defining the + preprocessor macro. + + Example program snippets are shown below. Refer to the function declarations for detailed usage information. + + EXAMPLE SNIPPET (initializers stored within TensorProto): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + onnx::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto); + + // graph_proto stores initializers internally + } + ``` + + EXAMPLE SNIPPET (large initializers stored in external file): + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + std::string external_file_path = "weights.bin"; + std::ofstream out_file(external_file_path, std::ios::binary); + + auto handle_initializer_data = [&external_file_path, &out_file](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, consumers, etc. + (void)value_info; + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = out_file.tellp(); + location = external_file_path; + out_file.write(static_cast(data), bytes); + out_file.flush(); + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto stores large initializers in an external file + } + ``` + + EXAMPLE SNIPPET (external initializers that point to data in memory, not officially supported by ONNX spec): + + This example stores initializers externally. However, instead of storing the initializers in a separate + file, the onnx::TensorProto objects point directly to memory addresses. This requires setting the initializer's + location to a special tag like "_MEM_ADDR_" (instead of a file path). The offset is set to the pointer to the + initializer's data in memory (instead of an offset into a file). + + Because this is not standard ONNX, such a onnx::GraphProto should not be saved as an ONNX file. + However, it allows custom tools that operate directly on a onnx::GraphProto to get the initializer data + if it has already been loaded into memory. + + ```C++ + #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + #include "ort_graph_to_proto.h" + + OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* ort_graph, + OrtEpGraphSupportInfo* graph_support_info) { + auto handle_initializer_data = [](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + (void)value_info; + (void)bytes; + + offset = reinterpret_cast(data); + location = "_MEM_ADDR_"; // Some special location tag that indicates the offset is a pointer. + is_external = true; // True if is external initializer + return Ort::Status{nullptr}; + } + + ONNX_NAMESPACE::GraphProto graph_proto; + OrtEpUtils::OrtGraphToProto(*ort_graph, graph_proto, handle_initializer_data); + + // graph_proto has initializers that look like they are stored in an external file, + // but they are actually pointing to the data in memory. + } + ``` +*/ + +#ifndef INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ +#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +#include +#include "onnxruntime_cxx_api.h" +#include "onnx/onnx_pb.h" + +namespace OrtEpUtils { + +/// +/// Signature of user-provided function to handle initializer data. Called by OrtGraphToProto() for every initializer. +/// +/// If the function sets the `is_external` output parameter to false, OrtGraphToProto() stores initializer data +/// within the TensorProto as raw_data. +/// +/// Otherwise, if the function sets `is_external` to true, OrtGraphToProto() assumes that this function stores the +/// initializer data in a file. In this case, OrtGraphToProto() configures the corresponding TensorProto to point the +/// location and offset returned via the `location` and `offset` output parameters. +/// +/// It is recommended to keep small initializers with byte size <= 127 stored inline the TensorProto to ensure +/// ONNX shape inference works correctly with the serialized ONNX model. +/// +/// OrtValueInfo for the initializer. Can be used to query name, type, shape, +/// and consumer nodes. +/// Opaque pointer to the initializer data. +/// Size in bytes of the initializer data. +/// Output parameter set to true if the initializer data is stored externally. The +/// implementer is responsible for writing the initializer data to file. If set to false, +/// the initializer will be stored within the TensorProto. +/// Output parameter set to the location (e.g., file) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// Output parameter set to the offset (e.g., file offset) into which the initializer is stored +/// by the implementer of this function. Ignored if `is_external` is set to false. +/// An Ort::Status indicating success or an error. Serialization exits if this returns an error. +using HandleInitializerDataFunc = std::function; + +/// +/// Serializes the provided OrtGraph to a onnx::GraphProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination GraphProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); + +/// +/// Serializes the provided top-level OrtGraph to a onnx::ModelProto. +/// Allows the caller to provide a function that specifies whether an initializer should be stored +/// within a TensorProto, written to a file, or remain as an in-memory external initializer (not valid ONNX). +/// +/// OrtGraph instance to serialize. +/// Destination ModelProto into which to serialize the input OrtGraph. +/// Optional function called to allow the user to determine +/// where the initializer data is stored. +/// An Ort::Status indicating success or an error. +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func = nullptr); +} // namespace OrtEpUtils + +// End of header +#endif // INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ + +// +// IMPLEMENTATION BELOW +// +#ifdef ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL + +#include +#include +#include +#include +#include +#include + +#define ORT_EP_UTILS_C_RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return Ort::Status{_status}; \ + } \ + } while (0) + +#define ORT_EP_UTILS_CXX_RETURN_IF_ERROR(fn) \ + do { \ + Ort::Status _status = (fn); \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define ORT_EP_UTILS_C_RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return Ort::Status{(ort_api).CreateStatus(ORT_FAIL, (msg))}; \ + } \ + } while (0) + +namespace OrtEpUtils { + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims); +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::GraphProto& graph_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // + // Set GraphProto metadata + // + const char* graph_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetName(&ort_graph, &graph_name)); + graph_proto.set_name(graph_name); + graph_proto.set_doc_string("Serialized from OrtGraph"); + + // + // Set GraphProto inputs and outputs + // + size_t num_graph_inputs = 0; + size_t num_graph_outputs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumInputs(&ort_graph, &num_graph_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOutputs(&ort_graph, &num_graph_outputs)); + + std::vector graph_inputs(num_graph_inputs); + std::vector graph_outputs(num_graph_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetInputs(&ort_graph, graph_inputs.data(), graph_inputs.size())); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOutputs(&ort_graph, graph_outputs.data(), graph_outputs.size())); + + for (const OrtValueInfo* ort_value_info : graph_inputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_input()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + for (const OrtValueInfo* ort_value_info : graph_outputs) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_output()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*ort_value_info, *value_info_proto)); + } + + // + // Set GraphProto nodes, value_infos, and initializers. + // + + // Use std::maps to store OrtValueInfos for GraphProto.value_info and GraphProto.initializer. + // A std::map maintains its elements in a stable ordering. + std::map value_infos; // For GraphProto.value_info + std::map initializer_value_infos; // For GraphProto.initializer + + // Helper function to collect an OrtValueInfo into `value_infos` or `initializer_value_infos`. + // Optionally returns the OrtValueInfo name to the caller. + auto collect_value_info = [&ort_api, &value_infos, + &initializer_value_infos](const OrtValueInfo& ort_value_info, + /*out*/ const char** value_name_out = nullptr) -> Ort::Status { + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + + if (value_name_out != nullptr) { + *value_name_out = value_name; + } + + if (value_infos.count(value_name) != 0 || initializer_value_infos.count(value_name) != 0) { + return Ort::Status{nullptr}; // Already processed this OrtValueInfo. + } + + bool is_required_graph_input = false; + bool is_optional_graph_input = false; + bool is_graph_output = false; + bool is_constant_initializer = false; + bool is_from_outer_scope = false; + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsRequiredGraphInput(&ort_value_info, &is_required_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsOptionalGraphInput(&ort_value_info, &is_optional_graph_input)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsGraphOutput(&ort_value_info, &is_graph_output)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsConstantInitializer(&ort_value_info, &is_constant_initializer)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_IsFromOuterScope(&ort_value_info, &is_from_outer_scope)); + + // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. + // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. + // For values defined in an outer scope, just add the value info but not the initializer. + if (is_from_outer_scope) { + value_infos.emplace(value_name, &ort_value_info); + } else if (is_optional_graph_input) { + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (is_constant_initializer) { + value_infos.emplace(value_name, &ort_value_info); + initializer_value_infos.emplace(value_name, &ort_value_info); + } else if (!is_required_graph_input && !is_graph_output) { + value_infos.emplace(value_name, &ort_value_info); // This is an internal OrtValueInfo. + } + + return Ort::Status{nullptr}; + }; + + size_t num_nodes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + + std::vector nodes(num_nodes); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + // Loop through all nodes (topological order): add NodeProto instances to GraphProto and track OrtValueInfos + // that will be stored in GraphProto.value_info and GraphProto.initializer. + for (size_t i = 0; i < num_nodes; i++) { + const OrtNode* ort_node = nodes[i]; + onnx::NodeProto* node_proto = graph_proto.add_node(); + + const char* node_name = nullptr; + const char* node_domain = nullptr; + const char* node_op_type = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetName(ort_node, &node_name)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetDomain(ort_node, &node_domain)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOperatorType(ort_node, &node_op_type)); + + node_proto->set_name(node_name); + node_proto->set_domain(node_domain); + node_proto->set_op_type(node_op_type); + + size_t num_inputs = 0; + size_t num_implicit_inputs = 0; + size_t num_outputs = 0; + size_t num_attrs = 0; + size_t num_subgraphs = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumInputs(ort_node, &num_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumImplicitInputs(ort_node, &num_implicit_inputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumOutputs(ort_node, &num_outputs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumAttributes(ort_node, &num_attrs)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetNumSubgraphs(ort_node, &num_subgraphs)); + + // Handle node attributes + if (num_attrs > 0) { + std::vector ort_attrs(num_attrs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetAttributes(ort_node, ort_attrs.data(), ort_attrs.size())); + + for (const OrtOpAttr* ort_attr : ort_attrs) { + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + + Ort::Status attr_type_status{ort_api.OpAttr_GetType(ort_attr, &attr_type)}; + if (attr_type == OrtOpAttrType::ORT_OP_ATTR_GRAPH) { + // ORT does not support reading subgraphs via ReadOpAttr(), so skip it. + // Can use Node_GetSubgraphs to get subgraphs. + continue; + } + + if (!attr_type_status.IsOK()) { + // Unsupported attribute type. + return attr_type_status; + } + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + } + } + + // Handle node subgraphs + if (num_subgraphs > 0) { + std::vector ort_subgraphs(num_subgraphs); + std::vector subgraph_attr_names(num_subgraphs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetSubgraphs(ort_node, ort_subgraphs.data(), ort_subgraphs.size(), + subgraph_attr_names.data())); + + for (size_t subgraph_idx = 0; subgraph_idx < num_subgraphs; subgraph_idx++) { + const OrtGraph* ort_subgraph = ort_subgraphs[subgraph_idx]; + const char* subgraph_attr_name = subgraph_attr_names[subgraph_idx]; + + onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::GraphProto* subgraph_proto = attr_proto->mutable_g(); + + attr_proto->set_name(subgraph_attr_name); + attr_proto->set_type(onnx::AttributeProto_AttributeType_GRAPH); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(*ort_subgraph, *subgraph_proto)); + } + } + + // Handle node inputs + if (num_inputs > 0) { + std::vector ort_inputs(num_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetInputs(ort_node, ort_inputs.data(), ort_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_inputs) { + if (ort_value_info == nullptr) { + // missing optional input. + node_proto->add_input(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_input(value_name); + } + } + + // Handle implicit inputs to this node. + if (num_implicit_inputs > 0) { + std::vector ort_implicit_inputs(num_implicit_inputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetImplicitInputs(ort_node, ort_implicit_inputs.data(), + ort_implicit_inputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_implicit_inputs) { + assert(ort_value_info != nullptr); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, /*value_name_out*/ nullptr)); + } + } + + // Handle node outputs + if (num_outputs > 0) { + std::vector ort_outputs(num_outputs); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetOutputs(ort_node, ort_outputs.data(), ort_outputs.size())); + + for (const OrtValueInfo* ort_value_info : ort_outputs) { + if (ort_value_info == nullptr) { + // missing optional output. + node_proto->add_output(""); + continue; + } + + const char* value_name = nullptr; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(collect_value_info(*ort_value_info, &value_name)); + + node_proto->add_output(value_name); + } + } + } + + // Add value_infos to GraphProto as ValueInfoProto objects. + for (const std::pair& entry : value_infos) { + onnx::ValueInfoProto* value_info_proto = graph_proto.mutable_value_info()->Add(); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(*entry.second, *value_info_proto)); + } + + // Add initializers to GraphProto as TensorProto objects. + for (const std::pair& entry : initializer_value_infos) { + const OrtValueInfo* initializer_value_info = entry.second; + std::string initializer_name = std::string{entry.first}; // Need a null-terminated string. + std::vector initializer_dims; + std::vector initializer_sym_dims; + ONNXTensorElementDataType initializer_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(*initializer_value_info, /*get_sym_dims*/ false, + initializer_elem_type, initializer_dims, + initializer_sym_dims)); + + onnx::TensorProto* tensor_proto = graph_proto.add_initializer(); + tensor_proto->set_name(initializer_name); + tensor_proto->set_data_type(initializer_elem_type); + + auto* tensor_proto_dims = tensor_proto->mutable_dims(); + for (int64_t dim : initializer_dims) { + tensor_proto_dims->Add(dim); + } + + const OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ValueInfo_GetInitializerValue(initializer_value_info, &ort_value)); + + const void* data = nullptr; + size_t data_bytes = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorData(ort_value, &data)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorSizeInBytes(ort_value, &data_bytes)); + + std::string ext_location; + int64_t ext_offset = 0; + bool is_external = false; + + if (handle_initializer_data_func != nullptr) { + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(handle_initializer_data_func(initializer_value_info, data, data_bytes, + is_external, ext_location, ext_offset)); + } + + if (is_external) { + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_EXTERNAL); + auto* ext_data_entries = tensor_proto->mutable_external_data(); + onnx::StringStringEntryProto* location_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* offset_entry = ext_data_entries->Add(); + onnx::StringStringEntryProto* length_entry = ext_data_entries->Add(); + + location_entry->set_key("location"); + location_entry->set_value(ext_location); + offset_entry->set_key("offset"); + offset_entry->set_value(std::to_string(ext_offset)); + length_entry->set_key("length"); + length_entry->set_value(std::to_string(data_bytes)); + } else { + // User wants to store data inline the TensorProto's raw_data + tensor_proto->set_data_location(onnx::TensorProto_DataLocation_DEFAULT); + tensor_proto->set_raw_data(data, data_bytes); + } + } + + return Ort::Status{nullptr}; +} + +Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, + onnx::ModelProto& model_proto, + HandleInitializerDataFunc handle_initializer_data_func) { + const OrtApi& ort_api = Ort::GetApi(); + + // Check that OrtGraph is a top-level graph (no parent node). + const OrtNode* parent_node = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetParentNode(&ort_graph, &parent_node)); + ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, ort_api, "Cannot serialize nested OrtGraph into a ModelProto"); + + // Set model description. + model_proto.set_doc_string("Serialized from OrtGraph"); + model_proto.set_producer_name("ort_ep_utils::OrtGraphToProto"); + + // Set ir version. + int64_t ir_version = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOnnxIRVersion(&ort_graph, &ir_version)); + model_proto.set_ir_version(ir_version); + + // Set operator sets. + size_t num_operator_sets = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetNumOperatorSets(&ort_graph, &num_operator_sets)); + ORT_EP_UTILS_C_RETURN_IF(num_operator_sets == 0, ort_api, "OrtGraph should have at least one operator set."); + + std::vector domains(num_operator_sets, nullptr); + std::vector opset_versions(num_operator_sets); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Graph_GetOperatorSets(&ort_graph, domains.data(), opset_versions.data(), + num_operator_sets)); + + auto* operator_sets = model_proto.mutable_opset_import(); + + for (size_t i = 0; i < num_operator_sets; ++i) { + onnx::OperatorSetIdProto* operator_set = operator_sets->Add(); + operator_set->set_domain(domains[i]); + operator_set->set_version(opset_versions[i]); + } + + model_proto.clear_graph(); + onnx::GraphProto* graph_proto = model_proto.mutable_graph(); + + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtGraphToProto(ort_graph, *graph_proto, handle_initializer_data_func)); + + return Ort::Status{nullptr}; +} + +static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_info, + bool get_symbolic_dims, + /*out*/ ONNXTensorElementDataType& elem_type, + /*out*/ std::vector& dims, + /*out*/ std::vector& symbolic_dims) { + const OrtApi& ort_api = Ort::GetApi(); + + const OrtTypeInfo* ort_type_info = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoTypeInfo(&ort_value_info, &ort_type_info)); + + ONNXType ort_onnx_type = ONNX_TYPE_UNKNOWN; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetOnnxTypeFromTypeInfo(ort_type_info, &ort_onnx_type)); + ORT_EP_UTILS_C_RETURN_IF(ort_onnx_type != ONNX_TYPE_TENSOR, ort_api, "Expected OrtValueInfo to represent a Tensor"); + + const OrtTensorTypeAndShapeInfo* ort_type_shape = nullptr; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.CastTypeInfoToTensorInfo(ort_type_info, &ort_type_shape)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetTensorElementType(ort_type_shape, &ort_elem_type)); + + size_t num_dims = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensionsCount(ort_type_shape, &num_dims)); + + std::vector ort_dims(num_dims, 0); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetDimensions(ort_type_shape, ort_dims.data(), ort_dims.size())); + + elem_type = ort_elem_type; + dims = std::move(ort_dims); + + if (get_symbolic_dims) { + std::vector ort_dim_syms(num_dims, nullptr); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetSymbolicDimensions(ort_type_shape, ort_dim_syms.data(), + ort_dim_syms.size())); + + symbolic_dims.reserve(num_dims); + for (const char* sym_dim : ort_dim_syms) { + symbolic_dims.push_back(sym_dim); + } + } + + return Ort::Status{nullptr}; +} + +// Create an onnx::ValueInfoProto from an OrtValueInfo (name, type, shape). +static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, + onnx::ValueInfoProto& value_info_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + std::vector ort_dims; + std::vector ort_dim_syms; + ONNXTensorElementDataType ort_elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + // We currently only support ONNX tensors. Support for other types (e.g., ONNX_TYPE_SEQUENCE) can be added later. + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(GetOrtValueInfoTensorTypeShape(ort_value_info, /*get_sym_dims*/ true, + ort_elem_type, ort_dims, ort_dim_syms)); + + const char* value_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.GetValueInfoName(&ort_value_info, &value_name)); + value_info_proto.set_name(value_name); + + onnx::TypeProto_Tensor* type_proto_tensor = value_info_proto.mutable_type()->mutable_tensor_type(); + type_proto_tensor->set_elem_type(ort_elem_type); + + // If there are no dimensions in the shape, do not set a TensorShapeProto. Otherwise, it always looks + // like a scalar value. + if (!ort_dims.empty()) { + onnx::TensorShapeProto* shape_proto = type_proto_tensor->mutable_shape(); + + for (size_t dim_idx = 0; dim_idx < ort_dims.size(); dim_idx++) { + onnx::TensorShapeProto_Dimension* dim_proto = shape_proto->add_dim(); + + if (ort_dims[dim_idx] >= 0) { + dim_proto->set_dim_value(ort_dims[dim_idx]); + } else { + const std::string& dim_param = ort_dim_syms[dim_idx]; + + // If dim_param is empty, leave dim_proto with neither the dim_value or dim_param set, + // which represents an unknown dimension. + if (!dim_param.empty()) { + dim_proto->set_dim_param(dim_param); + } + } + } + } + + return Ort::Status{nullptr}; +} + +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { + const OrtApi& ort_api = Ort::GetApi(); + + const char* attr_name = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetName(&ort_attr, &attr_name)); + attr_proto.set_name(attr_name); + + size_t total_attr_bytes = 0; + OrtOpAttrType attr_type = OrtOpAttrType::ORT_OP_ATTR_UNDEFINED; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetType(&ort_attr, &attr_type)); + + switch (attr_type) { + case OrtOpAttrType::ORT_OP_ATTR_INT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INT); + + int64_t i_val = 0; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &i_val, sizeof(i_val), &total_attr_bytes)); + attr_proto.set_i(i_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_INTS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector i_vals(total_attr_bytes / sizeof(int64_t)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, i_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* ints = attr_proto.mutable_ints(); + for (int64_t val : i_vals) { + ints->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOAT); + + float f_val = 0.0f; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, &f_val, sizeof(f_val), &total_attr_bytes)); + attr_proto.set_f(f_val); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_FLOATS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector f_vals(total_attr_bytes / sizeof(float)); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, f_vals.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* floats = attr_proto.mutable_floats(); + for (float val : f_vals) { + floats->Add(val); + } + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRING); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::string* str = attr_proto.mutable_s(); + + str->resize(total_attr_bytes); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, + &total_attr_bytes)); + + str->resize(total_attr_bytes); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_STRINGS); + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; + std::vector chars(total_attr_bytes, '\0'); + + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, chars.data(), total_attr_bytes, + &total_attr_bytes)); + + auto* strs = attr_proto.mutable_strings(); + + // Strings are all in a single buffer, each separated with a '\0'. + // Extract each string and add it to the STRINGS attribute array. + char* at = chars.data(); + char* end = at + chars.size(); + + while (at < end) { + char* str_begin = at; + + while (*at && at < end) { + at++; + } + + strs->Add()->assign(str_begin, at - str_begin); + if (at < end) { + assert(*at == '\0'); + at++; // Skip '\0' to get to the beginning of the next string. + } + } + + break; + } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); + + Ort::Value tensor(ort_value); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + size_t element_size = 0; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + element_size = sizeof(float); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + element_size = sizeof(uint8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + element_size = sizeof(int8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + element_size = sizeof(uint16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + element_size = sizeof(int16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + element_size = sizeof(int32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + element_size = sizeof(int64_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + element_size = sizeof(bool); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + element_size = sizeof(double); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + element_size = sizeof(uint32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + element_size = sizeof(uint64_t); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + auto shape = type_shape_info.GetShape(); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + size_t element_count = type_shape_info.GetElementCount(); + size_t data_bytes = element_count * element_size; + const void* data = tensor.GetTensorData(); + + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); + + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; + } + default: { + std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + return Ort::Status{nullptr}; +} + +} // namespace OrtEpUtils +#endif // ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL diff --git a/plugin_execution_providers/tensorrt/utils/parse_string.h b/plugin_execution_providers/tensorrt/utils/parse_string.h new file mode 100644 index 000000000..b10d0dfc8 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/parse_string.h @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +/** + * Tries to parse a value from an entire string. + */ +template +bool TryParseStringWithClassicLocale(std::string_view str, T& value) { + if constexpr (std::is_integral::value && std::is_unsigned::value) { + // if T is unsigned integral type, reject negative values which will wrap + if (!str.empty() && str[0] == '-') { + return false; + } + } + + // don't allow leading whitespace + if (!str.empty() && std::isspace(str[0], std::locale::classic())) { + return false; + } + + std::istringstream is{std::string{str}}; + is.imbue(std::locale::classic()); + T parsed_value{}; + + const bool parse_successful = + is >> parsed_value && + is.get() == std::istringstream::traits_type::eof(); // don't allow trailing characters + if (!parse_successful) { + return false; + } + + value = std::move(parsed_value); + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, std::string& value) { + value = str; + return true; +} + +inline bool TryParseStringWithClassicLocale(std::string_view str, bool& value) { + if (str == "0" || str == "False" || str == "false") { + value = false; + return true; + } + + if (str == "1" || str == "True" || str == "true") { + value = true; + return true; + } + + return false; +} + +/** + * Parses a value from an entire string. + */ +template +OrtStatus* ParseStringWithClassicLocale(std::string_view s, T& value) { + RETURN_IF_NOT(TryParseStringWithClassicLocale(s, value), "Failed to parse value: \"", value, "\""); + return nullptr; +} + +/** + * Parses a value from an entire string. + */ +template +T ParseStringWithClassicLocale(std::string_view s) { + T value{}; + ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(s, value)); + return value; +} diff --git a/plugin_execution_providers/tensorrt/utils/path_string.h b/plugin_execution_providers/tensorrt/utils/path_string.h new file mode 100644 index 000000000..ec2e8fcc8 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/path_string.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +// for std::tolower or std::towlower +#ifdef _WIN32 +#include +#else +#include +#endif + +#include "onnxruntime_c_api.h" + +// char type for filesystem paths +using PathChar = ORTCHAR_T; +// string type for filesystem paths +using PathString = std::basic_string; + +inline std::string ToUTF8String(const std::string& s) { return s; } +#ifdef _WIN32 +/** + * Convert a wide character string to a UTF-8 string + */ +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } +#else +inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } +#endif + +inline PathString ToPathString(const PathString& s) { + return s; +} + +#ifdef _WIN32 + +static_assert(std::is_same::value, "PathString is not std::wstring!"); + +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} +inline PathString ToPathString(const std::string& s) { + return ToWideString(s); +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::towlower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return ToUTF8String(s); +} + +#else + +static_assert(std::is_same::value, "PathString is not std::string!"); + +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::tolower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return s; +} + +#endif diff --git a/plugin_execution_providers/tensorrt/utils/provider_options.h b/plugin_execution_providers/tensorrt/utils/provider_options.h new file mode 100644 index 000000000..33beba2f4 --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/provider_options.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +// data types for execution provider options + +using ProviderOptions = std::unordered_map; +using ProviderOptionsVector = std::vector; +using ProviderOptionsMap = std::unordered_map; diff --git a/plugin_execution_providers/tensorrt/utils/provider_options_utils.h b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h new file mode 100644 index 000000000..4bf6b37ce --- /dev/null +++ b/plugin_execution_providers/tensorrt/utils/provider_options_utils.h @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "onnxruntime_c_api.h" +#include "ep_utils.h" +#include "parse_string.h" +#include "provider_options.h" + +template +using EnumNameMapping = std::vector>; + +/** + * Given a mapping and an enumeration value, gets the corresponding name. + */ +template +OrtStatus* EnumToName(const EnumNameMapping& mapping, TEnum value, std::string& name) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&value](const std::pair& entry) { + return entry.first == value; + }); + RETURN_IF( + it == mapping.end(), + "Failed to map enum value to name: ", static_cast::type>(value)); + name = it->second; + return nullptr; +} + +template +std::string EnumToName(const EnumNameMapping& mapping, TEnum value) { + std::string name; + THROW_IF_ERROR(EnumToName(mapping, value, name)); + return name; +} + +/** + * Given a mapping and a name, gets the corresponding enumeration value. + */ +template +OrtStatus* NameToEnum( + const EnumNameMapping& mapping, const std::string& name, TEnum& value) { + const auto it = std::find_if( + mapping.begin(), mapping.end(), + [&name](const std::pair& entry) { + return entry.second == name; + }); + RETURN_IF( + it == mapping.end(), + "Failed to map enum name to value: ", name); + value = it->first; + return nullptr; +} + +template +TEnum NameToEnum(const EnumNameMapping& mapping, const std::string& name) { + TEnum value; + THROW_IF_ERROR(NameToEnum(mapping, name, value)); + return value; +} + +class ProviderOptionsParser { + public: + /** + * Adds a parser for a particular provider option value. + * + * @param name The provider option name. + * @param value_parser An object that parses the option value. + * It should be callable with the following signature and return + * whether the parsing was successful: + * Status value_parser(const std::string&) + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddValueParser( + const std::string& name, ValueParserType value_parser) { + ENFORCE( + value_parsers_.emplace(name, ValueParser{value_parser}).second, + "Provider option \"", name, "\" already has a value parser."); + return *this; + } + + /** + * Adds a parser for a particular provider option value which converts a + * value to the right type and assigns it to the given reference. + * + * IMPORTANT: This function stores a reference to the destination variable. + * The caller must ensure that the reference is valid when Parse() is called! + * + * @param name The provider option name. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToReference( + const std::string& name, ValueType& dest) { + return AddValueParser( + name, + [&dest](const std::string& value_str) -> OrtStatus* { + return ParseStringWithClassicLocale(value_str, dest); + }); + } + + /** + * Adds a parser for a particular provider option value which maps an + * enumeration name to a value and assigns it to the given reference. + * + * IMPORTANT: This function stores references to the mapping and destination + * variables. The caller must ensure that the references are valid when + * Parse() is called! + * + * @param name The provider option name. + * @param mapping The enumeration value to name mapping. + * @param dest The destination variable reference. + * + * @return The current ProviderOptionsParser instance. + */ + template + ProviderOptionsParser& AddAssignmentToEnumReference( + const std::string& name, const EnumNameMapping& mapping, EnumType& dest) { + return AddValueParser( + name, + [&mapping, &dest](const std::string& value_str) -> OrtStatus* { + return NameToEnum(mapping, value_str, dest); + }); + } + + /** + * Parses the given provider options. + */ + OrtStatus* Parse(const ProviderOptions& options) const { + for (const auto& option : options) { + const auto& name = option.first; + const auto& value_str = option.second; + const auto value_parser_it = value_parsers_.find(name); + RETURN_IF( + value_parser_it == value_parsers_.end(), + "Unknown provider option: \"", name, "\"."); + + const auto parse_status = value_parser_it->second(value_str); + RETURN_IF_NOT( + (parse_status == nullptr), + "Failed to parse provider option \"", name, "\": "); + //"Failed to parse provider option \"", name, "\": ", parse_status.ErrorMessage()); + } + + return nullptr; + } + + private: + using ValueParser = std::function; + std::unordered_map value_parsers_; +}; diff --git a/plugin_execution_providers/test/tensorrt/CMakeLists.txt b/plugin_execution_providers/test/tensorrt/CMakeLists.txt new file mode 100644 index 000000000..253bf91d4 --- /dev/null +++ b/plugin_execution_providers/test/tensorrt/CMakeLists.txt @@ -0,0 +1,136 @@ +# usage: +# cd build/ +# cmake -S ../ -B ./ -DCMAKE_BUILD_TYPE=Debug -DORT_HOME=/path/to/ort_package/onnxruntime-win-x64-gpu-1.23.0 -DCMAKE_CUDA_ARCHITECTURES=80 -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc -DTENSORRT_HOME=/path/to/tensorrt/TensorRT-10.3.0.26 -DCMAKE_POSITION_INDEPENDENT_CODE=ON (see the result of "nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits") +# cmake --build ./ --config Debug +cmake_minimum_required(VERSION 3.26) +project(tensorrt_ep_test VERSION 1.0) +set(CMAKE_CXX_STANDARD 17) + +# CMake config to force dynamic debug CRT or dynamic release CRT globally for all dependencies. +# This is to address the issue of: +# libprotobufd.lib(common.obj) : error LNK2038: mismatch detected for 'RuntimeLibrary': value 'MTd_StaticDebug' doesn't match value 'MDd_DynamicDebug' in unary_elementwise_ops_impl.obj +if (WIN32) + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDebugDLL" CACHE STRING "" FORCE) # /MDd + set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime + endif() + + if(CMAKE_BUILD_TYPE STREQUAL "Release") + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreadedDLL" CACHE STRING "" FORCE) + set(BUILD_SHARED_LIBS OFF) # Build protobuf as static .lib, but using dynamic runtime + endif() +endif() + +add_definitions(-DONNX_NAMESPACE=onnx) +add_definitions(-DONNX_ML) +add_definitions(-DNOMINMAX) +file(GLOB tensorrt_ep_test_src "./*.cc" "./*.h") +add_executable(tensorrt_ep_test ${tensorrt_ep_test_src}) + +if (NOT ORT_HOME) + message(FATAL_ERROR "Please specify ORT_HOME, e.g. -DORT_HOME=/path/to/ort/") +endif() + +# Use release mode if not specified +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release") +endif() + +# Add dependencies +include(FetchContent) + +# Add protobuf +FetchContent_Declare( + protobuf + GIT_REPOSITORY https://github.com/protocolbuffers/protobuf.git + GIT_TAG v21.12 # Use a specific tag or commit +) + +if (WIN32) + # Sometimes, protobuf ignores CMAKE_MSVC_RUNTIME_LIBRARY. To ensure it works: + set(protobuf_MSVC_STATIC_RUNTIME OFF CACHE BOOL "" FORCE) +endif() + +set(protobuf_BUILD_TESTS OFF CACHE BOOL "" FORCE) + +FetchContent_MakeAvailable(protobuf) + +# Add ONNX +FetchContent_Declare( + onnx + GIT_REPOSITORY https://github.com/onnx/onnx.git + GIT_TAG v1.18.0 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(onnx) + +# Add GSL +FetchContent_Declare( + gsl + GIT_REPOSITORY https://github.com/microsoft/GSL.git + GIT_TAG v4.0.0 # Use a specific tag or commit +) + +FetchContent_MakeAvailable(gsl) + +# Add GoogleTest +FetchContent_Declare( + googletest + URL https://github.com/google/googletest/archive/refs/heads/main.zip +) +# For Windows: prevents overriding parent project's runtime library settings +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +set(DEPS_PATH "${CMAKE_BINARY_DIR}/_deps") + +if (WIN32) # Windows + set(ORT_LIB "${ORT_HOME}/lib/onnxruntime.lib") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobufd.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotocd.lib") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotobuf.lib" + "${DEPS_PATH}/protobuf-build/${CMAKE_BUILD_TYPE}/libprotoc.lib") + endif() + + set(DEPS_LIBS "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx.lib" + "${DEPS_PATH}/onnx-build/${CMAKE_BUILD_TYPE}/onnx_proto.lib") + +else() # Linux + set(ORT_LIB "${ORT_HOME}/lib/libonnxruntime.so") + + set(DEPS_LIBS "${DEPS_PATH}/onnx-build/libonnx.a" + "${DEPS_PATH}/onnx-build/libonnx_proto.a") + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobufd.a" + "${DEPS_PATH}/protobuf-build/libprotocd.a") + else() + set(DEPS_LIBS ${DEPS_LIBS} + "${DEPS_PATH}/protobuf-build/libprotobuf.a" + "${DEPS_PATH}/protobuf-build/libprotoc.a") + endif() +endif() + +MESSAGE(STATUS "Looking for following dependencies ...") +MESSAGE(STATUS "ORT lib : ${ORT_LIB}") +MESSAGE(STATUS "Deps libs: ${DEPS_LIBS}") + + +target_include_directories(tensorrt_ep_test PUBLIC "${ORT_HOME}/include" + "${DEPS_PATH}/gsl-src/include" # GSL is header-only + "${DEPS_PATH}/onnx-src" + "${DEPS_PATH}/onnx-build" + "${DEPS_PATH}/protobuf-src/src" +) + +target_link_libraries(tensorrt_ep_test PUBLIC #${DEPS_LIBS} + GTest::gtest GTest::gtest_main + protobuf::libprotobuf onnx + ${ORT_LIB} +) diff --git a/plugin_execution_providers/test/tensorrt/build_and_test_trt_ep.ps1 b/plugin_execution_providers/test/tensorrt/build_and_test_trt_ep.ps1 new file mode 100644 index 000000000..b943fac0e --- /dev/null +++ b/plugin_execution_providers/test/tensorrt/build_and_test_trt_ep.ps1 @@ -0,0 +1,65 @@ +# build_and_test_trt_ep.ps1 + +# Change to the directory where this script is located +Set-Location -Path $PSScriptRoot +Write-Host "Current directory set to: $PSScriptRoot" + +# Stop on first error +$ErrorActionPreference = "Stop" + +# Variables +$SourceDir = "../" +$BuildDir = "./" +$BuildType = "Debug" + +# ORT settings +$OrtVersion = "1.23.1" +#$OrtZipUrl = "https://github.com/microsoft/onnxruntime/releases/download/v$OrtVersion/onnxruntime-win-x64-gpu-$OrtVersion.zip" +$OrtZipUrl = "https://github.com/microsoft/onnxruntime/releases/download/v$OrtVersion/onnxruntime-win-x64-$OrtVersion.zip" +$OrtZipPath = "onnxruntime.zip" +$OrtHome = ".\ort_package" + +# Step 1: Download ONNX Runtime package +if (!(Test-Path $OrtHome)) { + Write-Host "=== Downloading ONNX Runtime $OrtVersion ===" + Invoke-WebRequest -Uri $OrtZipUrl -OutFile $OrtZipPath + + Write-Host "=== Extracting ONNX Runtime to $OrtHome ===" + Expand-Archive -Path $OrtZipPath -DestinationPath $OrtHome -Force + + # Clean up zip file + Remove-Item $OrtZipPath +} else { + Write-Host "ONNX Runtime directory already exists. Skipping download." +} + +# Step 2: Configure CMake +$buildDir = "build" +if (-Not (Test-Path $buildDir)) { + Write-Host "Creating build directory..." + New-Item -ItemType Directory -Path $buildDir | Out-Null +} +Set-Location -Path $buildDir + +Write-Host "=== Running CMake configure step ===" +$OrtHome = "$PSScriptRoot\ort_package\onnxruntime-win-x64-$OrtVersion" +cmake "-S" $SourceDir ` + "-B" $BuildDir ` + "-DCMAKE_BUILD_TYPE=$BuildType" ` + "-DORT_HOME=$OrtHome" + +if ($LASTEXITCODE -ne 0) { + Write-Error "CMake configuration failed!" + exit 1 +} + +# Step 3: Build +Write-Host "=== Building project ===" +cmake --build $BuildDir --config $BuildType --parallel + +if ($LASTEXITCODE -ne 0) { + Write-Error "Build failed!" + exit 1 +} + +Write-Host "=== Build completed successfully! ===" diff --git a/plugin_execution_providers/test/tensorrt/helper.cc b/plugin_execution_providers/test/tensorrt/helper.cc new file mode 100644 index 000000000..b66d8ade8 --- /dev/null +++ b/plugin_execution_providers/test/tensorrt/helper.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifdef _WIN32 +#include +#include +#endif + +#include + +#ifdef ORT_NO_EXCEPTIONS +#if defined(__ANDROID__) +#include +#else +#include +#endif +#endif + +#include + +#define THROW(...) throw std::runtime_error(std::string(__VA_ARGS__)); + +#ifdef _WIN32 +std::string ToUTF8String(std::wstring_view s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, nullptr, 0, nullptr, nullptr); + assert(len > 0); + std::string ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = WideCharToMultiByte(CP_UTF8, 0, s.data(), src_len, (char*)ret.data(), len, nullptr, nullptr); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} + +std::wstring ToWideString(std::string_view s) { + if (s.size() >= static_cast(std::numeric_limits::max())) + THROW("length overflow"); + + const int src_len = static_cast(s.size() + 1); + const int len = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, nullptr, 0); + assert(len > 0); + std::wstring ret(static_cast(len) - 1, '\0'); +#pragma warning(disable : 4189) + const int r = MultiByteToWideChar(CP_UTF8, 0, s.data(), src_len, (wchar_t*)ret.data(), len); + assert(len == r); +#pragma warning(default : 4189) + return ret; +} +#endif // #ifdef _WIN32 + +#ifdef NO_EXCEPTIONS +void PrintFinalMessage(const char* msg) { +#if defined(__ANDROID__) + __android_log_print(ANDROID_LOG_ERROR, "onnxruntime", "%s", msg); +#else + // TODO, consider changing the output of the error message from std::cerr to logging when the + // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr output + // might not be easily accessible on some systems such as mobile + // TODO, see if we need to change the output of the error message from std::cerr to NSLog for iOS + std::cerr << msg << std::endl; +#endif +} +#endif // #ifdef NO_EXCEPTIONS diff --git a/plugin_execution_providers/test/tensorrt/path_string.h b/plugin_execution_providers/test/tensorrt/path_string.h new file mode 100644 index 000000000..ec2e8fcc8 --- /dev/null +++ b/plugin_execution_providers/test/tensorrt/path_string.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +// for std::tolower or std::towlower +#ifdef _WIN32 +#include +#else +#include +#endif + +#include "onnxruntime_c_api.h" + +// char type for filesystem paths +using PathChar = ORTCHAR_T; +// string type for filesystem paths +using PathString = std::basic_string; + +inline std::string ToUTF8String(const std::string& s) { return s; } +#ifdef _WIN32 +/** + * Convert a wide character string to a UTF-8 string + */ +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } +#else +inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } +#endif + +inline PathString ToPathString(const PathString& s) { + return s; +} + +#ifdef _WIN32 + +static_assert(std::is_same::value, "PathString is not std::wstring!"); + +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} +inline PathString ToPathString(const std::string& s) { + return ToWideString(s); +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::towlower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return ToUTF8String(s); +} + +#else + +static_assert(std::is_same::value, "PathString is not std::string!"); + +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + +inline PathChar ToLowerPathChar(PathChar c) { + return std::tolower(c); +} + +inline std::string PathToUTF8String(const PathString& s) { + return s; +} + +#endif diff --git a/plugin_execution_providers/test/tensorrt/tensorrt_basic_test.cc b/plugin_execution_providers/test/tensorrt/tensorrt_basic_test.cc new file mode 100644 index 000000000..78b046158 --- /dev/null +++ b/plugin_execution_providers/test/tensorrt/tensorrt_basic_test.cc @@ -0,0 +1,185 @@ +#include +#include +#include + +#include "onnxruntime_cxx_api.h" +#include "test_trt_ep_utils.h" +#include "path_string.h" + +namespace test { +namespace trt_ep { + +// char type for filesystem paths +using PathChar = ORTCHAR_T; +// string type for filesystem paths +using PathString = std::basic_string; + +template +void VerifyOutptus(const std::vector& fetches, + const std::vector& expected_dims, + const std::vector& expected_values) { + ASSERT_EQ(1, fetches.size()); + const Ort::Value& actual_output = fetches[0]; + Ort::TensorTypeAndShapeInfo type_shape_info = actual_output.GetTensorTypeAndShapeInfo(); + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + auto shape = type_shape_info.GetShape(); + + ASSERT_EQ(element_type, ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(shape, expected_dims); + + size_t element_cnt = type_shape_info.GetElementCount(); + const T* actual_values = actual_output.GetTensorData(); + + ASSERT_EQ(element_cnt, expected_values.size()); + + for (size_t i = 0; i != element_cnt; ++i) { + ASSERT_EQ(actual_values[i], expected_values[i]); + } +} + +static OrtStatus* CreateOrtSession(Ort::Env& env, + PathString model_name, + std::string ep_name, + OrtSession** session) { + const OrtApi* ort_api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + + { + std::vector ep_devices = env.GetEpDevices(); + + // Find the Ort::EpDevice for "TensorRTEp". + std::vector selected_ep_devices = {}; + for (Ort::ConstEpDevice ep_device : ep_devices) { + // EP name should match the name assigned by the EP factory when creating the EP (i.e., in the implementation of + // OrtEP::CreateEp()) + if (std::string(ep_device.EpName()) == ep_name) { + selected_ep_devices.push_back(ep_device); + break; + } + } + + if (selected_ep_devices[0] == nullptr) { + // Did not find EP. Report application error ... + std::string message = "Did not find EP: " + ep_name; + return ort_api->CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + std::unordered_map ep_options; // Optional EP options. + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(env, selected_ep_devices, ep_options); + + Ort::Session ort_session(env, model_name.c_str(), session_options); + *session = ort_session.release(); + } + + return nullptr; +} + +static OrtStatus* RunInference(Ort::Session& session, + std::vector& outputs) { + // Get default ORT allocator + Ort::AllocatorWithDefaultOptions allocator; + + RETURN_IF_NOT(session.GetInputCount() == 3); + + // Get input names + Ort::AllocatedStringPtr input_name_ptr = + session.GetInputNameAllocated(0, allocator); // Keep the smart pointer alive to avoid dangling pointer + const char* input_name = input_name_ptr.get(); + + Ort::AllocatedStringPtr input_name2_ptr = session.GetInputNameAllocated(1, allocator); + const char* input_name2 = input_name2_ptr.get(); + + Ort::AllocatedStringPtr input_name3_ptr = session.GetInputNameAllocated(2, allocator); + const char* input_name3 = input_name3_ptr.get(); + + // Input data. + std::vector input_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // Input shape: (1, 3, 2) + std::vector input_shape{1, 3, 2}; + + // Create tensor + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + // Create input data as an OrtValue. + // Make input2 data and input3 data same as input1 data. + Ort::Value input_tensor = Ort::Value::CreateTensor(memory_info, input_values.data(), input_values.size(), + input_shape.data(), input_shape.size()); + Ort::Value input2_tensor = Ort::Value::CreateTensor(memory_info, input_values.data(), input_values.size(), + input_shape.data(), input_shape.size()); + Ort::Value input3_tensor = Ort::Value::CreateTensor(memory_info, input_values.data(), input_values.size(), + input_shape.data(), input_shape.size()); + + std::vector input_tensors; + input_tensors.reserve(3); + input_tensors.push_back(std::move(input_tensor)); + input_tensors.push_back(std::move(input2_tensor)); + input_tensors.push_back(std::move(input3_tensor)); + + // Get output name + Ort::AllocatedStringPtr output_name_ptr = + session.GetOutputNameAllocated(0, allocator); // Keep the smart pointer alive to avoid dangling pointer + const char* output_name = output_name_ptr.get(); + + // Run session + std::vector input_names{input_name, input_name2, input_name3}; + std::vector output_names{output_name}; + + auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_names.data(), input_tensors.data(), + input_tensors.size(), output_names.data(), 1); + outputs = std::move(output_tensors); + + return nullptr; +} + + + +TEST(TensorrtExecutionProviderTest, CreateSessionAndRunInference) { + Ort::Env env; + std::string lib_registration_name = "TensorRTEp"; + std::string& ep_name = lib_registration_name; + PathString lib_path = ORT_TSTR("TensorRTEp.dll"); + + // Register plugin TRT EP library with ONNX Runtime. + env.RegisterExecutionProviderLibrary( + lib_registration_name.c_str(), // Registration name can be anything the application chooses. + lib_path // Path to the plugin TRT EP library. + ); + + // Unregister the library using the application-specified registration name. + // Must only unregister a library after all sessions that use the library have been released. + auto unregister_plugin_eps_at_scope_exit = + gsl::finally([&]() { env.UnregisterExecutionProviderLibrary(lib_registration_name.c_str()); }); + + + std::string model_name = "basic_model_for_test.onnx"; + std::string graph_name = "basic_model"; + std::vector dims = {1, 3, 2}; + CreateBaseModel(model_name, graph_name, dims); + + OrtSession* session = nullptr; + ASSERT_EQ(CreateOrtSession(env, ToPathString(model_name), ep_name, &session), nullptr); + ASSERT_NE(session, nullptr); + Ort::Session ort_session{session}; + + std::vector output_tensors; + ASSERT_EQ(RunInference(ort_session, output_tensors), nullptr); + + // Extract output data + float* output_data = output_tensors.front().GetTensorMutableData(); + + std::cout << "Output:" << std::endl; + for (int i = 0; i < 6; i++) { + std::cout << output_data[i] << " "; + } + std::cout << std::endl; + + std::vector expected_values = {3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}; + std::vector expected_shape{1, 3, 2}; + VerifyOutptus(output_tensors, expected_shape, expected_values); + + +} + +} // namespace trt_ep +} // namespace test \ No newline at end of file diff --git a/plugin_execution_providers/test/tensorrt/test_trt_ep_utils.cc b/plugin_execution_providers/test/tensorrt/test_trt_ep_utils.cc new file mode 100644 index 000000000..2d7c319b5 --- /dev/null +++ b/plugin_execution_providers/test/tensorrt/test_trt_ep_utils.cc @@ -0,0 +1,104 @@ +#include "onnx/onnx_pb.h" +#include +#include + +namespace test { +namespace trt_ep { + +void CreateBaseModel(const std::string& model_path, const std::string& graph_name, const std::vector& dims, + bool add_non_zero_node = false) { + using namespace onnx; + + // --- Create a ModelProto --- + ModelProto model; + model.set_ir_version(onnx::IR_VERSION); + model.set_producer_name("onnx-example"); + model.set_producer_version("1.0"); + + // (Optionally) add an opset import for the standard domain + auto* opset_import = model.add_opset_import(); + opset_import->set_domain(""); // empty string = "ai.onnx" domain + opset_import->set_version(18); // Opset version + + // --- Create a GraphProto --- + GraphProto* graph = model.mutable_graph(); + graph->set_name(graph_name); + + // --- Define a FLOAT tensor type --- + TypeProto float_tensor; + float_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + for (auto d : dims) { + float_tensor.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(d); + } + + // --- Input X --- + ValueInfoProto* X = graph->add_input(); + X->set_name("X"); + *X->mutable_type() = float_tensor; + + // --- Input Y --- + ValueInfoProto* Y = graph->add_input(); + Y->set_name("Y"); + *Y->mutable_type() = float_tensor; + + // --- Node 1: Add(X, Y) -> node_1_out_1 --- + NodeProto* node1 = graph->add_node(); + node1->set_name("node_1"); + node1->set_op_type("Add"); + node1->add_input("X"); + node1->add_input("Y"); + node1->add_output("node_1_out_1"); + + // --- Input Z --- + ValueInfoProto* Z = graph->add_input(); + Z->set_name("Z"); + *Z->mutable_type() = float_tensor; + + // --- Node 2 (and maybe Node 3) --- + if (add_non_zero_node) { + // Node 2: Add(node_1_out_1, Z) -> node_2_out_1 + NodeProto* node2 = graph->add_node(); + node2->set_name("node_2"); + node2->set_op_type("Add"); + node2->add_input("node_1_out_1"); + node2->add_input("Z"); + node2->add_output("node_2_out_1"); + + // Node 3: NonZero(node_2_out_1) -> M + NodeProto* node3 = graph->add_node(); + node3->set_name("node_3"); + node3->set_op_type("NonZero"); + node3->add_input("node_2_out_1"); + node3->add_output("M"); + + // Output M is int64 tensor + TypeProto int_tensor; + int_tensor.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT64); + ValueInfoProto* M = graph->add_output(); + M->set_name("M"); + *M->mutable_type() = int_tensor; + + } else { + // Node 2: Add(node_1_out_1, Z) -> M + NodeProto* node2 = graph->add_node(); + node2->set_name("node_2"); + node2->set_op_type("Add"); + node2->add_input("node_1_out_1"); + node2->add_input("Z"); + node2->add_output("M"); + + // Output M is float tensor + ValueInfoProto* M = graph->add_output(); + M->set_name("M"); + *M->mutable_type() = float_tensor; + } + + // --- Serialize to disk --- + std::ofstream out(model_path, std::ios::binary); + if (!model.SerializeToOstream(&out)) { + throw std::runtime_error("Failed to write model to " + model_path); + } + out.close(); +} +} // namespace trt_ep +} // namespace test diff --git a/plugin_execution_providers/test/tensorrt/test_trt_ep_utils.h b/plugin_execution_providers/test/tensorrt/test_trt_ep_utils.h new file mode 100644 index 000000000..b5bbc8533 --- /dev/null +++ b/plugin_execution_providers/test/tensorrt/test_trt_ep_utils.h @@ -0,0 +1,73 @@ +#include +#include + +namespace test { +namespace trt_ep { + +std::string ToUTF8String(std::wstring_view s); +std::wstring ToWideString(std::string_view s); + +#define ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(std::string(__VA_ARGS__)); \ + } \ + } while (false) + +#define THROW(...) throw std::runtime_error(std::string(__VA_ARGS__)); + +#define RETURN_IF_ORTSTATUS_ERROR(fn) RETURN_IF_ERROR(fn) + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF_ORT_STATUS_ERROR(fn) \ + do { \ + auto _status = (fn); \ + if (!_status.IsOK()) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ...) \ + do { \ + if ((cond)) { \ + return Ort::GetApi().CreateStatus(ORT_EP_FAIL, std::string(__VA_ARGS__).c_str()); \ + } \ + } while (0) + +#define RETURN_IF_NOT(condition, ...) RETURN_IF(!(condition), __VA_ARGS__) + +#define MAKE_STATUS(error_code, msg) Ort::GetApi().CreateStatus(error_code, (msg)); + +#define THROW_IF_ERROR(expr) \ + do { \ + auto _status = (expr); \ + if (_status != nullptr) { \ + std::ostringstream oss; \ + oss << Ort::GetApi().GetErrorMessage(_status); \ + Ort::GetApi().ReleaseStatus(_status); \ + throw std::runtime_error(oss.str()); \ + } \ + } while (0) + +#define RETURN_FALSE_AND_PRINT_IF_ERROR(fn) \ + do { \ + OrtStatus* status = (fn); \ + if (status != nullptr) { \ + std::cerr << Ort::GetApi().GetErrorMessage(status) << std::endl; \ + return false; \ + } \ + } while (0) + +void CreateBaseModel(const std::string& model_path, + const std::string& graph_name, + const std::vector& dims, + bool add_non_zero_node = false); +} +} \ No newline at end of file