diff --git a/README.md b/README.md index def2a7cb5f7..745713e5810 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ TensorRT-LLM [![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads) [![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt) -[![version](https://img.shields.io/badge/release-1.1.0rc1-green)](./tensorrt_llm/version.py) +[![version](https://img.shields.io/badge/release-1.1.0rc2-green)](./tensorrt_llm/version.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap) diff --git a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h index 394f7fb7bfa..0978905b5e2 100644 --- a/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h +++ b/cpp/include/tensorrt_llm/batch_manager/createNewDecoderRequests.h @@ -24,7 +24,6 @@ #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/modelConfig.h" -#include "tensorrt_llm/runtime/request.h" #include "tensorrt_llm/runtime/worldConfig.h" namespace tensorrt_llm::runtime @@ -88,37 +87,6 @@ class CreateNewDecoderRequests : Algorithm SizeType32 maxSequenceLength, OptionalRef medusaBuffers) const; private: - //! @brief Setups decoder internal tensors for new speculative decoding request - static void newRequestSpeculativeDecoding(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, runtime::ModelConfig const& modelConfig, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream, - CudaStream const& decoderStream, SpeculativeDecodingMode const& speculativeDecodingMode, - SizeType32 maxDecodingEngineTokens); - - //! @brief Setups decoder internal tensors for new request in Draft model Sps mode - static void newRequestDraftTokensExternal(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - SamplingConfig const& samplingConfig, DecodingInput& jointDecodingInput, CudaStream const& decoderStream); - - //! @brief Setups decoder internal tensors for new Medusa request - static void newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens); - - //! @brief Setups decoder internal tensors for new Lookahead request - static void newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - //! @brief Setups decoder internal tensors for new Explicit draft tokens request - static void newRequestExplicitDraftTokens(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - //! @brief Setups decoder internal tensors for new Eagle request - static void newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream); - - [[nodiscard]] std::shared_ptr retrieveDraftLogits(runtime::ModelConfig const& modelConfig, - runtime::WorldConfig const& worldConfig, std::shared_ptr const& tensor, - runtime::BufferManager const& bufferManager) const; - bool mSpeculativeDecodingFastLogits; bool mIsLeaderInOrchMode; bool mIsNormalizeLogProbs; diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index e4d13c9e17b..f069e3ac7f5 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -1110,7 +1110,7 @@ class GenericLlmRequest [[nodiscard]] SizeType32 getNumDraftTokens() const { - return mDraftTokens->size(); + return hasDraftTokens() ? mDraftTokens->size() : 0; } void discardDraftTokens(SizeType32 numTokensToDiscard) diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h index deeb0fa0af4..4344f423ac1 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingInput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h @@ -102,11 +102,13 @@ class DecodingInput { public: TensorPtr draftLogits; + TensorPtr draftLogitsHost; TensorPtr draftProbs; TensorPtr targetProbs; TensorPtr numDraftTokens; TensorPtr numDraftTokensHost; TensorPtr draftTokenIds; + TensorPtr draftTokenIdsHost; TensorPtr useDraftLogits; TensorPtr useDraftLogitsHost; diff --git a/cpp/include/tensorrt_llm/runtime/request.h b/cpp/include/tensorrt_llm/runtime/request.h deleted file mode 100644 index e8f851b7d77..00000000000 --- a/cpp/include/tensorrt_llm/runtime/request.h +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "tensorrt_llm/executor/executor.h" -#include "tensorrt_llm/runtime/iTensor.h" - -#include - -namespace tensorrt_llm::runtime::decoder_batch -{ - -class Request -{ -public: - using TensorConstPtr = ITensor::SharedConstPtr; - using TensorPtr = ITensor::SharedPtr; - using BufferPtr = IBuffer::SharedPtr; - - explicit Request(SizeType32 inputLen) - : inputLen(inputLen) - { - } - - //! Mandatory parameters - SizeType32 inputLen; // Input length without draft tokens, increasing with generation steps - - // optional parameters - SizeType32 generatedTokensPerEngineStep{1}; // - - //! Optional parameters for speculative decoding - BufferPtr draftTokens; // [generatedTokensPerEngineStep - 1] on gpu - std::optional draftLogits; // [generatedTokensPerEngineStep - 1, vocabSize] on gpu - TensorPtr medusaPaths; // [maxDecodingTokens, maxPathLen], on gpu - TensorPtr medusaTreeIds; // [maxDecodingTokens], on gpu - std::optional lookaheadRuntimeConfig; - std::optional eagleConfig; -}; - -} // namespace tensorrt_llm::runtime::decoder_batch diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 16771709bb4..3335d69a015 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -20,11 +20,14 @@ #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/medusaBuffers.h" #include "tensorrt_llm/batch_manager/utils/logitsThread.h" +#include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/common/nvtxUtils.h" +#include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/decoderState.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingOutput.h" +#include "tensorrt_llm/runtime/iBuffer.h" #include "tensorrt_llm/runtime/runtimeKernels.h" #include "tensorrt_llm/runtime/speculativeDecodingMode.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" @@ -45,6 +48,8 @@ namespace tensorrt_llm::batch_manager using SizeType32 = CreateNewDecoderRequests::SizeType32; using TensorPtr = CreateNewDecoderRequests::TensorPtr; using SharedConstPtr = CreateNewDecoderRequests::SharedConstPtr; +template +using OptionalRef = tensorrt_llm::common::OptionalRef; namespace { @@ -320,149 +325,165 @@ void initializeOutputs(DecodingOutput& dJointOutput, SizeType32 batchSlot, SizeT TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -} // namespace - -void CreateNewDecoderRequests::newRequestSpeculativeDecoding(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - runtime::ModelConfig const& modelConfig, DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream, CudaStream const& decoderStream, - SpeculativeDecodingMode const& speculativeDecodingMode, SizeType32 maxDecodingEngineTokens) +void retrieveDraftLogits(TensorPtr& draftLogitsHost, std::shared_ptr const& reqDraftLogits, + ModelConfig const& modelConfig, WorldConfig const& worldConfig, bool speculativeDecodingFastLogits, + bool isLeaderInOrchMode, BufferManager const& bufferManager) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - if (speculativeDecodingMode.predictsDraftTokens()) + if (!speculativeDecodingFastLogits) { - auto const& stream = decoderStream; - BufferManager manager{std::make_shared(stream.get())}; + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + bufferManager.copy(*reqDraftLogits, *draftLogitsHost); + return; + } - auto& dJointOutput = jointDecodingOutput; + if (isLeaderInOrchMode) + { + // reqDraftLogits contains metadata for fast-logits path; validate size. + auto constexpr fastLogitsInfoSize = sizeof(te::SpeculativeDecodingFastLogitsInfo); + TLLM_CHECK_WITH_INFO(reqDraftLogits->getSizeInBytes() >= fastLogitsInfoSize, + "Draft logits metadata buffer is too small to hold SpeculativeDecodingFastLogitsInfo."); + te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo{}; + std::memcpy(&fastLogitsInfo, reqDraftLogits->data(), fastLogitsInfoSize); + utils::targetModelReceiveLogits(draftLogitsHost, fastLogitsInfo, modelConfig.getLogitsDtype()); - TensorPtr nextDraftTokens - = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokens, batchIdx, 1); - // FIXME: can we skip this? - manager.setZero(*nextDraftTokens); - if (speculativeDecodingMode.variableDraftLength()) + // Broadcast to other ranks if needed + if (worldConfig.isTensorParallel()) { - TensorPtr nextDraftTokensLen - = ITensor::slice(dJointOutput.speculativeDecodingOutputs->nextDraftTokensLen, batchIdx, 1); - manager.setZero(*nextDraftTokensLen); + auto const& commSession = COMM_SESSION; + auto shape = draftLogitsHost->getShape(); + commSession.bcastValue(shape.d[0], 0); + commSession.bcastValue(shape.d[1], 0); + commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0); } } - - if (speculativeDecodingMode.isDraftTokensExternal()) - { - newRequestDraftTokensExternal(batchIdx, request, samplingConfig, jointDecodingInput, decoderStream); - } - else if (speculativeDecodingMode.isMedusa()) - { - newRequestMedusa(batchIdx, request, jointDecodingInput, decoderStream, maxDecodingEngineTokens); - } - else if (speculativeDecodingMode.isLookaheadDecoding()) - { - newRequestLookahead(batchIdx, request, jointDecodingInput, jointDecodingOutput, runtimeStream); - } - else if (speculativeDecodingMode.isExplicitDraftTokens()) - { - newRequestExplicitDraftTokens(batchIdx, request, jointDecodingOutput, runtimeStream); - } - else if (speculativeDecodingMode.isEagle()) + else { - newRequestEagle(batchIdx, request, modelConfig, jointDecodingOutput, runtimeStream); + TLLM_CHECK_WITH_INFO(worldConfig.isTensorParallel(), + "Fast logits path requires tensor-parallel broadcast for non-leader ranks."); + + // Get logits from leader rank + auto const& commSession = COMM_SESSION; + int64_t dims[2]; + commSession.bcastValue(dims[0], 0); + commSession.bcastValue(dims[1], 0); + draftLogitsHost->reshape(ITensor::makeShape({dims[0], dims[1]})); + commSession.bcast(draftLogitsHost->data(), draftLogitsHost->getSizeInBytes(), mpi::MpiType::kUINT8, 0); } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} +}; -void CreateNewDecoderRequests::newRequestDraftTokensExternal(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, SamplingConfig const& samplingConfig, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream) +//! @brief Setups decoder internal tensors for new request in Draft model Sps mode +void newRequestDraftTokensExternal(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest const& llmReq, + SizeType32 numDecodingEngineTokens, runtime::ModelConfig const& modelConfig, WorldConfig const& worldConfig, + bool speculativeDecodingFastLogits, bool isLeaderInOrchMode, CudaStream const& decoderStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - BufferManager manager{std::make_shared(decoderStream.get())}; + BufferManager decoderBufferManager{std::make_shared(decoderStream.get())}; - auto& dJointInput = jointDecodingInput; + TLLM_CHECK(jointDecodingInput.externalDraftTokensInputs); + auto& externalDraftTokensInputs = jointDecodingInput.externalDraftTokensInputs; - auto const numDraftTokens = request.generatedTokensPerEngineStep - 1; + auto const& draftTokens = llmReq.getDraftTokens(); + auto const numDraftTokens = numDecodingEngineTokens - 1; - auto const useDraftLogits = request.draftLogits.has_value(); - if (useDraftLogits) + auto numDraftTokensHostRange = runtime::BufferRange(*externalDraftTokensInputs->numDraftTokensHost); + numDraftTokensHostRange[batchIdx] = numDraftTokens; + auto numDraftTokensView = ITensor::slice(externalDraftTokensInputs->numDraftTokens, batchIdx, 1); + runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream); + + if (numDraftTokens > 0) { - TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value()); + TensorPtr draftTokenIdsHostSlice + = ITensor::slice(externalDraftTokensInputs->draftTokenIdsHost, {batchIdx, 0}, numDraftTokens); + // Copy to pinned host memory (don't care about stream of bufferManager) + decoderBufferManager.copy(draftTokens->data(), *draftTokenIdsHostSlice); - TensorPtr draftLogitsReqBatchSlice - = ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1); - draftLogitsReqBatchSlice->squeeze(0); - TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens); - manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice); + TensorPtr draftTokenIdsSlice + = ITensor::slice(externalDraftTokensInputs->draftTokenIds, {batchIdx, 0}, numDraftTokens); + decoderBufferManager.copy(*draftTokenIdsHostSlice, *draftTokenIdsSlice); } - auto* useDraftLogitsHostPtr = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->useDraftLogitsHost); - useDraftLogitsHostPtr[batchIdx] = useDraftLogits; - auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1); + + auto const& draftLogits = llmReq.getDraftLogits(); + auto const useDraftLogits = draftLogits.has_value(); + + auto useDraftLogitsHostRange = runtime::BufferRange(*externalDraftTokensInputs->useDraftLogitsHost); + useDraftLogitsHostRange[batchIdx] = useDraftLogits; + auto useDraftLogitsView = ITensor::slice(externalDraftTokensInputs->useDraftLogits, batchIdx, 1); runtime::kernels::invokeFill(*useDraftLogitsView, useDraftLogits, decoderStream); - if (numDraftTokens > 0) + if (useDraftLogits) { - TensorPtr draftTokensReqBatchSlice - = ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1); - draftTokensReqBatchSlice->squeeze(0); - TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens); - TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens})); - manager.copy(*draftTokensView, *draftTokensReqTokensSlice); + TensorPtr draftLogitsHostSlice + = ITensor::slice(externalDraftTokensInputs->draftLogitsHost, {batchIdx, 0}, numDraftTokens); + retrieveDraftLogits(draftLogitsHostSlice, draftLogits.value(), modelConfig, worldConfig, + speculativeDecodingFastLogits, isLeaderInOrchMode, decoderBufferManager); + + TensorPtr draftLogitsSlice + = ITensor::slice(externalDraftTokensInputs->draftLogits, {batchIdx, 0}, numDraftTokens); + decoderBufferManager.copy(*draftLogitsHostSlice, *draftLogitsSlice); } - auto* numDraftTokensHostPtr - = runtime::bufferCast(*dJointInput.externalDraftTokensInputs->numDraftTokensHost); - numDraftTokensHostPtr[batchIdx] = numDraftTokens; - auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1); - runtime::kernels::invokeFill(*numDraftTokensView, numDraftTokens, decoderStream); - + auto const& samplingConfig = llmReq.mSamplingConfig; bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value(); float const constantThreshold = useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0]; - dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold; - dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold; + externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold; + externalDraftTokensInputs->constantThreshold = constantThreshold; TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestMedusa(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, CudaStream const& decoderStream, SizeType32 maxDecodingEngineTokens) +//! @brief Setups decoder internal tensors for new Medusa request +void newRequestMedusa(DecodingInput& jointDecodingInput, SizeType32 batchIdx, LlmRequest& llmReq, + SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, MedusaBuffers const& medusaBuffers, + CudaStream const& decoderStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + llmReq.mSamplingConfig.topKMedusaHeads = {medusaBuffers.mTopKs}; + // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? + // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. + auto medusaPaths = ITensor::slice(medusaBuffers.medusaPathsDevice, 0, 1); + auto medusaTreeIds = ITensor::slice(medusaBuffers.medusaTreeIdsDevice, 0, 1); + BufferManager manager{std::make_shared(decoderStream.get())}; - auto& dJointInput = jointDecodingInput; + auto& medusaInputs = jointDecodingInput.medusaInputs; TensorPtr curTokensPerStepSlice - = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaCurTokensPerStep), batchIdx, 1); + = ITensor::slice(constPointerCast(medusaInputs->medusaCurTokensPerStep), batchIdx, 1); // Context phase Medusa processes 1 token only, new value from targetTokensPerStep will be filled at the end // of first decoder runtime::kernels::invokeFill(*curTokensPerStepSlice, 1, decoderStream); TensorPtr targetTokensPerStepSlice - = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTargetTokensPerStep), batchIdx, 1); - auto const generatedTokensPerEngineStep = request.generatedTokensPerEngineStep; - TLLM_CHECK_WITH_INFO(generatedTokensPerEngineStep <= maxDecodingEngineTokens, - "Tokens per step for (%d) is larger than maximum tokens per step (%d)", generatedTokensPerEngineStep, + = ITensor::slice(constPointerCast(medusaInputs->medusaTargetTokensPerStep), batchIdx, 1); + TLLM_CHECK_WITH_INFO(numDecodingEngineTokens <= maxDecodingEngineTokens, + "Tokens per step for (%d) is larger than maximum tokens per step (%d)", numDecodingEngineTokens, maxDecodingEngineTokens); - runtime::kernels::invokeFill(*targetTokensPerStepSlice, generatedTokensPerEngineStep, decoderStream); + runtime::kernels::invokeFill(*targetTokensPerStepSlice, numDecodingEngineTokens, decoderStream); - TensorPtr pathsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaPaths), batchIdx, 1); - manager.copy(*request.medusaPaths, *pathsSlice); + TensorPtr pathsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaPaths), batchIdx, 1); + manager.copy(*medusaPaths, *pathsSlice); - TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(dJointInput.medusaInputs->medusaTreeIds), batchIdx, 1); - manager.copy(*request.medusaTreeIds, *treeIdsSlice); + TensorPtr treeIdsSlice = ITensor::slice(constPointerCast(medusaInputs->medusaTreeIds), batchIdx, 1); + manager.copy(*medusaTreeIds, *treeIdsSlice); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Lookahead request +void newRequestLookahead(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, + CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.lookaheadOutputs); + TLLM_CHECK(jointDecodingInput.lookaheadInputs); // The first generation step only generate 1 token. TensorPtr curTokensPerStepSlice @@ -472,65 +493,72 @@ void CreateNewDecoderRequests::newRequestLookahead(SizeType32 batchIdx, runtime: TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestExplicitDraftTokens(SizeType32 batchIdx, - runtime::decoder_batch::Request const& request, DecodingOutput& jointDecodingOutput, - CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Explicit draft tokens request +void newRequestExplicitDraftTokens( + DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.explicitDraftTokensBuffers); + auto const inputLen = llmReq.getPromptLen(); + TensorPtr positionIdsBaseSlice = ITensor::slice(jointDecodingOutput.explicitDraftTokensBuffers->positionIdsBase, batchIdx, 1); - runtime::kernels::invokeFill(*positionIdsBaseSlice, request.inputLen, runtimeStream); + runtime::kernels::invokeFill(*positionIdsBaseSlice, inputLen, runtimeStream); TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } -void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::decoder_batch::Request const& request, - runtime::ModelConfig const& modelConfig, DecodingOutput& jointDecodingOutput, CudaStream const& runtimeStream) +//! @brief Setups decoder internal tensors for new Eagle request +void newRequestEagle(DecodingOutput& jointDecodingOutput, SizeType32 batchIdx, LlmRequest const& llmReq, + runtime::ModelConfig const& modelConfig, executor::DecodingConfig const& decodingConfig, + CudaStream const& runtimeStream) { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); TLLM_CHECK(jointDecodingOutput.eagleBuffers); + auto& eagleBuffers = *jointDecodingOutput.eagleBuffers; + + auto const inputLen = llmReq.getPromptLen(); BufferManager manager{std::make_shared(runtimeStream.get())}; - TensorPtr eagleNetCtxRequestTypesHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxRequestTypesHost, batchIdx, 1); + TensorPtr eagleNetCtxRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetCtxRequestTypesHost, batchIdx, 1); TensorPtr eagleNetCtxContextLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxContextLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetCtxContextLengthsHost, batchIdx, 1); TensorPtr eagleNetCtxPastKeyValueLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetCtxPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetCtxRequestTypesHostSlice)[0] = 0; - runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetCtxContextLengthsHostSlice)[0] = inputLen; + runtime::bufferCast(*eagleNetCtxPastKeyValueLengthsHostSlice)[0] = inputLen; - TensorPtr eagleNetGenRequestTypesHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenRequestTypesHost, batchIdx, 1); + TensorPtr eagleNetGenRequestTypesHostSlice = ITensor::slice(eagleBuffers.eagleNetGenRequestTypesHost, batchIdx, 1); TensorPtr eagleNetGenContextLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenContextLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetGenContextLengthsHost, batchIdx, 1); TensorPtr eagleNetGenPastKeyValueLengthsHostSlice - = ITensor::slice(jointDecodingOutput.eagleBuffers->eagleNetGenPastKeyValueLengthsHost, batchIdx, 1); + = ITensor::slice(eagleBuffers.eagleNetGenPastKeyValueLengthsHost, batchIdx, 1); runtime::bufferCast(*eagleNetGenRequestTypesHostSlice)[0] = 1; - runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = request.inputLen; - runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = request.inputLen; + runtime::bufferCast(*eagleNetGenContextLengthsHostSlice)[0] = inputLen; + runtime::bufferCast(*eagleNetGenPastKeyValueLengthsHostSlice)[0] = inputLen; auto const eagleModule = std::dynamic_pointer_cast( modelConfig.getSpeculativeDecodingModulePtr()); std::optional eagleChoicesOpt; - if (request.eagleConfig) + auto const& eagleConfig = llmReq.getEagleConfig() ? llmReq.getEagleConfig() : decodingConfig.getEagleConfig(); + + if (eagleConfig) { - eagleChoicesOpt = request.eagleConfig->getEagleChoices(); + eagleChoicesOpt = eagleConfig->getEagleChoices(); } - if (!request.eagleConfig || !request.eagleConfig->useDynamicTree()) + if (!eagleConfig || !eagleConfig->useDynamicTree()) { - TensorPtr draftPathsHostSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPathsHost, batchIdx, 1); - TensorPtr draftPathsSlice = ITensor::slice(jointDecodingOutput.eagleBuffers->draftPaths, batchIdx, 1); + TensorPtr draftPathsHostSlice = ITensor::slice(eagleBuffers.draftPathsHost, batchIdx, 1); + TensorPtr draftPathsSlice = ITensor::slice(eagleBuffers.draftPaths, batchIdx, 1); // eagleConfig is nullptr or Eagle-1 std::vector topKs; @@ -546,6 +574,61 @@ void CreateNewDecoderRequests::newRequestEagle(SizeType32 batchIdx, runtime::dec TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); } +//! @brief Setups decoder internal tensors for new speculative decoding request +void newRequestSpeculativeDecoding(DecodingInput& jointDecodingInput, DecodingOutput& jointDecodingOutput, + SizeType32 batchIdx, LlmRequest& llmReq, SpeculativeDecodingMode const& speculativeDecodingMode, + SizeType32 numDecodingEngineTokens, SizeType32 maxDecodingEngineTokens, + OptionalRef medusaBuffers, runtime::ModelConfig const& modelConfig, + WorldConfig const& worldConfig, executor::DecodingConfig const& decodingConfig, bool speculativeDecodingFastLogits, + bool isLeaderInOrchMode, CudaStream const& runtimeStream, CudaStream const& decoderStream) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + if (speculativeDecodingMode.predictsDraftTokens()) + { + BufferManager manager{std::make_shared(decoderStream.get())}; + + TLLM_CHECK(jointDecodingOutput.speculativeDecodingOutputs); + auto& speculativeDecodingOutputs = *jointDecodingOutput.speculativeDecodingOutputs; + + TensorPtr nextDraftTokens = ITensor::slice(speculativeDecodingOutputs.nextDraftTokens, batchIdx, 1); + // FIXME: can we skip this? + manager.setZero(*nextDraftTokens); + if (speculativeDecodingMode.variableDraftLength()) + { + TensorPtr nextDraftTokensLen = ITensor::slice(speculativeDecodingOutputs.nextDraftTokensLen, batchIdx, 1); + manager.setZero(*nextDraftTokensLen); + } + } + + if (speculativeDecodingMode.isDraftTokensExternal()) + { + newRequestDraftTokensExternal(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, modelConfig, + worldConfig, speculativeDecodingFastLogits, isLeaderInOrchMode, decoderStream); + } + else if (speculativeDecodingMode.isMedusa()) + { + TLLM_CHECK(medusaBuffers); + newRequestMedusa(jointDecodingInput, batchIdx, llmReq, numDecodingEngineTokens, maxDecodingEngineTokens, + medusaBuffers.value(), decoderStream); + } + else if (speculativeDecodingMode.isLookaheadDecoding()) + { + newRequestLookahead(jointDecodingInput, jointDecodingOutput, batchIdx, runtimeStream); + } + else if (speculativeDecodingMode.isExplicitDraftTokens()) + { + newRequestExplicitDraftTokens(jointDecodingOutput, batchIdx, llmReq, runtimeStream); + } + else if (speculativeDecodingMode.isEagle()) + { + newRequestEagle(jointDecodingOutput, batchIdx, llmReq, modelConfig, decodingConfig, runtimeStream); + } + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +} // namespace + std::tuple, std::vector> CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedContextRequests, TensorPtr const& inputIds, executor::DecodingConfig const& decodingConfig, runtime::decoder::DecoderState& decoderState, @@ -563,9 +646,6 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon } inputIds->resize(decoderInputSize); - std::vector decoderRequests; - decoderRequests.reserve(finishedContextRequests.size()); - std::vector lookaheadPrompt; std::vector lookaheadAlgoConfigs; if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) @@ -597,36 +677,18 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon auto const promptLen = llmReq->getPromptLen(); - auto decoderRequest = decoder_batch::Request{promptLen}; - + SizeType32 numDecodingEngineTokens{1}; if (modelConfig.getSpeculativeDecodingMode().isDraftTokensExternal()) { - if (llmReq->hasDraftTokens()) - { - auto const& draftTokens = llmReq->getDraftTokens(); - // Copy to pinned host memory (don't care about stream of bufferManager) - decoderRequest.draftTokens = decoderBufferManager.copyFrom(*draftTokens, MemoryType::kPINNEDPOOL); - auto const& draftLogits = llmReq->getDraftLogits(); - if (draftLogits.has_value()) - { - decoderRequest.draftLogits - = retrieveDraftLogits(modelConfig, worldConfig, draftLogits.value(), decoderBufferManager); - } - decoderRequest.generatedTokensPerEngineStep = draftTokens->size() + 1; - } - else - { - decoderRequest.generatedTokensPerEngineStep = 1; - } + numDecodingEngineTokens = llmReq->getNumDraftTokens() + 1; } else if (!modelConfig.getSpeculativeDecodingMode().isNone()) { - decoderRequest.generatedTokensPerEngineStep = modelConfig.getMaxDecodingTokens(); + numDecodingEngineTokens = modelConfig.getMaxDecodingTokens(); } auto& dJointInput = decoderState.getJointDecodingInput(); - auto const numDecodingEngineTokens = decoderRequest.generatedTokensPerEngineStep; initializeInputLengths(dJointInput, batchSlot, promptLen, llmReq->mMaxNewTokens, numDecodingEngineTokens, maxSequenceLength, decoderBufferManager); decoderState.setNumDecodingEngineTokens(batchSlot, numDecodingEngineTokens); @@ -667,16 +729,7 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon { TLLM_CHECK(beamWidth == 1); - if (modelConfig.getSpeculativeDecodingMode().isMedusa()) - { - TLLM_CHECK(medusaBuffers); - llmReq->mSamplingConfig.topKMedusaHeads = {medusaBuffers->mTopKs}; - // FIXME: we must set medusa paths and tree ids not from seq slot, but from llmRequest? - // When multiple microbatches buffers are used, runtime buffers can not be addressed with seqSlot. - decoderRequest.medusaPaths = ITensor::slice(medusaBuffers->medusaPathsDevice, 0, 1); - decoderRequest.medusaTreeIds = ITensor::slice(medusaBuffers->medusaTreeIdsDevice, 0, 1); - } - else if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) + if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding()) { lookaheadPrompt.emplace_back(requestIds); @@ -684,67 +737,17 @@ CreateNewDecoderRequests::createDecoderRequests(RequestVector const& finishedCon = llmReq->getLookaheadConfig().value_or(decodingConfig.getLookaheadDecodingConfig().value()); lookaheadAlgoConfigs.emplace_back(lookaheadRuntimeConfig); } - else if (modelConfig.getSpeculativeDecodingMode().isEagle()) - { - decoderRequest.eagleConfig - = llmReq->getEagleConfig() ? llmReq->getEagleConfig() : decodingConfig.getEagleConfig(); - } - newRequestSpeculativeDecoding(batchSlot, decoderRequest, samplingConfig, modelConfig, - decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), runtimeStream, - decoderStream, decoderState.getSpeculativeDecodingMode(), decoderState.getMaxDecodingEngineTokens()); + newRequestSpeculativeDecoding(decoderState.getJointDecodingInput(), decoderState.getJointDecodingOutput(), + batchSlot, *llmReq, decoderState.getSpeculativeDecodingMode(), numDecodingEngineTokens, + decoderState.getMaxDecodingEngineTokens(), medusaBuffers, modelConfig, worldConfig, decodingConfig, + mSpeculativeDecodingFastLogits, mIsLeaderInOrchMode, runtimeStream, decoderStream); } - decoderRequests.push_back(decoderRequest); - inputOffset += promptLen; } return {std::move(lookaheadPrompt), std::move(lookaheadAlgoConfigs)}; } -std::shared_ptr CreateNewDecoderRequests::retrieveDraftLogits(ModelConfig const& modelConfig, - WorldConfig const& worldConfig, std::shared_ptr const& tensor, - BufferManager const& bufferManager) const -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - - if (!mSpeculativeDecodingFastLogits) - { - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return bufferManager.copyFrom(*tensor, MemoryType::kPINNEDPOOL); - } - - if (mIsLeaderInOrchMode) - { - te::SpeculativeDecodingFastLogitsInfo fastLogitsInfo; - std::memcpy(&fastLogitsInfo, tensor->data(), sizeof(fastLogitsInfo)); - auto logits = utils::targetModelReceiveLogits(fastLogitsInfo, modelConfig).value(); - - // Broadcast to other ranks if needed - if (worldConfig.isTensorParallel()) - { - auto const& commSession = COMM_SESSION; - auto shape = logits->getShape(); - commSession.bcastValue(shape.d[0], 0); - commSession.bcastValue(shape.d[1], 0); - commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); - } - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; - } - - // Get logits from leader rank - auto const& commSession = COMM_SESSION; - int64_t dims[2]; - commSession.bcastValue(dims[0], 0); - commSession.bcastValue(dims[1], 0); - auto const logitsDtype = modelConfig.getLogitsDtype(); - auto logits = tensorrt_llm::runtime::BufferManager::pinnedPool(ITensor::makeShape({dims[0], dims[1]}), logitsDtype); - commSession.bcast(logits->data(), logits->getSizeInBytes(), mpi::MpiType::kUINT8, 0); - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); - return logits; -}; - } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp index 484cd7c3c7b..7234ca9ba57 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp +++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.cpp @@ -121,8 +121,8 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS #endif // ENABLE_MULTI_DEVICE } -std::optional targetModelReceiveLogits( - executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig) +void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost, + executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype) { #if ENABLE_MULTI_DEVICE auto const& worldComm = tensorrt_llm::mpi::MpiComm::world(); @@ -151,10 +151,7 @@ std::optional targetModelReceiveLogits( int64_t dims[2]; MPICHECK(MPI_Mrecv(&dims, count, MPI_INT64_T, &msg, &status)); - auto const logitsDtype = modelConfig.getLogitsDtype(); - - auto tensor = tensorrt_llm::runtime::BufferManager::pinnedPool( - runtime::ITensor::makeShape({dims[0], dims[1]}), logitsDtype); + draftLogitsHost->reshape(runtime::ITensor::makeShape({dims[0], dims[1]})); worldComm.mprobe(fastLogitsInfo.draftParticipantId, mpi::MpiTag::kSpecDecLogitsData, &msg, &status); @@ -163,11 +160,7 @@ std::optional targetModelReceiveLogits( uint64_t const expectedSize = static_cast(dims[0]) * dims[1] * tc::getDTypeSize(logitsDtype); TLLM_CHECK((uint64_t) count == expectedSize); - MPICHECK(MPI_Mrecv(tensor->data(), count, MPI_UINT8_T, &msg, &status)); - - return tensor; -#else - return std::nullopt; + MPICHECK(MPI_Mrecv(draftLogitsHost->data(), count, MPI_UINT8_T, &msg, &status)); #endif // ENABLE_MULTI_DEVICE } diff --git a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h index 6d87ebee162..f19d5f5ef30 100644 --- a/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h +++ b/cpp/tensorrt_llm/batch_manager/utils/logitsThread.h @@ -21,10 +21,8 @@ #include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/runtime/common.h" #include "tensorrt_llm/runtime/iTensor.h" -#include "tensorrt_llm/runtime/modelConfig.h" #include -#include namespace tensorrt_llm::batch_manager { @@ -52,7 +50,7 @@ void draftModelSendLogitsThread(int device, std::atomic* draftModelThreadS std::shared_ptr const& crossKvCacheManager, std::shared_ptr const& peftCacheManager); -std::optional targetModelReceiveLogits( - executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, runtime::ModelConfig const& modelConfig); +void targetModelReceiveLogits(runtime::ITensor::SharedPtr& draftLogitsHost, + executor::SpeculativeDecodingFastLogitsInfo const& fastLogitsInfo, nvinfer1::DataType logitsDtype); } // namespace tensorrt_llm::batch_manager::utils diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index abccbe60a13..b5851dc1c2d 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -131,6 +131,7 @@ void DecoderState::setupSpeculativeDecodingBuffers( mSpeculativeDecodingMode = speculativeDecodingMode; + auto constexpr nvTokenIdType = TRTDataType::value; auto constexpr nvSizeType = TRTDataType::value; auto& dInput = mJointDecodingInput; @@ -179,6 +180,7 @@ void DecoderState::setupSpeculativeDecodingBuffers( DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs; externalDraftTokensInputs.draftLogits = bufferManager.emptyTensor(MemoryType::kGPU, dtype); + externalDraftTokensInputs.draftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, dtype); externalDraftTokensInputs.draftProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.targetProbs = bufferManager.emptyTensor(MemoryType::kGPU, dtype); externalDraftTokensInputs.numDraftTokens = bufferManager.emptyTensor(MemoryType::kGPU, nvSizeType); @@ -187,8 +189,8 @@ void DecoderState::setupSpeculativeDecodingBuffers( = bufferManager.emptyTensor(MemoryType::kGPU, TRTDataType::value); externalDraftTokensInputs.useDraftLogitsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType::value); - externalDraftTokensInputs.draftTokenIds - = bufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32); + externalDraftTokensInputs.draftTokenIds = bufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType); + externalDraftTokensInputs.draftTokenIdsHost = bufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvTokenIdType); dInput->externalDraftTokensInputs = externalDraftTokensInputs; } @@ -366,10 +368,16 @@ void DecoderState::reshapeSpeculativeDecodingBuffers(SpeculativeDecodingMode con {mMaxNumSequences, mMaxDecodingEngineTokens, mMaxBeamWidth, static_cast(vocabSizePadded)}); dInput.externalDraftTokensInputs->draftProbs->reshape(probsShape); dInput.externalDraftTokensInputs->targetProbs->reshape(probsShape); - dInput.externalDraftTokensInputs->draftLogits->reshape( - ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens, static_cast(vocabSizePadded)})); - dInput.externalDraftTokensInputs->draftTokenIds->reshape( - ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens})); + + auto const logitsShape = ITensor::makeShape( + {mMaxNumSequences, mMaxDecodingEngineTokens, static_cast(vocabSizePadded)}); + dInput.externalDraftTokensInputs->draftLogits->reshape(logitsShape); + dInput.externalDraftTokensInputs->draftLogitsHost->reshape(logitsShape); + + auto const tokenIdsShape = ITensor::makeShape({mMaxNumSequences, mMaxDecodingEngineTokens}); + dInput.externalDraftTokensInputs->draftTokenIds->reshape(tokenIdsShape); + dInput.externalDraftTokensInputs->draftTokenIdsHost->reshape(tokenIdsShape); + dInput.externalDraftTokensInputs->numDraftTokens->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->numDraftTokensHost->reshape(maxNumSequencesShape); dInput.externalDraftTokensInputs->useDraftLogits->reshape(maxNumSequencesShape); diff --git a/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md new file mode 100644 index 00000000000..5a73d047ea4 --- /dev/null +++ b/docs/source/torch/auto_deploy/advanced/serving_with_trtllm_serve.md @@ -0,0 +1,77 @@ +# Serving with trtllm-serve + +AutoDeploy integrates with the OpenAI-compatible `trtllm-serve` CLI so you can expose AutoDeploy-optimized models over HTTP without writing server code. This page shows how to launch the server with the AutoDeploy backend, configure it via YAML, and validate with a simple request. + +## Quick start + +Launch `trtllm-serve` with the AutoDeploy backend by setting `--backend _autodeploy`: + +```bash +trtllm-serve \ + meta-llama/Llama-3.1-8B-Instruct \ + --backend _autodeploy +``` + +- `model`: HF name or local path +- `--backend _autodeploy`: uses AutoDeploy runtime + +Once the server is ready, test with an OpenAI-compatible request: + +```bash +curl -s http://localhost:8000/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "meta-llama/Llama-3.1-8B-Instruct", + "messages":[{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Where is New York? Tell me in a single sentence."}], + "max_tokens": 32 + }' +``` + +## Configuration via YAML + +Use `--extra_llm_api_options` to supply a YAML file that augments or overrides server/runtime settings. + +```bash +trtllm-serve \ + meta-llama/Llama-3.1-8B \ + --backend _autodeploy \ + --extra_llm_api_options autodeploy_config.yaml +``` + +Example `autodeploy_config.yaml`: + +```yaml +# Compilation backend for AutoDeploy +compile_backend: torch-opt # options: torch-simple, torch-compile, torch-cudagraph, torch-opt + +# Runtime engine +runtime: trtllm # options: trtllm, demollm + +# Model loading +skip_loading_weights: false # set true for architecture-only perf runs + +# KV cache memory +free_mem_ratio: 0.8 # fraction of free GPU mem for KV cache + +# CUDA graph optimization +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64] + +# Attention backend +attn_backend: flashinfer # recommended for best performance +``` + +## Limitations and tips + +- KV cache block reuse is disabled automatically for AutoDeploy backend +- AutoDeploy backend doesn't yet support disaggregated serving. WIP +- For best performance: + - Prefer `compile_backend: torch-opt` + - Use `attn_backend: flashinfer` + - Set realistic `cuda_graph_batch_sizes` that match expected traffic + - Tune `free_mem_ratio` to 0.8–0.9 + +## See also + +- [AutoDeploy overview](../auto-deploy.md) +- [Benchmarking with trtllm-bench](./benchmarking_with_trtllm_bench.md) diff --git a/docs/source/torch/auto_deploy/auto-deploy.md b/docs/source/torch/auto_deploy/auto-deploy.md index fc00c0ccc3e..185e1f321ae 100644 --- a/docs/source/torch/auto_deploy/auto-deploy.md +++ b/docs/source/torch/auto_deploy/auto-deploy.md @@ -59,6 +59,7 @@ The exported graph then undergoes a series of automated transformations, includi - [Incorporating AutoDeploy into Your Own Workflow](./advanced/workflow.md) - [Expert Configurations](./advanced/expert_configurations.md) - [Performance Benchmarking](./advanced/benchmarking_with_trtllm_bench.md) +- [Serving with trtllm-serve](./advanced/serving_with_trtllm_serve.md) ## Roadmap diff --git a/examples/constraints.txt b/examples/constraints.txt index 4ce23b0de7a..8b0d1a00930 100644 --- a/examples/constraints.txt +++ b/examples/constraints.txt @@ -1,3 +1,3 @@ -tensorrt_llm==1.1.0rc1 +tensorrt_llm==1.1.0rc2 evaluate~=0.4.1 rouge_score~=0.1.2 diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py index 414039a5065..01fb0deb576 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py @@ -198,7 +198,6 @@ def prepare_flashinfer_metadata( flashinfer.get_seq_lens(paged_kv_indptr, paged_kv_last_page_len, page_size), position_ids.numel(), ) - # return metadata return ( qo_indptr, diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 812dfea29cd..9811274a8bc 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -274,6 +274,16 @@ def quant_config(self, value: QuantConfig): self._quant_config = value ### VALIDATION ################################################################################# + @field_validator("max_seq_len", mode="before") + @classmethod + def ensure_max_seq_len(cls, value: Any, info: ValidationInfo) -> Any: + if value is None: + # Fallback to the AutoDeployConfig default when not provided + return AutoDeployConfig.model_fields["max_seq_len"].get_default( + call_default_factory=True + ) + return value + @field_validator("build_config", mode="before") @classmethod def ensure_no_build_config(cls, value: Any, info: ValidationInfo) -> Any: diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 8eb9acfada2..c9b9fa979fe 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -65,7 +65,7 @@ from ..modules.multi_stream_utils import maybe_execute_in_parallel from ..modules.rms_norm import RMSNorm from ..peft.lora.layer import LoraLayer -from ..speculative import MTPSpecMetadata, SpecMetadata +from ..speculative import SpecMetadata from ..utils import AuxStreamType, EventType, Fp4QuantizedTensor from .modeling_speculative import SpecDecOneEngineForCausalLM from .modeling_utils import (DecoderModel, EagerFusionConfig, filter_weights, @@ -230,7 +230,7 @@ def __init__( aux_stream: Optional[torch.cuda.Stream] = None, ): config = model_config.pretrained_config - predicted_tokens_per_seq = model_config.spec_config.num_nextn_predict_layers + 1 if model_config.spec_config is not None else 1 + predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 super().__init__(hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, @@ -750,6 +750,7 @@ def forward( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: if residual is None: @@ -765,16 +766,24 @@ def forward( **kwargs, ) if isinstance(self.mlp, Deepseekv3MoE): + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MOE_FUSION = False return self.forward_MoE( hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, + spec_metadata=spec_metadata, ) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + self.fusion_config.POST_MLP_FUSION = False assert isinstance(self.mlp, GatedMLP) return self.forward_mlp( hidden_states=hidden_states, residual=residual, + spec_metadata=spec_metadata, ) def forward_MoE( @@ -782,6 +791,7 @@ def forward_MoE( hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): @@ -856,6 +866,10 @@ def _run_MoE(hidden_states, hidden_states_fp4, do_finalize): hidden_states, residual = self.moe_allreduce( fc2_output, all_reduce_params=moe_all_reduce_params) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) @@ -866,6 +880,7 @@ def forward_mlp( self, hidden_states: torch.Tensor, residual: torch.Tensor, + spec_metadata: Optional[SpecMetadata] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if self.fusion_config.PRE_MLP_FUSION: @@ -903,6 +918,10 @@ def forward_mlp( ), ) else: + if spec_metadata is not None and spec_metadata.is_layer_capture( + self.layer_idx): + spec_metadata.maybe_capture_hidden_states( + self.layer_idx, hidden_states, residual) if self.next_layer_layernorm is not None: hidden_states, residual = self.next_layer_layernorm( hidden_states, residual) @@ -1105,6 +1124,7 @@ def forward( hidden_states=hidden_states, attn_metadata=attn_metadata, residual=residual, + spec_metadata=spec_metadata, ) return hidden_states @@ -1132,7 +1152,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): model_config=model_config) self.model_nextn = 0 - if model_config.spec_config is not None: + if model_config.spec_config is not None and model_config.spec_config.spec_dec_mode.is_mtp( + ): model_nextn = model_config.spec_config.num_nextn_predict_layers ckpt_nextn = self.config.num_nextn_predict_layers self.num_hidden_layers = self.config.num_hidden_layers @@ -1167,11 +1188,10 @@ def forward( input_ids: torch.IntTensor = None, position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - spec_metadata: Optional[MTPSpecMetadata] = None, + spec_metadata: Optional[SpecMetadata] = None, return_context_logits: bool = False, **kwargs, ) -> torch.Tensor: - attn_metadata.num_generations_per_batch = self.model_nextn + 1 return super().forward(attn_metadata=attn_metadata, input_ids=input_ids, position_ids=position_ids, @@ -1313,7 +1333,9 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor, for name, module in tqdm(all_named_modules.items(), desc="Loading weights"): - if len(module._parameters) > 0: + if len(module._parameters) <= 0 or name.startswith("draft_model"): + continue + else: names = name.split('.') parent_module_name = '.'.join(names[:-1]) if "model.layers" in name and int( diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_h.py b/tensorrt_llm/_torch/models/modeling_nemotron_h.py index 41f870f890d..e548d09a084 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_h.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_h.py @@ -221,7 +221,9 @@ def forward( ) if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests: - self.mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests) + self.mamba_metadata = Mamba2Metadata( + attn_metadata.max_num_requests, + chunk_size=self.model_config.pretrained_config.chunk_size) self.mamba_metadata.prepare(attn_metadata) if inputs_embeds is None: diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index f82c3b4de06..56a489c9635 100644 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -155,10 +155,12 @@ def __init__( else: self.hidden_size_in = config.hidden_size - self.fc = Linear(self.hidden_size_in * 3, - config.hidden_size, - bias=getattr(config, "bias", False), - dtype=config.torch_dtype) + if self.spec_config.num_capture_layers > 1: + self.fc = Linear(self.hidden_size_in * + self.spec_config.num_capture_layers, + config.hidden_size, + bias=getattr(config, "bias", False), + dtype=config.torch_dtype) self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx) diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py index 445c288e6ff..d421cc9209d 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py @@ -13,15 +13,83 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math +from typing import Tuple + import torch from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata +def cu_seqlens_to_chunk_indices_offsets( + cu_seqlens: torch.Tensor, + chunk_size: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + cu_seqlens (torch.Tensor): 1D tensor of cumulative sequence lengths, shape (num_seqs + 1,). The first element should be 0. Each entry represents the starting index of a sequence in the flattened token array. + chunk_size (int): The size of each physical mamba chunk (number of tokens per chunk). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: A tuple containing: + - chunk_indices (torch.Tensor): 1D tensor of indices indicating the physical chunk for each logical chunk. + - chunk_offsets (torch.Tensor): 1D tensor of offsets indicating the starting index of each logical chunk within its physical chunk. + + This function computes the chunk indices and offsets for the given cu_seqlens and chunk_size. + Both are tensors of integers with length N, where N is the number of logical (pseudo) chunks. + A logical chunk is a sequence of tokens that are all part of the same sequence and are all in the same physical mamba chunk. + In other words, a logical chunk changes every time we cross a sequence boundary or a physical mamba chunk boundary. + Logical chunks are needed to handle batched requests with initial states (see _state_passing_fwd and _chunk_scan_fwd). + The chunk_indices tensor contains the index of the physical chunk for each logical chunk. + The chunk_offsets tensor contains the offset (AKA starting index) of the logical chunk in the physical chunk. + + Example: + cu_seqlens = [0, 5, 10] + chunk_size = 8 + -> chunk_indices = [0, 1, 0] + -> chunk_offsets = [0, 5, 0] + + In this example, we have 2 sequences, each with 5 tokens. The physical chunk size is 8 tokens. + We have three logical chunks: + - the first logical chunk starts at token 0 in the first physical chunk and contains all 5 tokens from the first sequence + - the second logical chunk starts at token 5 in the first physical chunk and contains first 3 tokens from the second sequence + - the third logical chunk starts at token 0 in the second physical chunk and contains the remaining 2 tokens from the second sequence + """ + + total_seqlens = cu_seqlens[-1] + cu_seqlens = cu_seqlens[1:] # remove prepended 0 + + # outputs will have length expansion of chunks that do not divide + # chunk_size + N = math.ceil(total_seqlens / chunk_size) + (cu_seqlens[:-1] % chunk_size + > 0).sum() + chunk_indices = torch.arange(N, dtype=torch.int, device=cu_seqlens.device) + chunk_offsets = torch.zeros((N, ), + dtype=torch.int, + device=cu_seqlens.device) + + p = 0 # num of insertions + for s, e in zip(cu_seqlens[:-1], cu_seqlens[1:]): + + # if does not divide chunk_size, then there is one chunk insertion + p += (s % chunk_size > 0) + + # get the dimensions + # - the + 1 for _e is to shift the boundary by one chunk + # - this shifting is not needed if chunk_size divides e + _s, _e = s // chunk_size + p, e // chunk_size + p + (e % chunk_size > 0) + + # adjust inidces and offsets + chunk_indices[_s:_e] -= p + chunk_offsets[_s] = s % chunk_size + + return chunk_indices, chunk_offsets + + class Mamba2Metadata: - def __init__(self, max_batch_size: int): + def __init__(self, max_batch_size: int, chunk_size: int): self.max_batch_size = max_batch_size + self.chunk_size = chunk_size # cumulative sequence lengths for prefill requests [batch_size+1] self.cu_seqlens = torch.zeros(max_batch_size + 1, @@ -31,9 +99,18 @@ def __init__(self, max_batch_size: int): # sequence index for prefill requests [num_prefill_tokens] - specifies which request each token belongs to self.seq_idx: torch.Tensor = None + # helper tensors for chunked prefill + self.has_initial_states = torch.zeros(max_batch_size, + dtype=torch.bool, + device="cuda") + self.use_initial_states = False + self.chunk_indices: torch.Tensor = None + self.chunk_offsets: torch.Tensor = None + def prepare(self, attn_metadata: AttentionMetadata): num_contexts = attn_metadata.num_contexts context_lens = attn_metadata.seq_lens_cuda[:num_contexts] + num_ctx_tokens = attn_metadata.num_ctx_tokens if num_contexts > 0: torch.cumsum(context_lens, dim=0, @@ -44,4 +121,17 @@ def prepare(self, attn_metadata: AttentionMetadata): dtype=torch.int, device=self.cu_seqlens.device), repeats=context_lens, - output_size=self.cu_seqlens[num_contexts]).unsqueeze(0) + output_size=num_ctx_tokens).unsqueeze(0) + + num_cached_tokens_per_seq = attn_metadata.kv_cache_params.num_cached_tokens_per_seq + self.has_initial_states[:num_contexts] = torch.tensor( + num_cached_tokens_per_seq[:num_contexts]) > 0 + # precomputed bool to avoid host<->device syncs during forward pass + self.use_initial_states = torch.any( + self.has_initial_states[:num_contexts]).item() + if self.use_initial_states: + self.chunk_indices, self.chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + self.cu_seqlens[:num_contexts + 1], self.chunk_size) + else: + self.chunk_indices = None + self.chunk_offsets = None diff --git a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py index 6ea096bb6a7..d5a3e3996a3 100644 --- a/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py +++ b/tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py @@ -191,12 +191,15 @@ def forward( cu_seqlens = mamba_metadata.cu_seqlens[:num_prefills + 1] seq_idx = mamba_metadata.seq_idx + has_initial_states = mamba_metadata.has_initial_states[: + num_prefills] xbc_p = causal_conv1d_fn(xbc_p.transpose(0, 1), self.conv1d.weight, self.conv1d.bias, activation="silu", conv_states=conv_states, + has_initial_state=has_initial_states, query_start_loc=cu_seqlens, cache_indices=state_indices_p).transpose( 0, 1) @@ -216,6 +219,12 @@ def forward( "b l (h p) -> b l h p", h=self.tp_nheads) + initial_states = None + if mamba_metadata.use_initial_states: + initial_states = torch.where( + has_initial_states[:, None, None, None], + ssm_states[state_indices_p], 0) + y, current_ssm_states = mamba_chunk_scan_combined( x_p, dt_p, @@ -226,7 +235,9 @@ def forward( D=self.D, z=z_p, dt_bias=self.dt_bias, - initial_states=None, + initial_states=initial_states, + chunk_indices=mamba_metadata.chunk_indices, + chunk_offsets=mamba_metadata.chunk_offsets, dt_softplus=self.delta_softplus, cu_seqlens=cu_seqlens, seq_idx=seq_idx, diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py index 58615ab9238..23b55d8811d 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_chunk_scan.py @@ -314,11 +314,12 @@ def _chunk_scan_fwd_kernel( # get the cs at the offset boundary # - c_off == 0 is a passthrough + # - We need dA_cs at the boundary, defined by c_off - no need + # to increase pointer by pid_m (it is a constant offset, + # i.e. the same for all blocks) dA_cs_m_boundary = tl.load( - dA_cumsum_ptr + - (pid_m * BLOCK_SIZE_M + c_off - 1) * stride_dA_cs_csize, - mask=(((pid_m * BLOCK_SIZE_M + c_off - 1) > -1) - and ((pid_m * BLOCK_SIZE_M + c_off) < chunk_size)), + dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize, + mask=(((c_off - 1) > -1) and (c_off < chunk_size)), other=0.0).to(tl.float32) if HAS_SEQ_IDX: diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py index 0a6f18bb63b..8edbe902bd8 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_combined.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_combined.py @@ -110,21 +110,24 @@ def _mamba_chunk_scan_combined_fwd( # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries # (middle term of factorization of off-diag blocks; A terms) # - for handling chunked prefill, this requires i) initial_states - # ii) seq_idx and iii) is_cont_batched to be all specified. + # ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified. # - When a new seq_idx is detected, we will stop passing the prev_state # and switch accordingly to the init_state corresponding to the new seq_idx. + # - We will also make sure that the dA_cumsum is taken only from the start of the + # sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries) # - this will ensure that states will be updated with the rightmost flushed seq_idx # of the previous chunk. This implies that the first chunk of states is either 0 # or equal to init_states of the first example. states, final_states = _state_passing_fwd( rearrange(states, "... p n -> ... (p n)"), - dA_cumsum[:, :, :, -1], + dA_cumsum, initial_states=(rearrange(initial_states, "... p n -> ... (p n)") if initial_states is not None else None), seq_idx=seq_idx, chunk_size=chunk_size, out_dtype=mamba_ssm_cache_dtype or C.dtype, - is_cont_batched=cu_seqlens is not None) + is_cont_batched=cu_seqlens is not None, + chunk_offsets=chunk_offsets) states, final_states = [ rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states] diff --git a/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py b/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py index e1c4b61eaf8..f751d4cd5f5 100644 --- a/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py +++ b/tensorrt_llm/_torch/modules/mamba/ssd_state_passing.py @@ -41,6 +41,8 @@ def _state_passing_fwd_kernel( dA_cs_ptr, initstates_ptr, seq_idx_ptr, + chunk_offsets_ptr, + chunk_meta_num, # Matrix dimensions dim, nchunks, @@ -61,6 +63,7 @@ def _state_passing_fwd_kernel( stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, + stride_dA_cs_csize, stride_initstates_batch, stride_initstates_head, stride_initstates_dim, @@ -76,7 +79,8 @@ def _state_passing_fwd_kernel( pid_h = tl.program_id(axis=2) pid_m = tl.program_id(axis=0) states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head - dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + dA_cs_ptr += pid_b * stride_dA_cs_batch + pid_h * stride_dA_cs_head + ( + chunk_size - 1) * stride_dA_cs_csize out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head final_states_ptr += (pid_b * stride_final_states_batch + pid_h * stride_final_states_head) @@ -105,35 +109,63 @@ def _state_passing_fwd_kernel( other=0.0).to(tl.float32) tl.store(out_ptrs, states, mask=offs_m < dim) out_ptrs += stride_out_chunk - seq_idx = 0 + prev_seq_idx_chunk_end = 0 + logical_chunk_idx = 0 for c in range(nchunks): new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) dA_cs = tl.load(dA_cs_ptr).to(tl.float32) - scale = tl.exp(dA_cs) + scale_mask = True if HAS_SEQ_IDX: # - the seq to pass forward is the one that is flushed to the right # boundary. - # - that is given by seq_idx_new below. - seq_idx_new = tl.load(seq_idx_ptr + - (min((c + 1) * chunk_size, seqlen) - 1) * - stride_seq_idx_seqlen) + # - that is given by seq_idx_chunk_end below: the sequence index at the end of the chunk. + seq_idx_chunk_end = tl.load(seq_idx_ptr + (min( + (c + 1) * chunk_size, seqlen) - 1) * stride_seq_idx_seqlen) if HAS_INITSTATES: - if IS_CONT_BATCHED and seq_idx != seq_idx_new: + if IS_CONT_BATCHED and prev_seq_idx_chunk_end != seq_idx_chunk_end: # this means in the current chunk the rightmost flushed seq # has changed. # - so we do not propagate the state from previous chunk # - but rather we load that sequence's init state - initstates_ptrs = initstates_ptr + seq_idx_new * stride_initstates_batch + initstates_ptrs = initstates_ptr + seq_idx_chunk_end * stride_initstates_batch # - update state with seq_idx_new's init state states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + # - we need to consider the cumsum only of the last sequence in the chunk + # - find its starting position (given by c_off of the logical chunk index) + # - and subtract the cumsum just before that position from the total cumsum + # - first, update the logical chunk index (add the number of sequences in the current physical chunk): + # sequence index at the start of the current chunk + seq_idx_chunk_start = tl.load(seq_idx_ptr + + min(c * chunk_size, seqlen) * + stride_seq_idx_seqlen) + logical_chunk_idx += (seq_idx_chunk_end - + seq_idx_chunk_start) + # - load the chunk offset: + c_off = tl.load(chunk_offsets_ptr + logical_chunk_idx, + mask=logical_chunk_idx < chunk_meta_num, + other=0) + # - if offset is 0, then the sequence starts at the beginning of the chunk, and we don't need to subtract anything + if c_off > 0: + # - dA_cs_ptr currently points to the cumsum at the end of the chunk - subtract the chunk size and add the offset + dA_cs_boundary = tl.load( + dA_cs_ptr - (chunk_size - 1) * stride_dA_cs_csize + + (c_off - 1) * stride_dA_cs_csize, + mask=(c_off - 1) > -1 and c_off < chunk_size, + other=0.0) + dA_cs -= dA_cs_boundary + + # - increment logical chunk index for every physical chunk + logical_chunk_idx += 1 else: - scale = tl.where(seq_idx_new == seq_idx, scale, 0.0) + scale_mask = seq_idx_chunk_end == prev_seq_idx_chunk_end + prev_seq_idx_chunk_end = seq_idx_chunk_end - seq_idx = seq_idx_new + scale = tl.where(scale_mask, tl.exp(dA_cs), 0.0) states = scale * states + new_states if c < nchunks - 1: tl.store(out_ptrs, states, mask=offs_m < dim) @@ -146,28 +178,36 @@ def _state_passing_fwd_kernel( def _state_passing_fwd( states, - dA_chunk_cumsum, + dA_cumsum, initial_states=None, seq_idx=None, chunk_size=None, out_dtype=None, is_cont_batched=False, + chunk_offsets=None, ): batch, nchunks, nheads, dim = states.shape - assert dA_chunk_cumsum.shape == (batch, nheads, nchunks) + if chunk_size is None: + chunk_size = dA_cumsum.shape[-1] + else: + assert chunk_size == dA_cumsum.shape[-1] + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) if initial_states is not None: if is_cont_batched: # - if cu_seqlens is provided, then the initial states # are used for continuous batching. In which case we # require seq_idx to be provided - assert seq_idx is not None, "" + assert seq_idx is not None, "seq_idx must be provided for continuous batching" + # - we also need chunk_offsets to be provided, to account + # for computation of dA_cumsum from the start of the + # sequence + assert chunk_offsets is not None, "chunk_offsets must be provided for continuous batching" else: # - this is the regular batching case, where initial # states are used are for each example of the batch. assert initial_states.shape == (batch, nheads, dim) if seq_idx is not None: - assert chunk_size is not None seqlen = seq_idx.shape[-1] assert seq_idx.shape == (batch, seqlen) out_dtype = states.dtype if out_dtype is None else out_dtype @@ -183,13 +223,15 @@ def _state_passing_fwd( states, out, final_states, - dA_chunk_cumsum, + dA_cumsum, initial_states, seq_idx, + chunk_offsets, + len(chunk_offsets) if chunk_offsets is not None else 0, dim, nchunks, seqlen if seq_idx is not None else 0, - chunk_size if seq_idx is not None else 0, + chunk_size, states.stride(0), states.stride(1), states.stride(2), @@ -201,9 +243,10 @@ def _state_passing_fwd( final_states.stride(0), final_states.stride(1), final_states.stride(2), - dA_chunk_cumsum.stride(0), - dA_chunk_cumsum.stride(2), - dA_chunk_cumsum.stride(1), + dA_cumsum.stride(0), + dA_cumsum.stride(2), + dA_cumsum.stride(1), + dA_cumsum.stride(3), *(( initial_states.stride(0), initial_states.stride(1), diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index df674a94968..0007b99ebd2 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -1,122 +1,309 @@ -from typing import Any, Callable, Dict, Optional, Tuple +import bisect +import contextlib +import weakref +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple import torch -from ..attention_backend.interface import AttentionMetadata +from ..expert_statistic import ExpertStatistic from ..modules.multi_stream_utils import with_multi_stream -from ..speculative.interface import SpecMetadata from ..utils import make_weak_ref, piecewise_cuda_graph +from .resource_manager import ResourceManager, ResourceManagerType +from .scheduler import ScheduledRequests +if TYPE_CHECKING: + from .model_engine import PyTorchModelEngine -class DecodingCUDAGraphRunner: +# A large prime number used for dummy request IDs to avoid collisions +CUDA_GRAPH_DUMMY_REQUEST_ID = (1 << 64) - 1 - def __init__( - self, - batch_size: int, - device: str, - attn_metadata: AttentionMetadata, - spec_metadata: Optional[SpecMetadata] = None, - use_mrope: bool = False, - max_beam_width: int = 1, - ) -> None: - """ - Stores a CUDA graph and its associated input buffers. - Each CUDA graph runner is associated with an AttentionMetadata object - if flashinfer is being used. Make sure to call attn_metadata.prepare() - before run()! +class CUDAGraphRunner: + """ + Manages the lifecycle and execution of CUDA graphs for the model engine. + + This unified class handles high-level orchestration (padding, eligibility) + and low-level execution (capturing, resource management, replaying) for + multiple graphs, keyed by (batch size, draft_len). + """ + WARMUP_STEPS = 2 + + def __init__(self, engine: "PyTorchModelEngine"): + self.engine_ref = weakref.ref(engine) + + # High-level configuration + config = engine.pytorch_backend_config + self.enabled = config.use_cuda_graph + self.padding_enabled = config.cuda_graph_padding_enabled + self.supported_batch_sizes = engine._cuda_graph_batch_sizes + self.max_supported_batch_size = engine._max_cuda_graph_batch_size + self.max_beam_width = engine.max_beam_width + self.spec_config = engine.spec_config + + self.graphs: Dict[Tuple[int, int], torch.cuda.CUDAGraph] = {} + self.static_inputs: Dict[Tuple[int, int], Dict[str, torch.Tensor]] = {} + self.graph_outputs: Dict[Tuple[int, int], + Callable[[], Optional[torch.Tensor]]] = {} + self.graph_metadata: Dict[Tuple[int, int], Dict[str, Any]] = {} + self.memory_pool = engine._cuda_graph_mem_pool + self.padding_dummy_request: Optional["Request"] = None + + @property + def enable_spec_decode(self): + return self._get_engine().is_spec_decode + + @property + def draft_len(self): + return self.spec_config.max_draft_len if self.enable_spec_decode else 0 + + @property + def spec_metadata(self): + return self._get_engine().spec_metadata + + @property + def draft_tokens_cuda(self): + return self._get_engine().draft_tokens_cuda + + @property + def attn_metadata(self): + return self._get_engine().attn_metadata + + def __del__(self): + self.clear() + + def _get_engine(self) -> "PyTorchModelEngine": + """Safely dereferences the weak reference to the engine.""" + engine = self.engine_ref() + if engine is None: + raise RuntimeError( + "The parent PyTorchModelEngine has been garbage collected.") + return engine + + def maybe_get_cuda_graph(self, batch: ScheduledRequests): + """ + Determines if the current batch can be run with a CUDA graph. - Note that torch.compile w/ mode reduce-overhead supports CUDA graphs - with memory pool sharing. However, we have our own manager here because, - at the time of writing this, torch.compile takes way too long to warmup - graphs compared to doing it manually (not to mention, custom ops from - e.g. FlashInfer cause graph breaks). + Returns a tuple containing: + - A boolean indicating if a graph can be used. + - The attn_metadata for the graph, if applicable. + - The spec_metadata for the graph, if applicable. """ - self.batch_size = batch_size - self.max_beam_width = max_beam_width + engine = self._get_engine() + + # disable when doing statistic + if hasattr(engine, 'iter_counter') and ExpertStatistic.set_iter( + engine.iter_counter): + return False, None, None + + can_run_cuda_graph = batch.can_run_cuda_graph + batch_size = batch.batch_size + if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1: + all_can_graph_batch = engine.dist.tp_allgather( + [can_run_cuda_graph, batch_size]) + is_all_gen_only = all(all_can_graph[0] + for all_can_graph in all_can_graph_batch) + all_batch_size_equal = all( + all_gen_only[1] == all_can_graph_batch[0][1] + for all_gen_only in all_can_graph_batch) + + if not is_all_gen_only or not all_batch_size_equal: + return False, None, None + + if not self.enabled or not can_run_cuda_graph: + return False, None, None + + key = (batch_size, self.draft_len) + if key in self.graphs: + return True, self.graph_metadata[key][ + "attn_metadata"], self.graph_metadata[key]["spec_metadata"] + + if batch_size not in self.supported_batch_sizes: + return False, None, None + + num_sequences_in_batch = batch_size * self.max_beam_width + attn_metadata = self.attn_metadata.create_cuda_graph_metadata( + num_sequences_in_batch, False, self.draft_len) + assert attn_metadata.is_cuda_graph + + if self.enable_spec_decode: + spec_metadata = self.spec_metadata.create_cuda_graph_metadata( + num_sequences_in_batch) + spec_metadata.draft_tokens = self.draft_tokens_cuda + else: + spec_metadata = None + return True, attn_metadata, spec_metadata + + def needs_capture(self, batch_size: int): + return (batch_size, self.draft_len) not in self.graph_outputs + + def capture(self, batch_size: int, forward_fn: Callable, + initial_inputs: Dict[str, Any]): + """Captures the forward pass for a given batch size.""" + engine = self._get_engine() + key = (batch_size, self.draft_len) + spec_metadata = initial_inputs.get("spec_metadata", None) # [CUDA graph spec decode padding] # We pad input IDs/position IDs to the maximum draft length (token per request). # We're forced to do this because we cannot reallocate inputs over many graph runs. token_per_request = spec_metadata.max_draft_len + 1 if spec_metadata is not None else 1 - # Using ones instead of zeros prevents NaNs in e.g. Deepseek - self.input_ids = torch.ones( - (batch_size * max_beam_width * token_per_request, ), - device=device, - dtype=torch.int32) - self.position_ids = torch.zeros( - (1, batch_size * max_beam_width * token_per_request), - device=device, - dtype=torch.int32) - self.mrope_position_deltas = torch.zeros( - (batch_size, - 1), device=device, dtype=torch.int32) if use_mrope else None - - self.attn_metadata = attn_metadata - self.spec_metadata = spec_metadata - self._output = None - self._graph = None - self.optional_extra_model_inputs = ["mrope_position_deltas"] + static_tensors = { + "input_ids": + torch.ones((batch_size * self.max_beam_width * token_per_request, ), + device="cuda", + dtype=torch.int32), + "position_ids": + torch.zeros(( + 1, + batch_size * self.max_beam_width * token_per_request, + ), + device="cuda", + dtype=torch.int32), + } + if engine.use_mrope: + static_tensors["mrope_position_deltas"] = torch.zeros( + (batch_size, 1), device="cuda", dtype=torch.int32) + self.static_inputs[key] = static_tensors - def __del__(self): - self._graph.reset() - - def capture( - self, - forward_fn: Callable[[Dict[str, Any]], torch.Tensor], - pool: Optional[Tuple[int, int]] = None, - ) -> Tuple[int, int]: - self._graph = torch.cuda.CUDAGraph() - inputs = { - "attn_metadata": self.attn_metadata, - "input_ids": self.input_ids, - "position_ids": self.position_ids, - "inputs_embeds": None, - "spec_metadata": self.spec_metadata, - "mrope_position_deltas": self.mrope_position_deltas, + capture_inputs = initial_inputs.copy() + capture_inputs.update(static_tensors) + + self.graph_metadata[key] = { + "attn_metadata": initial_inputs["attn_metadata"], + "spec_metadata": spec_metadata, } # We have to do warm up runs to initialize PyTorch's # internal states according to the docs: # https://pytorch.org/docs/stable/notes/cuda.html#cuda-graph-semantics # This also lets us initialize states in the attn_metadata. + graph = torch.cuda.CUDAGraph() with with_multi_stream(True), piecewise_cuda_graph(False): - for _ in range(2): - forward_fn(inputs) - with torch.cuda.graph(self._graph, pool=pool): - output = forward_fn(inputs) - # Mark weak ref here. The output tensor should be freed properly. - self._output = make_weak_ref(output) - return self._graph.pool() - - def needs_capture(self) -> bool: - return self._output is None - - def run(self, inputs: Dict[str, Any]) -> torch.Tensor: - assert "input_ids" in inputs - assert "position_ids" in inputs - assert "attn_metadata" in inputs - - attn_metadata = inputs["attn_metadata"] - assert attn_metadata is self.attn_metadata, ( - "attn_metadata does not match the attn_metadata instance that was used to " - "capture this graph.") - - if "spec_metadata" in inputs: - spec_metadata = inputs["spec_metadata"] - assert spec_metadata is self.spec_metadata, ( - "spec_metadata does not match the spec_metadata instance that was used to " - "capture this graph.") - - input_ids = inputs["input_ids"] - position_ids = inputs["position_ids"] + for _ in range(self.WARMUP_STEPS): + forward_fn(capture_inputs) + with torch.cuda.graph(graph, pool=self.memory_pool): + output = forward_fn(capture_inputs) + + self.graphs[key] = graph + self.graph_outputs[key] = make_weak_ref(output) + self.memory_pool = graph.pool() + + def replay(self, batch_size: int, + current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]: + """Replays a previously captured graph.""" + key = (batch_size, self.draft_len) + stored_meta = self.graph_metadata[key] + assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"] + if stored_meta["spec_metadata"] is not None: + assert current_inputs.get( + "spec_metadata") is stored_meta["spec_metadata"] + + static_tensors = self.static_inputs[key] + + input_ids = current_inputs["input_ids"] seqlen = input_ids.shape[0] - self.input_ids[:seqlen].copy_(input_ids) - self.position_ids[:, :seqlen].copy_(position_ids) - if "mrope_position_deltas" in inputs: - self.mrope_position_deltas[:self.batch_size].copy_( - inputs["mrope_position_deltas"]) - - assert self._output is not None and self._graph is not None - self._graph.replay() - return self._output + static_tensors["input_ids"][:seqlen].copy_(input_ids) + + position_ids = current_inputs["position_ids"] + static_tensors["position_ids"][:, :seqlen].copy_(position_ids) + + if "mrope_position_deltas" in current_inputs: + assert "mrope_position_deltas" in static_tensors + static_tensors["mrope_position_deltas"][:batch_size].copy_( + current_inputs["mrope_position_deltas"]) + + self.graphs[key].replay() + output_ref = self.graph_outputs[key] + + return output_ref + + def _get_padded_batch(self, batch: ScheduledRequests, + resource_manager: ResourceManager) -> int: + engine = self._get_engine() + kv_cache_manager = resource_manager.get_resource_manager( + engine.kv_cache_manager_key) + can_run_cuda_graph = batch.can_run_cuda_graph + batch_size = batch.batch_size + new_batch_size = batch_size + + if self.enabled and engine.enable_attention_dp and engine.mapping.tp_size > 1: + graph_batch_size = engine.dist.tp_allgather( + [can_run_cuda_graph, batch_size]) + all_can_graph = all(graph_batch[0] + for graph_batch in graph_batch_size) + if all_can_graph: + new_batch_size = max(gen_only_batch[1] + for gen_only_batch in graph_batch_size) + + if (not self.enabled or not self.padding_enabled + or not can_run_cuda_graph + or new_batch_size > self.max_supported_batch_size): + return 0 + + padded_batch_size = self._round_up_batch_size(new_batch_size) + if batch_size == padded_batch_size: + return 0 + + padding_size = padded_batch_size - batch_size + if padding_size + batch.batch_size > engine.batch_size: + return 0 + + # No padding if it would create too many concurrent requests. + # This is not strictly required, but we should probably + # respect the requirement just in case that changes in the future. + if self.padding_dummy_request is None: + available_blocks = kv_cache_manager.get_num_free_blocks() + # No padding if not enough KV cache space + if available_blocks < 1: + return 0 + + self.padding_dummy_request = kv_cache_manager.add_dummy_requests( + [CUDA_GRAPH_DUMMY_REQUEST_ID], + is_gen=True, + max_num_draft_tokens=engine.max_draft_len, + use_mrope=engine.use_mrope, + max_beam_width=engine.max_beam_width)[0] + self.padding_dummy_request.is_cuda_graph_dummy = True + spec_res_mgr = resource_manager.get_resource_manager( + ResourceManagerType.SPEC_RESOURCE_MANAGER) + if spec_res_mgr: + spec_res_mgr.add_dummy_requests([CUDA_GRAPH_DUMMY_REQUEST_ID]) + + batch.generation_requests.extend([self.padding_dummy_request] * + padding_size) + return padding_size + + def _round_up_batch_size(self, batch_size: int) -> int: + """Finds the smallest supported graph batch size >= the given size.""" + if not self.supported_batch_sizes: + return 0 + idx = bisect.bisect_left(self.supported_batch_sizes, batch_size) + if idx == len(self.supported_batch_sizes): + return 0 + return self.supported_batch_sizes[idx] + + @contextlib.contextmanager + def pad_batch(self, scheduled_requests: ScheduledRequests, + resource_manager: ResourceManager): + """Context manager to pad a batch to a graph-compatible size.""" + + padding_size = self._get_padded_batch(scheduled_requests, + resource_manager) + try: + yield scheduled_requests + finally: + if padding_size > 0: + scheduled_requests.generation_requests = scheduled_requests.generation_requests[: + -padding_size] + + def clear(self): + """Releases all captured graphs and the associated memory pool.""" + for graph in self.graphs.values(): + graph.reset() + self.graphs.clear() + self.static_inputs.clear() + self.graph_outputs.clear() + self.graph_metadata.clear() + del self.memory_pool + self.memory_pool = None + torch.cuda.empty_cache() diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 1b3fbfbfc4e..8f83083e167 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1,4 +1,3 @@ -import bisect import contextlib import copy import functools @@ -57,7 +56,7 @@ with_model_extra_attrs) from .config import LoadFormat, PyTorchConfig from .config_utils import is_mla -from .cuda_graph_runner import DecodingCUDAGraphRunner +from .cuda_graph_runner import CUDAGraphRunner from .layerwise_nvtx_marker import LayerwiseNvtxMarker from .llm_request import get_draft_token_length from .resource_manager import (BaseResourceManager, KVCacheManager, @@ -422,7 +421,6 @@ def __init__( self.iter_states = {} self._cuda_graphs = {} self._cuda_graph_mem_pool = self._torch_compile_backend._graph_pool_handle if self._torch_compile_enabled else None - self._run_cuda_graphs = pytorch_backend_config.use_cuda_graph self._cuda_graph_padding_enabled = pytorch_backend_config.cuda_graph_padding_enabled @@ -451,7 +449,7 @@ def __init__( # with different KV cache managers. self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER self.lora_model_config: Optional[LoraModelConfig] = None - self.cuda_graph_dummy_request = None + self.cuda_graph_runner = CUDAGraphRunner(self) # Setup the local cache indirection buffer only once and reuse it. # This way it can also be used for CUDA graphs. @@ -541,12 +539,12 @@ def wrapper(self, *args, **kwargs): @contextlib.contextmanager def no_cuda_graph(self): - _run_cuda_graphs = self._run_cuda_graphs - self._run_cuda_graphs = False + _run_cuda_graphs = self.cuda_graph_runner.enabled + self.cuda_graph_runner.enabled = False try: yield finally: - self._run_cuda_graphs = _run_cuda_graphs + self.cuda_graph_runner.enabled = _run_cuda_graphs @with_warmup_flag def warmup(self, resource_manager: ResourceManager) -> None: @@ -561,7 +559,7 @@ def warmup(self, resource_manager: ResourceManager) -> None: # The lifetime of model engine and kv cache manager can be different. # Reset the global cuda graph dummy request to None in warmup. - self.cuda_graph_dummy_request = None + self.cuda_graph_runner.padding_dummy_request = None def get_cuda_graph_warmup_request(batch_size, draft_len): # Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel. @@ -756,7 +754,7 @@ def release_batch(result: ScheduledRequests | None): AutoTuner.get().print_profiling_cache() - if not (self._run_cuda_graphs + if not (self.cuda_graph_runner.enabled or self._torch_compile_piecewise_cuda_graph): return @@ -889,152 +887,6 @@ def _set_up_spec_metadata( is_draft_model=self.is_draft_model) return self.spec_metadata - def _get_padded_batch( - self, - scheduled_requests: ScheduledRequests, - kv_cache_manager, - spec_resource_manager: Optional[BaseResourceManager] = None) -> int: - can_run_cuda_graph = scheduled_requests.can_run_cuda_graph - batch_size = scheduled_requests.batch_size - new_batch_size = batch_size - - if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1: - graph_batch_size = self.dist.tp_allgather( - [can_run_cuda_graph, batch_size]) - all_can_graph = all(graph_batch[0] - for graph_batch in graph_batch_size) - if all_can_graph: - new_batch_size = max(gen_only_batch[1] - for gen_only_batch in graph_batch_size) - - if (not self._run_cuda_graphs or not self._cuda_graph_padding_enabled - or not can_run_cuda_graph - or new_batch_size > self._max_cuda_graph_batch_size): - return 0 - - padded_batch_size = self._round_up_batch_size(new_batch_size) - if batch_size == padded_batch_size: - return 0 - - padding_size = padded_batch_size - batch_size - if padding_size + scheduled_requests.batch_size > self.batch_size: - return 0 - - # No padding if it would create too many concurrent requests. - # This is not strictly required, but we should probably - # respect the requirement just in case that changes in the future. - if self.cuda_graph_dummy_request is None: - available_blocks = kv_cache_manager.get_num_free_blocks() - # No padding if not enough KV cache space - if available_blocks < 1: - return 0 - - cuda_graph_dummy_request_ids = [MAX_UINT64 - 1] - self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests( - cuda_graph_dummy_request_ids, - is_gen=True, - max_num_draft_tokens=self.runtime_draft_len, - use_mrope=self.use_mrope, - max_beam_width=self.max_beam_width)[0] - self.cuda_graph_dummy_request.is_cuda_graph_dummy = True - if spec_resource_manager is not None: - spec_resource_manager.add_dummy_requests( - request_ids=cuda_graph_dummy_request_ids) - - scheduled_requests.generation_requests.extend( - [self.cuda_graph_dummy_request] * padding_size) - - return padding_size - - @contextlib.contextmanager - def _maybe_pad_batch( - self, - scheduled_requests: ScheduledRequests, - kv_cache_manager, - spec_resource_manager: Optional[BaseResourceManager] = None): - """ - CUDA graphs can only be used for specific batch sizes. - - If using CUDA graphs, this method will add dummy requests to the given - batch so we can always use a CUDA graph. It is a context manager - because the padded requests will be removed from scheduled requests. - """ - padding_size = self._get_padded_batch(scheduled_requests, - kv_cache_manager, - spec_resource_manager) - try: - yield scheduled_requests - finally: - if padding_size > 0: - scheduled_requests.generation_requests = scheduled_requests.generation_requests[: - -padding_size] - - def _round_up_batch_size(self, batch_size: int) -> int: - """ - Round up the given batch size to the nearest batch size that is - associated with a CUDA graph. - """ - idx = bisect.bisect_left(self._cuda_graph_batch_sizes, batch_size) - return self._cuda_graph_batch_sizes[idx] - - def _maybe_get_cuda_graph( - self, - batch: ScheduledRequests, - ) -> Optional[DecodingCUDAGraphRunner]: - """ - Get a CUDA graph runner or return None (e.g. if CUDA graphs are disabled - or if the batch size is too big). - """ - # disable when doing statistic - if ExpertStatistic.set_iter(self.iter_counter): - return None - - draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0 - can_run_cuda_graph = batch.can_run_cuda_graph - batch_size = len(batch.generation_requests) - if self._run_cuda_graphs and self.enable_attention_dp and self.mapping.tp_size > 1: - all_can_graph_batch = self.dist.tp_allgather( - [can_run_cuda_graph, batch_size]) - is_all_gen_only = all(all_can_graph[0] - for all_can_graph in all_can_graph_batch) - all_batch_size_equal = all( - all_gen_only[1] == all_can_graph_batch[0][1] - for all_gen_only in all_can_graph_batch) - - if not is_all_gen_only or not all_batch_size_equal: - return None - - if not self._run_cuda_graphs or not can_run_cuda_graph: - return None - - if batch_size in self._cuda_graphs and draft_len in self._cuda_graphs[ - batch_size]: - return self._cuda_graphs[batch_size][draft_len] - - if batch_size not in self._cuda_graph_batch_sizes: - return None - - num_sequences_in_batch = batch_size * self.max_beam_width - attn_metadata = self.attn_metadata.create_cuda_graph_metadata( - num_sequences_in_batch, False, draft_len) - assert attn_metadata.is_cuda_graph - - if self.enable_spec_decode: - spec_metadata = self.spec_metadata.create_cuda_graph_metadata( - num_sequences_in_batch) - spec_metadata.draft_tokens = self.draft_tokens_cuda - else: - spec_metadata = None - - # Initialize nested dictionary if needed - if batch_size not in self._cuda_graphs: - self._cuda_graphs[batch_size] = {} - - self._cuda_graphs[batch_size][draft_len] = DecodingCUDAGraphRunner( - batch_size, "cuda", attn_metadata, spec_metadata, self.use_mrope, - self.max_beam_width) - return self._cuda_graphs[batch_size][draft_len] - def __del__(self) -> None: if getattr(self, 'ub_buffers', None): for u in self.ub_buffers: @@ -1244,13 +1096,7 @@ def _init_model_capacity(self): self._init_max_num_tokens() def _release_cuda_graphs(self): - for batch_size, draft_graphs in self._cuda_graphs.items(): - for draft_len, graph in draft_graphs.items(): - del graph - self._cuda_graphs.clear() - torch.cuda.empty_cache() - del self._cuda_graph_mem_pool - self._cuda_graph_mem_pool = None + self.cuda_graph_runner.clear() def get_max_num_sequences(self) -> int: """ @@ -2263,12 +2109,14 @@ def forward( else: return self._forward_step(inputs, gather_ids, gather_context_logits) - with self._maybe_pad_batch(scheduled_requests, kv_cache_manager, - spec_resource_manager) as scheduled_requests: - maybe_graph = self._maybe_get_cuda_graph(scheduled_requests) - if maybe_graph is not None: - attn_metadata = maybe_graph.attn_metadata - spec_metadata = maybe_graph.spec_metadata + with self.cuda_graph_runner.pad_batch( + scheduled_requests, resource_manager) as padded_requests: + + maybe_graph, maybe_attn_metadata, maybe_spec_metadata = self.cuda_graph_runner.maybe_get_cuda_graph( + padded_requests) + if maybe_graph: + attn_metadata = maybe_attn_metadata + spec_metadata = maybe_spec_metadata else: attn_metadata = self.attn_metadata if self.enable_spec_decode: @@ -2277,17 +2125,19 @@ def forward( spec_metadata = None inputs, gather_ids = self._prepare_inputs( - scheduled_requests, kv_cache_manager, attn_metadata, - spec_metadata, new_tensors_device, cache_indirection_buffer) + padded_requests, kv_cache_manager, attn_metadata, spec_metadata, + new_tensors_device, cache_indirection_buffer) self.iter_counter += 1 - if maybe_graph is None: + if not maybe_graph: + # Fallback to eager execution if graph was not used with MoeLoadBalancerIterContext(moe_load_balancer): outputs = self._forward_step(inputs, gather_ids, gather_context_logits) else: - if maybe_graph.needs_capture(): + batch_size = len(padded_requests.generation_requests) + if self.cuda_graph_runner.needs_capture(batch_size): def capture_forward_fn(inputs: Dict[str, Any]): with MoeLoadBalancerIterContext(moe_load_balancer): @@ -2296,18 +2146,16 @@ def capture_forward_fn(inputs: Dict[str, Any]): gather_ids=gather_ids, gather_context_logits=gather_context_logits) - pool = maybe_graph.capture( - capture_forward_fn, - self._cuda_graph_mem_pool, - ) - self._cuda_graph_mem_pool = pool + self.cuda_graph_runner.capture(batch_size, + capture_forward_fn, inputs) # here we don't need to use context since cuda graph capture didn't run kernel. # maybe we need a cleaner way to do this. - outputs = maybe_graph.run(inputs) + outputs = self.cuda_graph_runner.replay(batch_size, inputs) else: with MoeLoadBalancerIterContext(moe_load_balancer): - outputs = maybe_graph.run(inputs) + outputs = self.cuda_graph_runner.replay( + batch_size, inputs) self._execute_logit_post_processors(scheduled_requests, outputs) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 12686728cda..ac3bb7a9f53 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -23,7 +23,7 @@ get_spec_resource_manager) from ._util import (KvCacheCreator, _adjust_torch_mem_fraction, create_py_executor_instance, instantiate_sampler, is_mla) -from .config import PyTorchConfig +from .config import LoadFormat, PyTorchConfig from .config_utils import is_mla from .guided_decoder import GuidedDecoder from .model_engine import PyTorchModelEngine @@ -252,13 +252,16 @@ def create_py_executor( with mem_monitor.observe_creation_stage( _ExecutorCreationStage.MODEL_ENGINE_DRAFT): draft_spec_config = copy.copy(spec_config) + draft_pytorch_backend_config = copy.copy(pytorch_backend_config) + if spec_config.load_format == "dummy": + draft_pytorch_backend_config.load_format = LoadFormat.DUMMY # The draft model won't have any draft tokens attached to # generation requests when we invoke it autoregressively draft_spec_config.max_draft_len = 0 draft_model_engine = PyTorchModelEngine( model_path=spec_config.speculative_model_dir, - pytorch_backend_config=pytorch_backend_config, + pytorch_backend_config=draft_pytorch_backend_config, batch_size=executor_config.max_batch_size, max_beam_width=executor_config.max_beam_width, max_num_tokens=executor_config.max_num_tokens, diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 417becf12f3..2d4225641b5 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Tuple +from typing import List, Optional, Set import torch from torch import nn @@ -35,9 +35,10 @@ def __init__(self, config: "EagleDecodingConfig", dtype: torch.dtype, # empty hidden states tensor max_num_tokens = min(max_num_tokens, max_num_requests * self.max_seq_len) - self.hidden_states = torch.empty((max_num_tokens, self.hidden_size * 3), - dtype=self.dtype, - device='cuda') + self.hidden_states = torch.empty( + (max_num_tokens, self.hidden_size * config.num_capture_layers), + dtype=self.dtype, + device='cuda') # sequence length, only used for metadata preparation self.seq_lens = {i: 0 for i in range(max_num_requests)} # start indices of each slot @@ -79,8 +80,7 @@ def get_needed_resource_to_completion(self, request: LlmRequest): @dataclass class Eagle3SpecMetadata(SpecMetadata): hidden_states: List[torch.Tensor] = field(default_factory=list) - num_capture_layers: int = 3 - layers_to_capture: Tuple[int, ...] = field(init=False) + layers_to_capture: Optional[Set[int]] = None target_model_embed_tokens: Optional[torch.nn.Module] = None hidden_size: int = 0 max_num_tokens: int = 0 @@ -90,14 +90,19 @@ class Eagle3SpecMetadata(SpecMetadata): eagle3_resource_manager: Optional[Eagle3ResourceManager] = None def __post_init__(self): - if self.num_layers == 1: - self.layers_to_capture = (0, ) - else: - if self.num_layers <= 5: - raise ValueError("Not enough hidden layers for EAGLE") + if self.layers_to_capture is None: + if self.num_layers == 1: + self.layers_to_capture = (self.num_layers - 1, ) + else: + if self.num_layers <= 5: + raise ValueError( + "Not enough hidden layers for default EAGLE3 capture") - self.layers_to_capture = (1, self.num_layers // 2 - 1, - self.num_layers - 4) + self.layers_to_capture = (1, self.num_layers // 2 - 1, + self.num_layers - 4) + else: + self.layers_to_capture = sorted(list(self.layers_to_capture)) + self.num_capture_layers = len(self.layers_to_capture) # Initialize to 0 to avoid reading uninitialized memory during warmup self.hidden_states_read_indices = torch.zeros([self.max_num_tokens], @@ -186,7 +191,7 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): # The hidden states hidden_states: Optional[torch.Tensor] = None # The layers to be captured - layers_to_capture: Tuple[int, ...] = field(init=False) + layers_to_capture: Optional[Set[int]] = None # The hidden size of the hidden states hidden_size: int = 0 # The max number of tokens @@ -197,14 +202,19 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): batch_indices_cuda: Optional[torch.Tensor] = None def __post_init__(self): - if self.num_layers == 1: - self.layers_to_capture = (1, ) - else: - if self.num_layers <= 5: - raise ValueError("Not enough hidden layers for EAGLE") + if self.layers_to_capture is None: + if self.num_layers == 1: + self.layers_to_capture = (self.num_layers - 1, ) + else: + if self.num_layers <= 5: + raise ValueError( + "Not enough hidden layers for default EAGLE3 capture") - self.layers_to_capture = (1, self.num_layers // 2 - 1, - self.num_layers - 4) + self.layers_to_capture = (1, self.num_layers // 2 - 1, + self.num_layers - 4) + else: + self.layers_to_capture = sorted(list(self.layers_to_capture)) + self.num_capture_layers = len(self.layers_to_capture) self.hidden_states = torch.empty( (self.max_num_tokens, self.hidden_size * len(self.layers_to_capture)), diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index f7cdd92a561..1d306b90291 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -185,6 +185,13 @@ def create_cuda_graph_metadata(self, max_batch_size: int): cuda_graph_metadata.__post_init__() return cuda_graph_metadata + def is_layer_capture(self, layer_id: int): + """ + Whether the layer should be captured (eg for Eagle3). + By default, does nothing. + """ + return False + def maybe_capture_hidden_states(self, layer_id: int, hidden_states: torch.Tensor, residual: torch.Tensor) -> None: diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index c4a4ccf7e3c..16fef4862b3 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -38,6 +38,7 @@ def get_spec_metadata(spec_config, dtype=model_config.torch_dtype, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, + layers_to_capture=spec_config.eagle3_layers_to_capture, ) if spec_config.spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelSpecMetadata( @@ -47,6 +48,7 @@ def get_spec_metadata(spec_config, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, + layers_to_capture=spec_config.eagle3_layers_to_capture, ) if spec_config.spec_dec_mode.is_draft_target() or \ spec_config.spec_dec_mode.is_ngram() or \ diff --git a/tensorrt_llm/bench/dataclasses/reporting.py b/tensorrt_llm/bench/dataclasses/reporting.py index acf7f60bcbb..fd76466cd5a 100755 --- a/tensorrt_llm/bench/dataclasses/reporting.py +++ b/tensorrt_llm/bench/dataclasses/reporting.py @@ -273,6 +273,22 @@ def get_statistics_dict(self) -> Dict[str, Any]: }, } + # Retrieve KV cache information. + kv_cache_config = self.kwargs.get("kv_cache_config", KvCacheConfig()) + if isinstance(kv_cache_config, KvCacheConfig): + kv_cache_dtype = kv_cache_config.dtype + kv_cache_mem_percent = kv_cache_config.free_gpu_memory_fraction + elif isinstance(kv_cache_config, dict): + kv_cache_dtype = kv_cache_config.get("dtype", "auto") + kv_cache_mem_percent = kv_cache_config.get( + "free_gpu_memory_fraction") + else: + raise ValueError( + f"Invalid kv_cache_config type: {type(kv_cache_config)}.") + + kv_cache_mem_percent = f"{kv_cache_mem_percent * 100.0:.2f}%" \ + if kv_cache_mem_percent is not None else "None" + # Engine/Backend details if self.rt_cfg.backend not in ('pytorch', '_autodeploy'): config_path = self.rt_cfg.engine_dir / "config.json" @@ -302,15 +318,6 @@ def get_statistics_dict(self) -> Dict[str, Any]: model = self.rt_cfg.model_path or self.rt_cfg.model model_config = ModelConfig.from_pretrained(model, trust_remote_code=True) - kv_cache_config = self.kwargs.get("kv_cache_config", - KvCacheConfig()) - if isinstance(kv_cache_config, KvCacheConfig): - kv_cache_dtype = kv_cache_config.dtype - elif isinstance(kv_cache_config, dict): - kv_cache_dtype = kv_cache_config.get("dtype", "auto") - else: - raise ValueError( - f"Invalid kv_cache_config type: {type(kv_cache_config)}.") validate_and_set_kv_cache_quant(model_config, kv_cache_dtype) @@ -336,8 +343,7 @@ def get_statistics_dict(self) -> Dict[str, Any]: "max_batch_size": self.rt_cfg.settings_config.max_batch_size, "max_num_tokens": self.rt_cfg.settings_config.max_num_tokens, "scheduling_policy": self.rt_cfg.settings_config.scheduler_policy, - "kv_cache_percentage": - self.rt_cfg.settings_config.kv_cache_percent * 100.0, + "kv_cache_percentage": kv_cache_mem_percent, "issue_rate": self.convert_rate_to_s(self.statistics.issue_rate_ns) } @@ -526,7 +532,7 @@ def report_statistics(self) -> None: f"Max Runtime Batch Size: {world_info['max_batch_size']}\n" f"Max Runtime Tokens: {world_info['max_num_tokens']}\n" f"Scheduling Policy: {world_info['scheduling_policy']}\n" - f"KV Memory Percentage: {world_info['kv_cache_percentage']:.2f}%\n" + f"KV Memory Percentage: {world_info['kv_cache_percentage']}\n" f"Issue Rate (req/sec): {world_info['issue_rate']:.4E}\n" f"\n") diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 07eb13d7968..c1013eb3c5c 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -14,6 +14,7 @@ from tensorrt_llm import LLM as PyTorchLLM from tensorrt_llm import MultimodalEncoder from tensorrt_llm._tensorrt_engine import LLM +from tensorrt_llm._torch.auto_deploy.llm import LLM as AutoDeployLLM from tensorrt_llm._utils import mpi_rank from tensorrt_llm.executor.utils import LlmLauncherEnvs from tensorrt_llm.llmapi import (BuildConfig, CapacitySchedulerPolicy, @@ -109,7 +110,7 @@ def get_llm_args(model: str, capacity_scheduler_policy=CapacitySchedulerPolicy.GUARANTEED_NO_EVICT, dynamic_batch_config=dynamic_batch_config, ) - + backend = backend if backend in ["pytorch", "_autodeploy"] else None llm_args = { "model": model, @@ -140,7 +141,7 @@ def get_llm_args(model: str, "kv_cache_config": kv_cache_config, "backend": - backend if backend == "pytorch" else None, + backend, "num_postprocess_workers": num_postprocess_workers, "postprocess_tokenizer_dir": @@ -162,9 +163,15 @@ def launch_server(host: str, backend = llm_args["backend"] model = llm_args["model"] - if backend == 'pytorch': llm = PyTorchLLM(**llm_args) + elif backend == '_autodeploy': + # AutoDeploy does not support build_config + llm_args.pop("build_config", None) + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): + # AutoDeploy does not support cache reuse yet. + llm_args["kv_cache_config"].enable_block_reuse = False + llm = AutoDeployLLM(**llm_args) else: llm = LLM(**llm_args) @@ -204,10 +211,13 @@ def launch_mm_encoder_server( default="localhost", help="Hostname of the server.") @click.option("--port", type=int, default=8000, help="Port of the server.") -@click.option("--backend", - type=click.Choice(["pytorch", "trt"]), - default="pytorch", - help="Set to 'pytorch' for pytorch path. Default is cpp path.") +@click.option( + "--backend", + type=click.Choice(["pytorch", "trt", "_autodeploy"]), + default="pytorch", + help= + "Set to 'pytorch' for pytorch path and '_autodeploy' for autodeploy path. Default is pytorch path." +) @click.option('--log_level', type=click.Choice(severity_map.keys()), default='info', diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index da5071e3b0a..6ed4dea76c7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -9,7 +9,7 @@ from enum import Enum, EnumMeta from pathlib import Path from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, - Type, TypeAlias, TypeVar, Union, get_args, get_origin) + Set, Type, TypeAlias, TypeVar, Union, get_args, get_origin) import torch import yaml @@ -352,6 +352,7 @@ class DecodingBaseConfig(StrictBaseModel): # When specified, speculation will be disabled at batch sizes above # this value. Otherwise, speculation will always be on. max_concurrency: Optional[int] = None + load_format: Optional[str] = None @classmethod def from_dict(cls, data: dict): @@ -424,6 +425,7 @@ class EagleDecodingConfig(DecodingBaseConfig): num_eagle_layers: Optional[int] = None max_non_leaves_per_layer: Optional[int] = None eagle3_one_model: Optional[bool] = True + eagle3_layers_to_capture: Optional[Set[int]] = None @classmethod def from_dict(cls, data: dict): @@ -443,6 +445,17 @@ def spec_dec_mode(self): return TorchSpeculativeDecodingMode.EAGLE3_ONE_MODEL return TorchSpeculativeDecodingMode.EAGLE3 + @functools.cached_property + def num_capture_layers(self): + """ + Returns the number of layers to capture of the target model. + If eagle3_layers_to_capture is not None, return the length of the set. + Otherwise, assume Eagle3 base set and return 3. + """ + if self.eagle3_layers_to_capture is not None: + return len(self.eagle3_layers_to_capture) + return 3 + class UserProvidedDecodingConfig(DecodingBaseConfig): # Cannot use real type annotations due to circular imports @@ -523,7 +536,9 @@ class MTPDecodingConfig(DecodingBaseConfig): @classmethod def from_dict(cls, data: dict): - return cls(**data) + out = cls(**data) + out.max_draft_len = out.num_nextn_predict_layers + return out decoding_type: ClassVar[str] = "MTP" diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/README.md b/tensorrt_llm/tools/profiler/nsys_profile_tools/README.md new file mode 100644 index 00000000000..b7b9f084de3 --- /dev/null +++ b/tensorrt_llm/tools/profiler/nsys_profile_tools/README.md @@ -0,0 +1,174 @@ +# gputrc2graph.py + +This script processes NVIDIA Nsight Systems (`nsys`) GPU trace files +(`.nsys-rep`) with -t cuda tracing enabled, and generates kernel-level +summaries and visualizations of GPU and non-GPU time. It is useful for +profiling and analyzing nsys profile output. + +## Usage + +### Command-line Arguments + +- `--in_file` + **(required)** + List of input files and their metadata. Each entry should be in the format: + `,,,` + - `nsys-rep`: Path to the `.nsys-rep` file. + - `engine`: Engine name (e.g., `trtllm`). + - `model`: Model name (e.g., `llama`, `gpt-oss`, `ds`). + - `elapsed_nonprofiled_sec`: Wall-clock runtime (in seconds) without + profiling. Specify `0` to use the elapsed GPU time calculated from the nsys-rep file (this may inflate non-GPU time if actual runtime without profiling is less). Multiple entries can be provided, separated by spaces. + +- `--out_dir` + Output directory for the generated CSV and HTML files. + If not specified, results are saved in the current directory. + +- `--title` + Title for the HTML chart/visualization. + +- `--nsys_cmd` + Path to the `nsys` command. + Default: `nsys` (assumes it is in your PATH). + Use this if `nsys` is not in your system PATH. + +## Notes + +- Make sure you have pandas and plotly python packages installed. +- Make sure [nsys](https://developer.nvidia.com/nsight-systems/get-started) is +installed, and specify the path to the `nsys` command with `--nsys_cmd` if it + is not in your PATH. +- For more details on available engines and models, see the help string in + the script or run: + +```bash +python3 gputrc2graph.py --help +``` + +## Example 1: analyze a single profile + +To analyze the GPU cycles of for example, a llama-3.1-8B model with trtllm: + +1. Run the following command to collect nsys profile, for trtllm serve config. + + ```bash + nsys profile -t cuda -o nsys_res -f true --trace-fork-before-exec=true \ + --cuda-graph-trace=node --delay --duration \ + python3 -m trtllm-serve meta-llama/Llama-4-Scout-17B-16E-Instruct ... + ``` + + where: + + - DELAY: how many seconds to delay nsys from collecting profiles, needed so + that profiles aren't captured till trtllm server has come up and load + generation starts. + - DURATION: how many seconds for nsys profile to run before generating the + profile. This should be > the duration of the run. + +2. Run again, this time without collecting the profile, and get the total run + time in seconds. This value will be used by the script to calculate the + CPU(non-GPU) seconds for the analysis. + +3. Say the run elapsed time is .35 seconds, from step #2. Run script to + analyze: + + ```bash + python3 gputrc2graph.py \ + --in_file run1.nsys-rep,trtllm,llama,.35 + ``` + +The command will produce 2 files for analysis: + +- result.html: this categorizes kernel names into different categories in a + stacked bar chart. +- result.csv: shows how the kernel names are mapped to the different + categories. + +### HTML visualization with result.html + +The html file shows the number of elapsed seconds due to different GPU +Substages or categories, which consist of moe_gemm as the biggest +category, at .14 seconds, followed by "attn" kernels. This lets the user +prioritize the kernels to focus on for performance optimizations. + +![Example GPU Trace Visualization](images/html.png) + +There's also an appended data table underneath the bar chart for copying out to + other post-processing tools. + +![Example GPU Trace Visualization Table](images/html_tbl.png) + +### Kernel to category mapping with result.csv + +Suppose the user would like to focus on improving decreasing calls to nccl +kernels. The next step is to use the result.csv to dive into what the kernels +are which compose the nccl GPU cycles. The following image shows that +ar_fusion all reduce kernel to be the biggest contributor to GPU cycles for +nccl, followed by AllGather. + +![Example GPU Trace csv](images/csv.png) + +## Example 2: analyze multiple profiles + +Suppose the user has multiple nsys trace files, captured for different models, +say llama and gpt-oss in this case, and wish to compare their GPU/non-GPU +time, something like the following command can be used. + +```bash +python3 gputrc2graph.py \ +--in_file run1.nsys-rep,trtllm,llama,100 run2.nsys-rep,trtllm,gpt-oss,102 \ +--out_dir results +``` + +The analysis process is similar to example 1 but now there will be multiple +stack bar charts that can be compared. The categories for the different +kernels will remain the same, so that it's easy to compare the GPU cycles for +the same categories. + +Once a category is shown to have more cycles for one configuration than +another, the next step would be to use the csv file to see what kernels are +mapped into that category, and which kernels are taking the largest amount of +time which would cause a difference for the overall category. + +## Example 3: add new classification for a new model + +To create a new engine DEF with model ABC, just add another json file in the +same directory as gputrc2graph.py with the same format as the other json files. +The script will automatically pick up all the json files in the same directory +as engine/model specifications. + +Then, for this new model, suppose there are 4 kernels to be classified into +"gemm" and "attn", where the gemm kernelshave names with "*H*" or "*I*" in +them, and attn kernels have names with "*J*" or "*K*" in them, just add another + .json file in the same directory as gputrc2graph.py with the same format as + the other json files, like the following: + +```json +{ + "DEF": { + "ABC": { + "H|I": "gemm", + "J|K": "attn", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} +``` + +Each entry in the dictionary consists of: + +- key: a regex used to classify the kernels +- value: the category to classify the kernels into. + +The last 2 entries are common for all engine/models, consisting of CUDA memory +operations and a 'misc' for anything that's leftover and can't be classified. + +When invoking gputrc2graph.py, specify a trace file with this new model/engine +like the following: + +```bash +--in_file new.nsys-rep,DEF,ABC, +``` + +If the engine_DEF.json file already exists, just add the model as a new node in + the existing engine file, after the other models. diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/gputrc2graph.py b/tensorrt_llm/tools/profiler/nsys_profile_tools/gputrc2graph.py new file mode 100755 index 00000000000..1ca8a0ff235 --- /dev/null +++ b/tensorrt_llm/tools/profiler/nsys_profile_tools/gputrc2graph.py @@ -0,0 +1,349 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + This generates gpu kernel analysis output from nsys rep. Will call nsys + stats -r cuda_gpu_trace, get non-overlapped gpu cycles, then generate + csv and html output for analysis +""" + +import argparse +import logging +import os + +import regex as re + +logger = logging.getLogger(__name__) + + +# helper data class for annotating kernels +def load_engine_model(): + """returns engine_model built from all json files in the current dir""" + import glob + import json + + engine_model = {} + + json_files = glob.glob( + os.path.join(os.path.dirname(__file__) or ".", "*.json")) + for fname in json_files: + with open(fname, encoding="utf-8") as f: + engine_model.update(json.load(f)) + return engine_model + + +class GPUTrace2Graph: + """ + Parses output of nsys report, generates csv and bar chart output + """ + + def __init__(self): + import pandas as pd # avoid importing till needed + + self.pd = pd + self.pd.options.mode.copy_on_write = True + + # helper functions for generating trace->summary csvs + def gen_nonoverlapped_sum_from_gputrace(self, in_file, out_file): + logger.info("loading %s", in_file) + df = self.pd.read_csv(in_file, + usecols=["Start (ns)", "Duration (ns)", "Name"]) + if df.empty: + return + df["End (ns)"] = df["Start (ns)"] + df["Duration (ns)"] + df = self.sum_non_overlapping_intervals(df) + # get ready to print table with elapsed times per kernel + df["Instances"] = 1 + df_sum = df.groupby("Name", as_index=False).agg({ + "Elapsed Time (ns)": "sum", + "Duration (ns)": "sum", + "Instances": "size" + }) + + # generate csv + df_sum["Total Time (sec)"] = df_sum["Duration (ns)"] / 1e9 + df_sum["Elapsed Time (sec)"] = df_sum["Elapsed Time (ns)"] / 1e9 + df_sum = df_sum.sort_values(by="Elapsed Time (sec)", ascending=False) + df_sum[["Elapsed Time (sec)", "Total Time (sec)", "Instances", + "Name"]].to_csv(out_file, index=False) + + def sum_non_overlapping_intervals(self, df): + """ + returns new sorted df with Elapsed Time (ns) column using + vectorized operations + """ + logger.info("sorting %s trace records by start time", str(df.shape)) + assert not df.empty, 'empty nsys records' + # Sort by start time and reset index + df = df.sort_values(by="Start (ns)").reset_index(drop=True) + + # Initialize elapsed time as duration + df["Elapsed Time (ns)"] = df["Duration (ns)"] + + # Get numpy arrays for faster operations + starts = df["Start (ns)"].values + ends = df["End (ns)"].values + + # Keep track of current interval end + current_end = ends[0] + display_units = max(1, int(len(df) / 100)) + # Update current_end for overlapping intervals + for i in range(1, len(df)): + if i % display_units == 0: + print(f"processing trace: {int(i/len(df) * 100)} %", end="\r") + if starts[i] <= current_end: + if ends[i] > current_end: + # Partial overlap + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = ( + ends[i] - current_end) + current_end = ends[i] + else: + # Complete overlap + df.iloc[i, df.columns.get_loc("Elapsed Time (ns)")] = 0 + else: + # No overlap + current_end = ends[i] + + return df + + # functions for generating html files + def make_html(self, df, output_dir, title): + """make html graph from df""" + import plotly.express as px + + if df.empty: + return + output_name = os.path.join(output_dir, "result") + if not title: + title = "Model_Engine" + x = "Model_Engine" + y = "Elapsed Time (sec)" + color = "Category" + """ generate kernel mapping table """ + # Sort Model_Engine categories by last field after underscore + df["Model_Engine"] = self.pd.Categorical( + df["Model_Engine"], + sorted(df["Model_Engine"].unique(), key=lambda x: x.split("_")[-1]), + ) + df[["Model_Engine", color, "Instances", "Name", + y]].sort_values(by=color).to_csv(f"{output_name}.csv", index=False) + graph = px.histogram( + df.round(2), + x=x, + y=y, + title=(f"{y} for {title}"), + color=color, + text_auto=True, + ) + # wrap x axis labels + graph.update_xaxes(automargin=True) + graph.write_html(f"{output_name}.html") + """ + Generate data table with columns per Model_Engine into result.html + """ + pivot_df = df.pivot_table( + values="Elapsed Time (sec)", + index="Category", + columns="Model_Engine", + aggfunc="sum", + observed=False, + ).round(2) + # Add sum row at bottom + pivot_df.loc["total_elapsed_sec"] = pivot_df.sum() + pivot_df.fillna("").to_html("temp.html") + with ( + open(f"{output_name}.html", "a", encoding="utf-8") as outfile, + open("temp.html", encoding="utf-8") as infile, + ): + outfile.write(infile.read()) + os.remove("temp.html") + + print(f"Finished generating: \n" + f" {output_name}.html for stack bar chart \n" + f" {output_name}.csv for Kernel-Category mapping") + + def anno_gpu_kernname(self, df, mapping): + """add "Category" column""" + + def anno_gpu_kernname_helper(name): + for kern_name, val in mapping.items(): + if re.search(kern_name, name): + return val + + df["Category"] = df["Name"].apply(anno_gpu_kernname_helper) + + def make_nongpu_row(self, df, nongpu_sec): + """this will append non-gpu time entry at end of df""" + nongpu_row = self.pd.DataFrame([df.iloc[-1]]) + nongpu_row["Category"] = nongpu_row["Name"] = "CPU(non-GPU)" + nongpu_row["Instances"] = 1 + nongpu_row["Elapsed Time (sec)"] = nongpu_sec + return nongpu_row + + def is_valid_file(self, base_file): + """asserts if base_file is non-existent or is empty""" + assert (os.path.isfile(base_file) and os.path.getsize(base_file) + > 0), f"{base_file} doesn't exist or is empty" + + def should_gen_file(self, new_file, base_file): + """figure out if new file should be generated from base_file""" + self.is_valid_file(base_file) + if (os.path.exists(new_file) + and (os.path.getmtime(new_file) > os.path.getmtime(base_file)) + and (os.path.getsize(base_file) > 0)): + logger.info("reusing %s", new_file) + return False + else: + logger.info("generating %s", new_file) + return True + + def gen_sum_file(self, file, nsys_cmd): + """ + generates sum file from nsys trace with times per kernel and + returns the name of the sum file + """ + import subprocess # nosec B404 + + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + + if not file_dir: + file_dir = "." + # Walk through trace and get the total non-overlapped time + nsys_stats_file = os.path.join(file_dir, + f"{file_name}_cuda_gpu_trace.csv") + sum_file = os.path.join(file_dir, + f"{file_name}_cuda_gpu_kernel_tracesum.csv") + if self.should_gen_file(nsys_stats_file, file): + cmd = [ + nsys_cmd, + "stats", + "-r", + "cuda_gpu_trace", + file, + "-o", + f"{file_dir}/{file_name}", + ] + cmd_str = " ".join(cmd) + logger.info("+ %s", cmd_str) + # estimate time based on calibrated 240M/min + file_size_mb = os.path.getsize(file) / 1e6 + logger.info( + "nsys stats for %.2f MB file expected to take %.2f min", + file_size_mb, + file_size_mb / 240, + ) + try: + subprocess.run(cmd) + except Exception: + logger.error("%s failed; Use --nsys_cmd to specify nsys path", + cmd_str) + exit(1) + logger.info("generating non-overalapped sum %s", sum_file) + self.gen_nonoverlapped_sum_from_gputrace(nsys_stats_file, sum_file) + self.is_valid_file(sum_file) + logger.info("Finished generating %s", sum_file) + return sum_file + + def gen_graph(self, in_file, out_dir, title, nsys_cmd, engine_model): + """generates graph and csv file from in_file into out_dir""" + # Initialize an empty DataFrame to store combined data + combined_df = self.pd.DataFrame() + for idx, (file, engine, model, total_sec) in enumerate(in_file): + file_dir = os.path.dirname(file) + file_name = os.path.basename(file) + if not file_dir: + file_dir = "." + sum_file = self.gen_sum_file(file, nsys_cmd) + # read kernel summary file + df = self.pd.read_csv(sum_file) + # annotate kernel to their categories + assert engine_model.get(engine), f"engine {engine} unknown" + assert engine_model[engine].get(model), f"model {model} unknown" + # remove nsys-rep from file_name for shorter x-label + file_name = file_name.replace(".nsys-rep", "") + df["Model_Engine"] = f"{model}_{engine}_{file_name}_{idx}" + self.anno_gpu_kernname(df, engine_model[engine][model]) + # patch in non-gpu time + gpu_sec = round(df["Elapsed Time (sec)"].sum(), 1) + total_sec = round(float(total_sec), 1) + if total_sec < gpu_sec: + logger.warning( + "Elapsed sec %.2f < GPU sec %.2f resetting Elapsed sec ", + total_sec, + gpu_sec, + ) + total_sec = gpu_sec + nongpu_row = self.make_nongpu_row(df, total_sec - gpu_sec) + df = self.pd.concat([df, nongpu_row], ignore_index=True) + combined_df = self.pd.concat([combined_df, df], ignore_index=True) + if out_dir is None: + out_dir = "." + else: + os.makedirs(out_dir, exist_ok=True) + # generate html file + self.make_html(combined_df, out_dir, title) + + +def parse_tuple(s): + return tuple(s.split(",")) + + +def main(): + logging.basicConfig(format=("%(asctime)s - %(levelname)s - %(message)s"), + level=logging.INFO) + parser = argparse.ArgumentParser( + description=( + "Process nsys rep and generate kernel non-overlapped cycles. \n" + "Example:\n" + "gputrc2graph.py --in_file d1.nsys-rep,trtllm,llama,100 \n" + "d2.nsys-rep,trtllm,gpt-oss,102 " + '--out_dir results/ --title "Model=gpt-oss TRTLLM chart"'), + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + # load supported engine_model + engine_model_supported = load_engine_model() + # Get a string representation of supported engine/model combinations + engine_model_supported_str = ", ".join( + f"{engine}:[{', '.join(models.keys())}]" + for engine, models in engine_model_supported.items()) + parser.add_argument( + "--in_file", + type=parse_tuple, + nargs="+", + help=("list of (nsys-rep, engine, model, elapsed_nonprofiled_sec) " + "separated by space. Elapsed_nonprofiled_sec is runtime without " + "profiling used to calculate non-gpu time. Specify 0 to use " + "elapsed time from nsys-rep but that might inflate non-gpu time. " + f"Available engine:[model] are: {engine_model_supported_str} " + f"Example: --in_file d1.nsys-rep,sglan,llama,100 " + "d2.nsys-rep,trtllm,gpt-oss,102"), + required=True, + ) + parser.add_argument("--out_dir", help=("output dir for result.csv/html")) + parser.add_argument("--title", help=("title for html chart")) + parser.add_argument( + "--nsys_cmd", + help=("nsys cmd, e.g. /usr/bin/nsys, Default: nsys"), + default="nsys", + ) + args = parser.parse_args() + gputrace = GPUTrace2Graph() + gputrace.gen_graph(args.in_file, args.out_dir, args.title, args.nsys_cmd, + engine_model_supported) + + +if __name__ == "__main__": + main() diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/images/csv.png b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/csv.png new file mode 100644 index 00000000000..3fd412f6577 Binary files /dev/null and b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/csv.png differ diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html.png b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html.png new file mode 100644 index 00000000000..35992c7f896 Binary files /dev/null and b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html.png differ diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html_tbl.png b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html_tbl.png new file mode 100644 index 00000000000..cb134a014cd Binary files /dev/null and b/tensorrt_llm/tools/profiler/nsys_profile_tools/images/html_tbl.png differ diff --git a/tensorrt_llm/tools/profiler/nsys_profile_tools/trtllm_engine_model.json b/tensorrt_llm/tools/profiler/nsys_profile_tools/trtllm_engine_model.json new file mode 100644 index 00000000000..9287a6d9c6d --- /dev/null +++ b/tensorrt_llm/tools/profiler/nsys_profile_tools/trtllm_engine_model.json @@ -0,0 +1,62 @@ +{ + "trtllm": { + "llama": { + "Fused_Moe_Kernel|gemm::|fused_moe|bmm_|GemmUniversal": "moe_gemm", + "gemm|nvjet_": "gemm", + "moe|Expert|Moe": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|AllReduce": "nccl_and_custom_ar", + "RMSNormKernel": "norm", + "topk": "topk", + "act_and_mul_|Activation": "activation", + "Rotary": "rope", + "SoftMax": "softmax", + "flash|splitKreduce|kernel_mha|mmha|fmha": "attn", + "elementwise": "elementwise", + "Quantize|cvt_": "quantize", + "reduce_kernel": "reduce", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "ds": { + "fp8_blockscale_gemm": "block_fp8_gemm", + "gemm::GroupProblemShape|Fused_Moe_Kernel|bmm_": "moe_gemm", + "gemm|matmul|nvjet|gemvx": "gemm", + "moe|buildExpertMaps|Moe|Expert|Moe": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce|AllReduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "topk": "topk", + "act_and_mul_|Activation": "activation", + "Rope": "rope", + "elementwise": "elementwise", + "fmha|flash_fwd_kernel": "attn", + "Quantize|fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + }, + "gpt-oss": { + "block_fp8|gemm_fp8_blockwise": "block_fp8_gemm", + "fused_moe_kernel|_group_gemm|GroupProblemShape|GemmUniversal|bmm_|matmul_ogs_|_topk_forward|_combined_routing|_sum_bitmatrix_rows|_compute_writeback_idx": "moe_gemm", + "gemm|matmul|nvjet": "gemm", + "moe|sigmoid|expert|splitKreduce|Moe": "moe", + "CatArrayBatched": "prepare_next", + "ncclDevKernel|cross_device_reduce|AllReduce": "nccl_and_custom_ar", + "Norm|_norm_": "norm", + "sbtopk": "topk", + "act_and_mul_|Activation": "activation", + "Rope": "rope", + "elementwise": "elementwise", + "fp8_quant|quant_fp8|cvt_": "quantize", + "reduce": "reduce", + "SoftMax": "softmax", + "fmha|mha|flash_fwd_kernel": "attn", + "triton": "triton_kernel", + "CUDA mem": "non-gpu-H_D_memops", + ".*": "misc" + } + } +} diff --git a/tensorrt_llm/version.py b/tensorrt_llm/version.py index 603fd689b75..93b6027df5c 100644 --- a/tensorrt_llm/version.py +++ b/tensorrt_llm/version.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "1.1.0rc1" +__version__ = "1.1.0rc2" diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index da64969337e..d761ae6851d 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -30,6 +30,8 @@ def get_default_kwargs(self): return { 'skip_tokenizer_init': False, 'trust_remote_code': True, + # TODO(https://github.com/NVIDIA/TensorRT-LLM/issues/7142): + # AutoDeploy does not support cache reuse yet. 'kv_cache_config': { 'enable_block_reuse': False, }, diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index cc970b452f1..6f89b7cc433 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -325,3 +325,4 @@ full:L40S/accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_ full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] SKIP (https://nvbugs/5471106) full:L40S/accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2] SKIP (https://nvbugs/5471108) test_e2e.py::test_multi_nodes_eval[llama4-models/nvidia/Llama-4-Maverick-17B-128E-Instruct-FP8-tp8pp2-mmlu] SKIP (https://nvbugs/5473781) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] SKIP (https://nvbugs/5476580) diff --git a/tests/unittest/_torch/executor/test_pytorch_model_engine.py b/tests/unittest/_torch/executor/test_pytorch_model_engine.py index 4cfec14c750..58876350424 100644 --- a/tests/unittest/_torch/executor/test_pytorch_model_engine.py +++ b/tests/unittest/_torch/executor/test_pytorch_model_engine.py @@ -140,6 +140,8 @@ class PyTorchModelEngineTestCase(unittest.TestCase): def test_pad_generation_requests(self) -> None: model_engine, kv_cache_manager = create_model_engine_and_kvcache() + resource_manager = ResourceManager( + {ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager}) seqlens_and_batch_sizes = [ (5, 1), @@ -155,8 +157,8 @@ def test_pad_generation_requests(self) -> None: batch.generation_requests = [] pages_before = kv_cache_manager.get_num_free_blocks() - with model_engine._maybe_pad_batch( - batch, kv_cache_manager) as padded_batch: + with model_engine.cuda_graph_runner.pad_batch( + batch, resource_manager) as padded_batch: # No padding for prefill self.assertIs(batch, padded_batch) self.assertEqual(kv_cache_manager.get_num_free_blocks(), @@ -166,9 +168,9 @@ def test_pad_generation_requests(self) -> None: batch.context_requests = [] batch.generation_requests = requests pages_before = kv_cache_manager.get_num_free_blocks() - new_dummy_block = 1 if model_engine.cuda_graph_dummy_request is None else 0 - with model_engine._maybe_pad_batch( - batch, kv_cache_manager) as padded_batch: + new_dummy_block = 1 if model_engine.cuda_graph_runner.padding_dummy_request is None else 0 + with model_engine.cuda_graph_runner.pad_batch( + batch, resource_manager) as padded_batch: if batch_size < 8 and max_seq_len < 25: self.assertEqual( len(padded_batch.generation_requests) % 8, 0) diff --git a/tests/unittest/_torch/helpers.py b/tests/unittest/_torch/helpers.py index 4b63769735d..86580f9b94a 100644 --- a/tests/unittest/_torch/helpers.py +++ b/tests/unittest/_torch/helpers.py @@ -162,3 +162,32 @@ def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor, results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output return results.view_as(x) + + +class MockPytorchBackendConfig: + + def __init__(self, use_cuda_graph, cuda_graph_padding_enabled): + self.use_cuda_graph = use_cuda_graph + self.cuda_graph_padding_enabled = cuda_graph_padding_enabled + + +class MockEngine: + """A replacement for SimpleNamespace that supports weak references.""" + + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +def create_mock_engine(batch_size: int): + + return MockEngine( + pytorch_backend_config=MockPytorchBackendConfig( + use_cuda_graph=True, cuda_graph_padding_enabled=False), + _cuda_graph_batch_sizes=[batch_size], + _max_cuda_graph_batch_size=batch_size, + max_beam_width=1, + is_spec_decode=False, + spec_config=None, + _cuda_graph_mem_pool=None, + use_mrope=False, + ) diff --git a/tests/unittest/_torch/modeling/test_modeling_exaone4.py b/tests/unittest/_torch/modeling/test_modeling_exaone4.py index 48ad7d8d835..ebf496b2c14 100644 --- a/tests/unittest/_torch/modeling/test_modeling_exaone4.py +++ b/tests/unittest/_torch/modeling/test_modeling_exaone4.py @@ -22,6 +22,7 @@ class Exaone4Config(PretrainedConfig): # TODO: Remove this once we have a proper config for Exaone4 SKIP_EXAONE4_HF_ACCURACY_TEST = True +from _torch.helpers import create_mock_engine from transformers.cache_utils import HybridCache from utils.util import getSMVersion @@ -30,8 +31,7 @@ class Exaone4Config(PretrainedConfig): from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_exaone4 import Exaone4ForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -338,6 +338,11 @@ def test_exaone4_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -345,19 +350,20 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: exaone4.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, + lambda inputs: exaone4.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -380,5 +386,6 @@ def run_forward(input_ids, position_ids, attn_metadata): ref.logits[:, -1].float(), atol=0.4, rtol=0.4) - + if graph_runner is not None: + graph_runner.clear() kv_cache_manager.shutdown() diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index e72957997f5..73cd4bf9bac 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -4,6 +4,7 @@ from typing import Any import torch +from _torch.helpers import create_mock_engine from parameterized import parameterized from transformers import LlamaConfig from transformers import LlamaForCausalLM as HFLlamaForCausalLM @@ -14,8 +15,7 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -326,6 +326,11 @@ def test_llama_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -333,19 +338,18 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: llama.forward(**inputs)) - + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, lambda inputs: llama.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -364,5 +368,6 @@ def run_forward(input_ids, position_ids, attn_metadata): ref.logits[:, -1].float(), atol=0.4, rtol=0.4) - + if graph_runner is not None: + graph_runner.clear() kv_cache_manager.shutdown() diff --git a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py index a0d09c18c71..2f7618cb39b 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama_min_latency.py @@ -4,6 +4,7 @@ import torch import transformers +from _torch.helpers import create_mock_engine from parameterized import parameterized from transformers import Llama4Config from transformers import \ @@ -20,8 +21,7 @@ from tensorrt_llm._torch.models.modeling_llama import \ Llama4ForConditionalGeneration from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -403,6 +403,10 @@ def test_llama_allclose_to_hf(self, scenario: AllCloseScenario) -> None: input_ids.size(-1) + gen_input_ids.size(-1)) ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() @@ -411,19 +415,19 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: llama.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, lambda inputs: llama.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: diff --git a/tests/unittest/_torch/modeling/test_modeling_mistral.py b/tests/unittest/_torch/modeling/test_modeling_mistral.py index c330126942e..f85837ddb14 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mistral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mistral.py @@ -7,6 +7,7 @@ import pytest import torch import transformers +from _torch.helpers import create_mock_engine from utils.util import getSMVersion import tensorrt_llm @@ -15,7 +16,8 @@ from tensorrt_llm._torch import model_config as model_config_lib from tensorrt_llm._torch.attention_backend import utils as attention_utils from tensorrt_llm._torch.models import modeling_mistral -from tensorrt_llm._torch.pyexecutor import cuda_graph_runner, resource_manager +from tensorrt_llm._torch.pyexecutor import resource_manager +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm.bindings import executor as executor_lib from tensorrt_llm.models import modeling_utils @@ -398,6 +400,11 @@ def test_mistral_3_vlm_allclose_to_hf(mistral_small_3_1_24b_config, backend, use ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not use_cuda_graph: @@ -405,22 +412,18 @@ def run_forward(input_ids, position_ids, attn_metadata): input_ids=input_ids, position_ids=position_ids, attn_metadata=attn_metadata ) else: - graph_runner = cuda_graph_runner.DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata - ) - graph_runner.capture(lambda inputs: mistral.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, lambda inputs: mistral.forward(**inputs), inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run( - { - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - } - ) + logits = graph_runner.replay(1, inputs) return logits if use_cuda_graph: @@ -438,3 +441,5 @@ def run_forward(input_ids, position_ids, attn_metadata): ) torch.testing.assert_close(logits, ref.logits[:, -1].float(), atol=0.4, rtol=0.4) + if graph_runner is not None: + graph_runner.clear() diff --git a/tests/unittest/_torch/modeling/test_modeling_mixtral.py b/tests/unittest/_torch/modeling/test_modeling_mixtral.py index 3b9e6896e32..1637120b304 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mixtral.py +++ b/tests/unittest/_torch/modeling/test_modeling_mixtral.py @@ -3,6 +3,7 @@ from dataclasses import dataclass import torch +from _torch.helpers import create_mock_engine from parameterized import parameterized from transformers import MixtralConfig from transformers import MixtralForCausalLM as HFMixtralForCausalLM @@ -15,8 +16,7 @@ from tensorrt_llm._torch.models.checkpoints.hf.mixtral_weight_mapper import \ MixtralHfWeightMapper from tensorrt_llm._torch.models.modeling_mixtral import MixtralForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -310,6 +310,11 @@ def test_mixtral_allclose_to_hf(self, scenario: Scenario): ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -317,19 +322,20 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: mixtral.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, + lambda inputs: mixtral.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -348,5 +354,6 @@ def run_forward(input_ids, position_ids, attn_metadata): ref.logits[:, -1].float(), atol=0.1, rtol=0.1) - + if graph_runner is not None: + graph_runner.clear() kv_cache_manager.shutdown() diff --git a/tests/unittest/_torch/modeling/test_modeling_mllama.py b/tests/unittest/_torch/modeling/test_modeling_mllama.py index 665e28919b5..1ecfd396612 100644 --- a/tests/unittest/_torch/modeling/test_modeling_mllama.py +++ b/tests/unittest/_torch/modeling/test_modeling_mllama.py @@ -3,6 +3,7 @@ from copy import deepcopy import torch +from _torch.helpers import create_mock_engine from parameterized import parameterized from test_modeling_llama import Scenario, reduce_llama_config from transformers import MllamaConfig @@ -15,8 +16,7 @@ from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_mllama import \ MllamaForConditionalGeneration -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -417,6 +417,11 @@ def test_mllama_allclose_to_hf_text_only(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -424,19 +429,19 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: mllama.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, lambda inputs: mllama.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -455,3 +460,6 @@ def run_forward(input_ids, position_ids, attn_metadata): ref.logits[:, -1].float(), atol=0.3, rtol=0.3) + if graph_runner is not None: + graph_runner.clear() + kv_cache_manager.shutdown() diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron.py b/tests/unittest/_torch/modeling/test_modeling_nemotron.py index a17b050ec0f..11456d0f099 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron.py @@ -4,6 +4,7 @@ from typing import Any import torch +from _torch.helpers import create_mock_engine from parameterized import parameterized from transformers import NemotronConfig from transformers import NemotronForCausalLM as HFNemotronForCausalLM @@ -14,8 +15,7 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_nemotron import NemotronForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -318,6 +318,11 @@ def test_nemotron_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -325,19 +330,20 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: nemotron.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, + lambda inputs: nemotron.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -357,4 +363,6 @@ def run_forward(input_ids, position_ids, attn_metadata): atol=0.4, rtol=0.4) + if graph_runner is not None: + graph_runner.clear() kv_cache_manager.shutdown() diff --git a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py index 58c854931e9..3e727e654bf 100644 --- a/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py +++ b/tests/unittest/_torch/modeling/test_modeling_nemotron_h.py @@ -33,7 +33,9 @@ def extract_decode_logprobs(result: RequestOutput, def create_nemotron_h_llm(use_cuda_graph, disable_overlap_scheduler, max_batch_size, - mamba_ssm_cache_dtype=None): + mamba_ssm_cache_dtype=None, + enable_chunked_prefill=False, + max_num_tokens=None): """Create LLM with specific overlap scheduler setting""" model_dir = f"{llm_models_root(check=True)}/Nemotron-H-8B-Base-8K" return LLM( @@ -47,6 +49,8 @@ def create_nemotron_h_llm(use_cuda_graph, mamba_ssm_cache_dtype="auto" if mamba_ssm_cache_dtype is None else mamba_ssm_cache_dtype), sampler_type="TRTLLMSampler", + enable_chunked_prefill=enable_chunked_prefill, + max_num_tokens=max_num_tokens, ) @@ -336,3 +340,62 @@ def test_nemotron_h_cuda_graph_overlap_scheduler(): msg=lambda x: f"Prompt {i}: with/without overlap scheduler (with CG) logprobs for all selected tokens {x}" ) + + +def test_nemotron_h_chunked_prefill(): + # Long prompts (~100 tokens) to make sure chunked prefill is enabled + # (At the time of development, tokens_per_block isn't configurable from the LLM API, + # and max_tokens (i.e. chunk size) needs to be a multiple of tokens_per_block) + prompts = [ + "Artificial Intelligence in Healthcare: Artificial intelligence (AI) is transforming healthcare by improving diagnostics, treatment plans, and patient care. AI algorithms can analyze medical images with high accuracy, assist in early disease detection, and personalize treatment plans based on patient data. Additionally, AI-powered chatbots and virtual assistants provide support to patients, enhancing accessibility and efficiency in healthcare services. As AI technology continues to advance, its integration into healthcare systems promises to deliver better outcomes and reduce costs. With continuous research and development, AI in healthcare is poised to", + "The Role of Cloud Computing: Cloud computing has revolutionized the way businesses operate by providing scalable, on-demand access to computing resources. This technology allows organizations to store and process data remotely, reducing the need for physical infrastructure and enabling greater flexibility. Cloud services facilitate collaboration, enhance data security, and support the deployment of innovative applications. As businesses increasingly adopt cloud solutions, they benefit from improved efficiency, cost savings, and the ability to rapidly adapt to changing market conditions. Companies leveraging cloud computing are better positioned to", + "Advancements in Renewable Energy: Renewable energy technologies, such as solar and wind power, are crucial for addressing climate change and reducing dependence on fossil fuels. Advances in energy storage, grid integration, and efficiency are making renewable energy sources more viable and cost-effective. Innovations in materials science and engineering are also driving the development of next-generation renewable technologies. As global efforts to combat climate change intensify, the continued advancement of renewable energy will play a pivotal role in achieving a sustainable future. Governments and industries are increasingly investing in", + "The Importance of Cybersecurity: In today's digital age, cybersecurity has become essential to protect sensitive information and maintain the integrity of systems. With the rise of cyber threats such as hacking, phishing, and ransomware, organizations must implement robust security measures to safeguard their data. Cybersecurity involves a combination of technologies, processes, and practices designed to defend against unauthorized access and attacks. By staying vigilant and updating security protocols, businesses can mitigate risks and ensure the safety of their digital assets. Proactive cybersecurity strategies are crucial in", + "The Impact of Artificial Intelligence on Education: Artificial intelligence is reshaping education by providing personalized learning experiences and automating administrative tasks. AI-driven educational tools can adapt to individual student needs, offering tailored feedback and resources to enhance learning outcomes. Additionally, AI can streamline administrative processes, allowing educators to focus more on teaching and student engagement. As AI continues to evolve, its role in education will expand, offering new opportunities for innovation and efficiency. The integration of AI in classrooms promises to revolutionize how students learn and how educators manage their", + ] + sampling_config = SamplingParams(max_tokens=10, + temperature=0.0, + return_context_logits=True, + return_generation_logits=True) + + with create_nemotron_h_llm(use_cuda_graph=False, + disable_overlap_scheduler=True, + max_batch_size=16) as llm: + outputs = llm.generate(prompts, + sampling_params=sampling_config, + use_tqdm=True) + + with create_nemotron_h_llm(use_cuda_graph=False, + disable_overlap_scheduler=True, + max_batch_size=16, + enable_chunked_prefill=True, + max_num_tokens=64) as llm: + chunked_prefill_outputs = llm.generate(prompts, + sampling_params=sampling_config, + use_tqdm=True) + + for i, (output, chunked_prefill_output) in enumerate( + zip(outputs, chunked_prefill_outputs)): + assert output.outputs[0].text == chunked_prefill_output.outputs[0].text + + # assert same prefill logprobs. Same atol as diff between mcore and initial impl + prefill_logprobs = extract_prefill_logprobs(output) + chunked_prefill_logprobs = extract_prefill_logprobs( + chunked_prefill_output) + torch.testing.assert_close( + prefill_logprobs, + chunked_prefill_logprobs, + atol=0.3, + rtol=0.05, + msg=lambda x: f"Prompt {i} prefill logprobs {x}") + + # Decode logprobs shouldn't be affected by chunked prefill - tolerance like batching tolerance + decode_logprobs = extract_decode_logprobs(output) + chunked_decode_logprobs = extract_decode_logprobs( + chunked_prefill_output) + torch.testing.assert_close( + decode_logprobs, + chunked_decode_logprobs, + atol=0.2, + rtol=0.05, + msg=lambda x: f"Prompt {i} decode logprobs {x}") diff --git a/tests/unittest/_torch/modeling/test_modeling_phi3.py b/tests/unittest/_torch/modeling/test_modeling_phi3.py index 4a277c01ba1..7c5ffd94141 100644 --- a/tests/unittest/_torch/modeling/test_modeling_phi3.py +++ b/tests/unittest/_torch/modeling/test_modeling_phi3.py @@ -4,6 +4,7 @@ from typing import Any import torch +from _torch.helpers import create_mock_engine from transformers import Phi3Config from transformers import Phi3ForCausalLM as HFPhi3ForCausalLM from utils.util import default_dtype @@ -13,8 +14,7 @@ from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_phi3 import Phi3ForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -310,6 +310,11 @@ def test_phi3_allclose_to_hf(self) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -317,19 +322,19 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: phi3.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, lambda inputs: phi3.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -348,5 +353,6 @@ def run_forward(input_ids, position_ids, attn_metadata): ref.logits[:, -1].float(), atol=0.4, rtol=0.4) - + if graph_runner is not None: + graph_runner.clear() kv_cache_manager.shutdown() diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen.py b/tests/unittest/_torch/modeling/test_modeling_qwen.py index fe8e0f5a4cd..d1d129de083 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen.py @@ -17,12 +17,12 @@ from tensorrt_llm._torch.models.modeling_qwen import ( Qwen2ForCausalLM, Qwen2ForProcessRewardModel) # yapf: enable -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from _torch.helpers import create_mock_engine from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from utils.llm_data import llm_models_root from utils.util import getSMVersion @@ -265,6 +265,11 @@ def test_qwen_allclose_to_hf(self, scenario: Scenario) -> None: ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -272,19 +277,19 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: qwen.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, lambda inputs: qwen.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -303,7 +308,8 @@ def run_forward(input_ids, position_ids, attn_metadata): ref.logits[:, -1].float(), atol=0.4, rtol=0.4) - + if graph_runner is not None: + graph_runner.clear() kv_cache_manager.shutdown() @parameterized.expand( diff --git a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py index 608a17fe1b5..8658ae0e242 100644 --- a/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py +++ b/tests/unittest/_torch/modeling/test_modeling_qwen_moe.py @@ -3,6 +3,7 @@ from dataclasses import dataclass import torch +from _torch.helpers import create_mock_engine from parameterized import parameterized from transformers import Qwen2MoeConfig from transformers import Qwen2MoeForCausalLM as HFQwen2MoeForCausalLM @@ -15,8 +16,7 @@ from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \ Qwen2MoeHfWeightMapper from tensorrt_llm._torch.models.modeling_qwen_moe import Qwen2MoeForCausalLM -from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ - DecodingCUDAGraphRunner +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import CUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig from tensorrt_llm.mapping import Mapping @@ -315,6 +315,11 @@ def test_qwen_moe_allclose_to_hf(self, scenario: Scenario): ] gen_position_ids = torch.cat(gen_position_ids).unsqueeze(0).cuda() + graph_runner = None + if scenario.use_cuda_graph: + mock_engine = create_mock_engine(1) + graph_runner = CUDAGraphRunner(mock_engine) + def run_forward(input_ids, position_ids, attn_metadata): attn_metadata.prepare() if not scenario.use_cuda_graph: @@ -322,19 +327,20 @@ def run_forward(input_ids, position_ids, attn_metadata): position_ids=position_ids, attn_metadata=attn_metadata) else: - graph_runner = DecodingCUDAGraphRunner( - attn_metadata.max_num_requests, "cuda", attn_metadata) - graph_runner.capture(lambda inputs: qwen_moe.forward(**inputs)) + inputs = { + "input_ids": input_ids, + "position_ids": position_ids, + "attn_metadata": attn_metadata, + } + graph_runner.capture(1, + lambda inputs: qwen_moe.forward(**inputs), + inputs) for _ in range(2): # Run it twice. This helps us catch problems if buffers are accidentally reallocated # in prepare(). attn_metadata.prepare() - logits = graph_runner.run({ - "input_ids": input_ids, - "position_ids": position_ids, - "attn_metadata": attn_metadata, - }) + logits = graph_runner.replay(1, inputs) return logits if scenario.use_cuda_graph: @@ -353,5 +359,6 @@ def run_forward(input_ids, position_ids, attn_metadata): ref.logits[:, -1].float(), atol=0.1, rtol=0.1) - + if graph_runner is not None: + graph_runner.clear() kv_cache_manager.shutdown() diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index ffb8e33766a..f26fa244f1f 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -1,6 +1,9 @@ +import json import os import sys +import tempfile import unittest +from pathlib import Path import pytest import torch @@ -120,5 +123,107 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, assert text_spec == text_ref +def test_deepseek_eagle3(): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + use_one_model = False + enable_chunked_prefill = False + + # Eagle3 one model works with overlap scheduler and block reuse. + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 150: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + eagle_config = { + 'architectures': ['LlamaForCausalLMEagle3'], + 'attention_bias': False, + 'attention_dropout': 0.0, + 'bos_token_id': 128000, + 'eos_token_id': [128001, 128008, 128009], + 'eagle_config': { + 'use_aux_hidden_state': False, + 'use_input_layernorm_in_first_layer': True, + 'use_last_layernorm': True, + 'use_mtp_layernorm': False + }, + 'head_dim': 128, + 'hidden_act': 'silu', + 'hidden_size': 2560, + 'initializer_range': 0.02, + 'intermediate_size': 16384, + 'max_position_embeddings': 4096, + 'mlp_bias': False, + 'model_type': 'llama', + 'num_attention_heads': 32, + 'num_eagle_features': 1, + 'num_hidden_layers': 1, + 'num_key_value_heads': 8, + 'pretraining_tp': 1, + 'rms_norm_eps': 1e-05, + 'rope_scaling': { + 'factor': 8.0, + 'high_freq_factor': 4.0, + 'low_freq_factor': 1.0, + 'original_max_position_embeddings': 8192, + 'rope_type': 'llama3' + }, + 'rope_theta': 500000.0, + 'tie_word_embeddings': False, + 'torch_dtype': 'bfloat16', + 'transformers_version': '4.52.4', + 'use_cache': True, + 'vocab_size': 129280, + 'draft_vocab_size': 129280 + } + with tempfile.TemporaryDirectory() as temp_dir: + eagle_model_dir = Path(temp_dir) + config_path = eagle_model_dir / "config.json" + with config_path.open("w") as f: + json.dump(eagle_config, f, indent=2) + target_model_dir = f"{models_path}/DeepSeek-V3-Lite/nvfp4_moe_only" + + # bs > 1 gives non-deterministic when doing IFB. There are slight chances + # that ref and spec does not match 100% + max_batch_size = 16 + max_draft_len = 3 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + free_gpu_memory_fraction=0.5) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + max_num_tokens=4096, + max_seq_len=4096, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + ) + + spec_config = EagleDecodingConfig( + max_draft_len=max_draft_len, + speculative_model_dir=eagle_model_dir, + # Llama 3 does not support one model eagle. + eagle3_one_model=use_one_model, + eagle3_layers_to_capture={29}, + load_format="dummy") + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + + if __name__ == "__main__": unittest.main() diff --git a/tests/unittest/_torch/thop/test_causal_conv1d_op.py b/tests/unittest/_torch/thop/test_causal_conv1d_op.py index c5e42e2618c..54793854c9a 100644 --- a/tests/unittest/_torch/thop/test_causal_conv1d_op.py +++ b/tests/unittest/_torch/thop/test_causal_conv1d_op.py @@ -26,11 +26,15 @@ @pytest.mark.parametrize( - "dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu, paged_cache", + "dim, dconv, req_type, dtype, batch_size, max_seq_len, remove_padding, apply_silu, paged_cache, use_initial_state", list( product([2048], [4], ['context', 'generation'], ['float16', 'float32', 'bfloat16'], [5], [16], [False, True], - [False, True], [False, True])) + + [False, True], [False, True], [False])) + + # test with initial state + list( + product([2048], [4], ['context'], ['bfloat16'], [5], [16], + [False, True], [False], [False, True], [True])) + # long sequence tests to cover the int overflow issue list( map( @@ -42,10 +46,11 @@ "The long sequence test needs at least 33GB memory, skipping" )), product([5376], [4], ['context'], ['float16', 'bfloat16'], [2], - [131072], [False, True], [False, True], [False])))) + [131072], [False, True], [False, True], [False], [False])))) @pytest.mark.high_cuda_memory def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len, - remove_padding, apply_silu, paged_cache): + remove_padding, apply_silu, paged_cache, + use_initial_state): device = "cuda" seq_len = max_seq_len if req_type == "context" else 1 mean = 0.0 @@ -68,7 +73,7 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len, host_context_lengths = torch.ones( (batch_size, ), dtype=torch.int32) * seq_len - if req_type == "context": + if req_type == "context" and not use_initial_state: conv_state = torch.zeros([batch_size, dim, dconv - 1], dtype=torch_dtype, device=device) @@ -111,7 +116,8 @@ def test_causal_conv1d(dim, dconv, req_type, dtype, batch_size, max_seq_len, conv_weight_input = conv_weight.squeeze(1).contiguous() if req_type == "context": - has_initial_state = None + has_initial_state = None if not use_initial_state else torch.ones( + batch_size, device=device, dtype=torch.bool) torch.ops.trtllm.causal_conv1d_fwd( x_in_out, diff --git a/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py b/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py index ea3c2c2c3cd..e26fe007763 100644 --- a/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py +++ b/tests/unittest/_torch/thop/test_mamba2_chunk_ss_update.py @@ -21,6 +21,8 @@ from utils.torch_ref import (selective_state_update_ref, ssd_chunk_scan_combined_ref) +from tensorrt_llm._torch.modules.mamba.mamba2_metadata import \ + cu_seqlens_to_chunk_indices_offsets from tensorrt_llm._torch.modules.mamba.selective_state_update import \ selective_state_update from tensorrt_llm._torch.modules.mamba.ssd_combined import \ @@ -30,51 +32,58 @@ @pytest.mark.parametrize( - "dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding, paged_cache", + "dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, remove_padding, paged_cache, use_initial_states", # dim parametrization list( product([1024, 2048, 5120], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [3], [16], [False], [True], [False])) + + ['bfloat16'], [3], [16], [False], [True], [False], [False])) + # headdim parametrization list( product([2048], [32, 64, 128, 256], [1], [128], ['context', 'generation'], ['bfloat16'], [3], [16], [False], - [True], [False])) + + [True], [False], [False])) + # ngroups parametrization list( product([2048], [64], [1, 4], [128], ['context', 'generation'], - ['bfloat16'], [3], [16], [False], [True], [False])) + + ['bfloat16'], [3], [16], [False], [True], [False], [False])) + # dstate parametrization list( product([2048], [64], [1], [64, 96, 128, 256], ['context', 'generation'], ['bfloat16'], [3], [16], [False], - [True], [False])) + + [True], [False], [False])) + # dtype parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], ['float16', 'bfloat16', 'float32'], [3], [16], [False], [True], - [False])) + + [False], [False])) + # batch_size parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [1, 2, 8, 16], [16], [False], [True], [False])) + + ['bfloat16'], [1, 2, 8, 16], [16], [False], [True], [False], + [False])) + # max_seq_len parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], ['bfloat16'], [3], [32, 64, 256, 2048, 16384], [False], [True], - [False])) + + [False], [False])) + # has_z parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [3], [32], [True, False], [True], [False])) + + ['bfloat16'], [3], [32], [True, False], [True], [False], + [False])) + # remove_padding parametrization list( product([2048], [64], [1], [128], ['context', 'generation'], - ['bfloat16'], [3], [32], [False], [True, False], [False])) + + ['bfloat16'], [3], [32], [False], [True, False], [False], + [False])) + # paged_cache parametrization (relevant for generation only) list( product([2048], [64], [1], [128], ['generation'], ['bfloat16'], [3], - [32], [False], [False], [True, False])) + + [32], [False], [False], [True, False], [False])) + + # use_initial_states parametrization (relevant for context only and remove_padding=True) + list( + product([2048], [64], [1], [128], ['context'], ['bfloat16'], [3], [32], + [False], [True], [False], [True, False])) + # long sequence test to cover the int overflow issue [ pytest.param( @@ -89,6 +98,7 @@ False, False, False, + False, marks=pytest.mark.skipif( get_total_gpu_memory(0) < 68 * 1024**3, reason= @@ -97,7 +107,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, req_type, dtype, batch_size, max_seq_len, has_z, - remove_padding, paged_cache): + remove_padding, paged_cache, + use_initial_states): # configs device = "cuda" seq_len = max_seq_len if req_type == 'context' else 1 @@ -168,6 +179,8 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D = torch.randn(nheads, device=device) if has_z: z = torch.randn_like(x) + if use_initial_states: + initial_states = state.clone() if req_type == 'generation': # remove the seqlen dimension @@ -193,8 +206,13 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, C_ref = C.detach().clone() D_ref = D.detach().clone() z_ref = z.detach().clone() if has_z else None + initial_states_ref = state_ref.clone() if use_initial_states else None if req_type == "context": + if use_initial_states: + assert remove_padding + chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + cu_seqlens, chunk_size) out, ssm_state = mamba_chunk_scan_combined( x, dt, @@ -205,6 +223,9 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D, z=z if has_z else None, dt_bias=dt_bias, + initial_states=initial_states if use_initial_states else None, + chunk_indices=chunk_indices if use_initial_states else None, + chunk_offsets=chunk_offsets if use_initial_states else None, seq_idx=seq_idx if remove_padding else None, cu_seqlens=cu_seqlens if remove_padding else None, dt_softplus=delta_softplus, @@ -273,7 +294,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D_ref, z=z_ref[:, start:end, ...] if has_z else None, dt_bias=dt_bias_ref, - dt_softplus=delta_softplus) + dt_softplus=delta_softplus, + initial_states=initial_states_ref[i:i + 1, ...] + if use_initial_states else None, + ) out_ref[0, start:end, ...] = part_out_ref.squeeze(0) state_ref[i, ...] = part_state_ref.squeeze(0) elif long_context: @@ -295,7 +319,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D_ref, z=z_ref[i:i + 1, ...] if has_z else None, dt_bias=dt_bias_ref, - dt_softplus=delta_softplus) + dt_softplus=delta_softplus, + initial_states=initial_states_ref[i:i + 1, ...] + if use_initial_states else None, + ) out_ref[i, ...] = part_out_ref.squeeze(0) state_ref[i, ...] = part_state_ref.squeeze(0) else: @@ -309,7 +336,10 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, D=D_ref, z=z_ref if has_z else None, dt_bias=dt_bias_ref, - dt_softplus=delta_softplus) + dt_softplus=delta_softplus, + initial_states=initial_states_ref + if use_initial_states else None, + ) elif req_type == 'generation': out_ref = selective_state_update_ref(state_ref, x_ref, @@ -330,3 +360,229 @@ def test_mamba2_chunk_scan_selective_state_update(dim, headdim, ngroups, dstate, state_ref, rtol=1e-2, atol=atol[dtype]) + + +@pytest.mark.parametrize("mamba_chunk_size", [8, 256]) +@pytest.mark.parametrize("seqlens", [ + (16, 2, 8, 13), + (270, 88, 212, 203), + (16, 20), +]) +def test_mamba2_chunk_scan_combined_prefill_chunking(mamba_chunk_size, seqlens): + dim = 1024 + headdim = 64 + ngroups = 1 + dstate = 128 + + # test in high precision to distinguish between numeric instabilities and actual errors + dtype = 'float32' + + num_sequences = len(seqlens) + has_z = True + + device = "cuda" + nheads = dim // headdim + delta_softplus = True + mean = 0.0 + std_dev = 0.1 + + torch_dtype = str_dtype_to_torch(dtype) + + seqlens = torch.tensor(seqlens, dtype=torch.int32, device=device) + cu_seqlens = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(seqlens, dim=0, dtype=torch.int32) + ], + dim=0) + seq_idx = torch.repeat_interleave(torch.arange(len(seqlens), + dtype=torch.int32, + device=device), + seqlens, + output_size=cu_seqlens[-1]).unsqueeze(0) + input_batch_size = 1 + input_seq_len = cu_seqlens[-1] + + # test data + torch.random.manual_seed(0) + x = torch.empty(input_batch_size, + input_seq_len, + nheads, + headdim, + device=device, + dtype=torch_dtype) + x.normal_(mean, std_dev) + dt = torch.randn(input_batch_size, + input_seq_len, + nheads, + device=device, + dtype=torch_dtype) + dt_bias = torch.rand(nheads, device=device) - 4.0 + A = -torch.rand(nheads, device=device) - 1.0 + B = torch.randn(input_batch_size, + input_seq_len, + ngroups, + dstate, + device=device, + dtype=torch_dtype) + C = torch.randn_like(B) + D = torch.randn(nheads, device=device) + + z = torch.randn_like(x) + + ## full seqlen computation + out_ref, state_ref = mamba_chunk_scan_combined( + x, + dt, + A, + B, + C, + chunk_size=mamba_chunk_size, + D=D, + z=z if has_z else None, + dt_bias=dt_bias, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + dt_softplus=delta_softplus, + return_final_states=False, + return_varlen_states=True, + ) + + ## chunked seqlen computation + # first chunk + chunked_seqlens = seqlens // 2 + chunked_cu_seqlens = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(chunked_seqlens, dim=0, dtype=torch.int32) + ], + dim=0) + chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(chunked_seqlens), dtype=torch.int32, device=device), + chunked_seqlens, + output_size=chunked_cu_seqlens[-1]).unsqueeze(0) + chunked_input_seq_len = chunked_cu_seqlens[-1] + x_chunked = torch.zeros_like(x)[:, :chunked_input_seq_len, ...] + dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...] + B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...] + C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...] + z_chunked = torch.zeros_like(z)[:, :chunked_input_seq_len, ...] + for i in range(num_sequences): + # yapf: disable + chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] + + x_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(x, i) + dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) + B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) + C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) + z_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(z, i) + # yapf: enable + + partial_out, partial_state = mamba_chunk_scan_combined( + x_chunked, + dt_chunked, + A, + B_chunked, + C_chunked, + chunk_size=mamba_chunk_size, + D=D, + z=z_chunked, + dt_bias=dt_bias, + seq_idx=chunked_seq_idx, + cu_seqlens=chunked_cu_seqlens, + dt_softplus=delta_softplus, + return_final_states=False, + return_varlen_states=True, + ) + + # remaining chunk + remaining_chunked_seqlens = seqlens - chunked_seqlens + remaining_chunked_cu_seqlens = torch.cat([ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(remaining_chunked_seqlens, dim=0, dtype=torch.int32) + ], + dim=0) + remaining_chunked_seq_idx = torch.repeat_interleave( + torch.arange(len(remaining_chunked_seqlens), + dtype=torch.int32, + device=device), + remaining_chunked_seqlens, + output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0) + remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1] + # yapf: disable + remaining_x_chunked = torch.zeros_like(x)[:, :remaining_chunked_input_seq_len, ...] + remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] + remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] + remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] + remaining_z_chunked = torch.zeros_like(z)[:, :remaining_chunked_input_seq_len, ...] + for i in range(num_sequences): + remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] + + remaining_x_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(x, i) + remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) + remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) + remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) + remaining_z_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(z, i) + + # assert input chunking is correct + concat_chunk_f = lambda pt1, pt2, i: torch.cat([ + pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...], + pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...], + ], + dim=1) + concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) + + assert concat_batch_f(x_chunked, remaining_x_chunked).equal(x) + assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt) + assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B) + assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C) + assert concat_batch_f(z_chunked, remaining_z_chunked).equal(z) + # yapf: enable + + chunk_indices, chunk_offsets = cu_seqlens_to_chunk_indices_offsets( + remaining_chunked_cu_seqlens, mamba_chunk_size) + + out_chunked, state_chunked = mamba_chunk_scan_combined( + remaining_x_chunked, + remaining_dt_chunked, + A, + remaining_B_chunked, + remaining_C_chunked, + chunk_size=mamba_chunk_size, + D=D, + z=remaining_z_chunked, + dt_bias=dt_bias, + initial_states=partial_state, + chunk_indices=chunk_indices, + chunk_offsets=chunk_offsets, + seq_idx=remaining_chunked_seq_idx, + cu_seqlens=remaining_chunked_cu_seqlens, + dt_softplus=delta_softplus, + return_final_states=False, + return_varlen_states=True, + ) + out = concat_batch_f(partial_out, out_chunked) + + # kernel chunked is same as kernel overall + # tight tolerance to find subtle correctness issues + rtol = 1e-2 + atol = 2e-3 + for i in range(num_sequences): + out_seq = out[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + out_seq_ref = out_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...] + torch.testing.assert_close(out_seq[:, :chunked_seqlens[i], ...], + out_seq_ref[:, :chunked_seqlens[i], ...], + rtol=rtol, + atol=atol, + msg=lambda x: f"seq{i} output part1 " + x) + torch.testing.assert_close(out_seq[:, chunked_seqlens[i]:, ...], + out_seq_ref[:, chunked_seqlens[i]:, ...], + rtol=rtol, + atol=atol, + msg=lambda x: f"seq{i} output part2 " + x) + + state_seq = state_chunked[i] + state_seq_ref = state_ref[i] + torch.testing.assert_close(state_seq, + state_seq_ref, + rtol=rtol, + atol=atol, + msg=lambda x: f"seq{i} state " + x) diff --git a/tests/unittest/utils/torch_ref.py b/tests/unittest/utils/torch_ref.py index 6e666bed264..d8a6b258c57 100644 --- a/tests/unittest/utils/torch_ref.py +++ b/tests/unittest/utils/torch_ref.py @@ -480,7 +480,8 @@ def ssd_chunk_scan_combined_ref(x, D=None, z=None, dt_bias=None, - dt_softplus=False): + dt_softplus=False, + initial_states=None): """ Argument: x: (batch, seqlen, nheads, headdim) @@ -492,6 +493,7 @@ def ssd_chunk_scan_combined_ref(x, D: (nheads, headdim) or (nheads,) z: (batch, seqlen, nheads, headdim) dt_bias: (nheads,) + initial_states: (batch, nheads, dstate, headdim) Return: out: (batch, seqlen, nheads, headdim) final_states: (batch, nheads, dstate, headdim) @@ -520,8 +522,16 @@ def ssd_chunk_scan_combined_ref(x, states = states.to(torch.float32) # 2. Pass the state to all the chunks by weighted cumsum. # state_passing_ref is much less numerically stable + # align initial_states shape with states shape + initial_states = rearrange( + initial_states, + "... n p -> ... p n") if initial_states is not None else None states, final_states = state_passing_ref( - rearrange(states, "... p n -> ... (p n)"), dA_cumsum[:, :, :, -1]) + rearrange(states, "... p n -> ... (p n)"), + dA_cumsum[:, :, :, -1], + rearrange(initial_states, "... p n-> ... (p n)") + if initial_states is not None else None, + ) states, final_states = [ rearrange(t, "... (p n) -> ... p n", n=dstate) for t in [states, final_states]