Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
1e9d545
Add new 'deviceinterface' registration key to support device variants
NicolasHug Jul 31, 2025
97af12e
Add dummy custom nvdec interface
NicolasHug Jul 31, 2025
0fb3e9a
Let SingleStreamDecoder accept a variant
NicolasHug Jul 31, 2025
2aeda55
Let decodeAVFrame() call into custom decode packet function if it exists
NicolasHug Jul 31, 2025
15c49aa
properly import headers
NicolasHug Jul 31, 2025
f1c7083
linter
NicolasHug Jul 31, 2025
7c387c9
WIP
NicolasHug Aug 1, 2025
f7fb8a8
Basic cache
NicolasHug Aug 1, 2025
3b7427f
Expose variant, fix libcuda dep
NicolasHug Aug 1, 2025
92531f0
remove cache for now
NicolasHug Aug 2, 2025
d6ede53
simplify some stuff
NicolasHug Aug 2, 2025
b56e039
Debug
NicolasHug Aug 2, 2025
ffbfecc
DEBUG
NicolasHug Aug 2, 2025
aff2c1d
Fix: add bitstream filtering
NicolasHug Aug 2, 2025
3719e39
Fix stride handling
NicolasHug Aug 2, 2025
dec3175
remove prints
NicolasHug Aug 2, 2025
de83894
Merge branch 'main' of github.com:pytorch/torchcodec into nvdec-interace
NicolasHug Aug 13, 2025
17bf28e
Keep device variant private and usable via cuda:0:custom_nvdec
NicolasHug Aug 28, 2025
f00fd79
Add TODO
NicolasHug Aug 28, 2025
c240c70
Fix confusing name
NicolasHug Aug 28, 2025
7edf916
Simplify Cmake, vendor nvcuvid headers
NicolasHug Aug 28, 2025
52e2817
Merge branch 'main' of github.com:pytorch/torchcodec into nvdecclean
NicolasHug Aug 29, 2025
666b592
remove prints
NicolasHug Aug 29, 2025
82f9764
Simplify main decoding loop by duplicating logic
NicolasHug Aug 29, 2025
315811b
Work through BSF
NicolasHug Aug 29, 2025
02b7cfc
Pass over decoder impl
NicolasHug Aug 29, 2025
2971aac
Remove call to getCudaCtx - probably a leftover
NicolasHug Aug 29, 2025
3c2aee3
cleanups, comments
NicolasHug Sep 1, 2025
ca3769c
receive / send based implem
NicolasHug Sep 4, 2025
ded8562
Remove display callback, WIP
NicolasHug Sep 4, 2025
3acdc67
Remove display callback
NicolasHug Sep 4, 2025
ea84696
Fix decode surface bug
NicolasHug Sep 4, 2025
380afa7
Fix some color-space stuff
NicolasHug Sep 4, 2025
4370fda
Fix some ordering and set duration
NicolasHug Sep 4, 2025
e8de1c2
remove test stuff
NicolasHug Sep 4, 2025
cf32675
Merge branch 'main' of github.com:pytorch/torchcodec into nvdec_recei…
NicolasHug Sep 4, 2025
013b3f7
Fix some stuff
NicolasHug Sep 4, 2025
74c0db6
Remove printf stuff
NicolasHug Sep 5, 2025
30ae69a
Add todo
NicolasHug Sep 5, 2025
169671d
Add debug stuff
NicolasHug Sep 5, 2025
b8c3f9d
More debug stuff
NicolasHug Sep 5, 2025
a3abacd
Use priority queue
NicolasHug Sep 5, 2025
d823e8a
Fix some tests
NicolasHug Sep 5, 2025
dac6ccb
Cache... and fix????
NicolasHug Sep 5, 2025
3d338b8
Fix some avFrame duration setting
NicolasHug Sep 6, 2025
7d8fbad
Add psnr check
NicolasHug Sep 6, 2025
13edb83
fixed AV1 frame order, I think
NicolasHug Sep 6, 2025
29edb40
properly set av1 pts?
NicolasHug Sep 6, 2025
f58816a
ops tests
NicolasHug Sep 8, 2025
ac7b387
Merge branch 'main' of github.com:pytorch/torchcodec into nvdec_recei…
NicolasHug Sep 8, 2025
756e362
Add debug stuff, try to create symlink
NicolasHug Sep 8, 2025
c1c54e7
mongolo
NicolasHug Sep 8, 2025
27a1574
CI debug
NicolasHug Sep 8, 2025
d99a225
Try to build only ffmpeg7
NicolasHug Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions packaging/pre_build_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,109 @@ set -ex
# and PyTorch actually has it included. PyTorch, however, does not have the
# CMake helpers.
conda install -y pybind11 -c conda-forge


# Search for nvcuvid library in various locations for debugging CI build issues
echo "[NVCUVID-SEARCH] === Searching for nvcuvid library ==="

# First, let's find where CUDA nppi libraries are located
echo "[NVCUVID-SEARCH] Looking for CUDA nppi libraries to find potential nvcuvid location..."
NPPI_LOCATIONS=()

# Search for CUDA nppi libraries that CMake should find
for nppi_lib in "libnppi.so*" "libnppicc.so*" "nppi.so*" "nppicc.so*" "libnppi*" "libnppicc*"; do
found_nppi=$(find /usr -name "$nppi_lib" 2>/dev/null | head -5)
if [ -n "$found_nppi" ]; then
echo "[NVCUVID-SEARCH] Found CUDA nppi library: $found_nppi"
while IFS= read -r lib_path; do
lib_dir=$(dirname "$lib_path")
if [[ ! " ${NPPI_LOCATIONS[@]} " =~ " $lib_dir " ]]; then
NPPI_LOCATIONS+=("$lib_dir")
fi
done <<< "$found_nppi"
fi
done

# Add these locations to our search paths
for nppi_dir in "${NPPI_LOCATIONS[@]}"; do
echo "[NVCUVID-SEARCH] Adding CUDA library directory to search: $nppi_dir"
done

# Standard library search paths
SEARCH_PATHS=(
"/usr/lib"
"/usr/lib64"
"/usr/lib/x86_64-linux-gnu"
"/usr/local/lib"
"/usr/local/lib64"
"/lib"
"/lib64"
"/opt/cuda/lib64"
"/usr/local/cuda/lib64"
"/usr/local/cuda/lib"
"/usr/local/cuda-*/lib64"
"/usr/local/cuda-*/lib"
)

# Add the CUDA nppi library directories to our search paths
for nppi_dir in "${NPPI_LOCATIONS[@]}"; do
SEARCH_PATHS+=("$nppi_dir")
done

# Library name variations to search for
LIB_PATTERNS=(
"libnvcuvid.so*"
"nvcuvid.so*"
"libnvcuvid.a"
"nvcuvid.a"
"libnvcuvid*"
"nvcuvid*"
)

found_libraries=()

for search_path in "${SEARCH_PATHS[@]}"; do
if [ -d "$search_path" ]; then
echo "[NVCUVID-SEARCH] Searching in: $search_path"
for pattern in "${LIB_PATTERNS[@]}"; do
# Use find with error suppression to avoid permission errors
found_files=$(find "$search_path" -maxdepth 3 -name "$pattern" 2>/dev/null || true)
if [ -n "$found_files" ]; then
echo "[NVCUVID-SEARCH] Found: $found_files"
found_libraries+=($found_files)
fi
done
else
echo "[NVCUVID-SEARCH] Directory not found: $search_path"
fi
done

# Also try using ldconfig to find the library
echo "[NVCUVID-SEARCH] Checking ldconfig cache for nvcuvid..."
if command -v ldconfig >/dev/null 2>&1; then
ldconfig_result=$(ldconfig -p 2>/dev/null | grep -i nvcuvid || echo "Not found in ldconfig cache")
echo "[NVCUVID-SEARCH] ldconfig result: $ldconfig_result"
fi

# Try pkg-config if available
echo "[NVCUVID-SEARCH] Checking pkg-config for cuda libraries..."
if command -v pkg-config >/dev/null 2>&1; then
pkg_result=$(pkg-config --list-all 2>/dev/null | grep -i cuda || echo "No CUDA packages found in pkg-config")
echo "[NVCUVID-SEARCH] pkg-config cuda packages: $pkg_result"
fi

# Summary
if [ ${#found_libraries[@]} -gt 0 ]; then
echo "[NVCUVID-SEARCH] === SUMMARY: Found ${#found_libraries[@]} nvcuvid library files ==="
for lib in "${found_libraries[@]}"; do
echo "[NVCUVID-SEARCH] $lib"
# Show file info if possible
if [ -f "$lib" ]; then
ls -la "$lib" 2>/dev/null || true
fi
done
else
echo "[NVCUVID-SEARCH] === SUMMARY: No nvcuvid libraries found ==="
fi

echo "[NVCUVID-SEARCH] === End nvcuvid library search ==="
41 changes: 34 additions & 7 deletions src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ else()
# TODO set warnings as errors on Windows as well.
# set(TORCHCODEC_WERROR_OPTION "/WX")
else()
set(TORCHCODEC_WERROR_OPTION "-Werror")
# set(TORCHCODEC_WERROR_OPTION "-Werror")
endif()
endif()

Expand All @@ -31,10 +31,10 @@ endif()
if (WIN32)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /W4 ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
# set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Wextra -pedantic ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCHCODEC_WERROR_OPTION} ${TORCH_CXX_FLAGS}")
endif()


function(make_torchcodec_sublibrary
library_name
type
Expand Down Expand Up @@ -98,7 +98,7 @@ function(make_torchcodec_libraries
)

if(ENABLE_CUDA)
list(APPEND core_sources CudaDeviceInterface.cpp)
list(APPEND core_sources CudaDeviceInterface.cpp CustomNvdecDeviceInterface.cpp)
endif()

set(core_library_dependencies
Expand All @@ -111,6 +111,33 @@ function(make_torchcodec_libraries
${CUDA_nppi_LIBRARY}
${CUDA_nppicc_LIBRARY}
)

# Try the normal way first
find_library(NVCUVID_LIBRARY NAMES nvcuvid)

# If not found, try with version suffix
if(NOT NVCUVID_LIBRARY)
find_library(NVCUVID_LIBRARY NAMES nvcuvid.1 PATHS /usr/lib64 /usr/lib)
endif()

# Or specify the full path directly
if(NOT NVCUVID_LIBRARY)
set(NVCUVID_LIBRARY "/usr/lib64/libnvcuvid.so.1")
endif()

if(NVCUVID_LIBRARY)
message(STATUS "Found NVCUVID: ${NVCUVID_LIBRARY}")
else()
message(FATAL_ERROR "Could not find NVCUVID library")
endif()
# find_library(NVCUVID_LIBRARY NAMES nvcuvid REQUIRED)
# message(STATUS "Found NVCUVID library: ${NVCUVID_LIBRARY}")

# Add CUDA Driver library (needed for cuCtxGetCurrent, etc.)
find_library(CUDA_DRIVER_LIBRARY NAMES cuda REQUIRED)
message(STATUS "Found CUDA Driver library: ${CUDA_DRIVER_LIBRARY}")

list(APPEND core_library_dependencies ${NVCUVID_LIBRARY} ${CUDA_DRIVER_LIBRARY})
endif()

make_torchcodec_sublibrary(
Expand Down Expand Up @@ -239,9 +266,9 @@ if(DEFINED ENV{BUILD_AGAINST_ALL_FFMPEG_FROM_S3})
)

make_torchcodec_libraries(7 ffmpeg7)
make_torchcodec_libraries(6 ffmpeg6)
make_torchcodec_libraries(4 ffmpeg4)
make_torchcodec_libraries(5 ffmpeg5)
# make_torchcodec_libraries(6 ffmpeg6)
# make_torchcodec_libraries(4 ffmpeg4)
# make_torchcodec_libraries(5 ffmpeg5)
else()
message(
STATUS
Expand Down
84 changes: 47 additions & 37 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ extern "C" {
namespace facebook::torchcodec {
namespace {

static bool g_cuda =
static bool g_cuda_default =
registerDeviceInterface(torch::kCUDA, [](const torch::Device& device) {
return new CudaDeviceInterface(device);
});
Expand Down Expand Up @@ -171,7 +171,7 @@ std::unique_ptr<NppStreamContext> getNppStreamContext(

CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
: DeviceInterface(device) {
TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!");
TORCH_CHECK(g_cuda_default, "CudaDeviceInterface was not registered!");
TORCH_CHECK(
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
}
Expand Down Expand Up @@ -205,6 +205,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
// printf("In default's CUDA interface convertAVFrameToFrameOutput\n");
fflush(stdout);
if (avFrame->format != AV_PIX_FMT_CUDA) {
// The frame's format is AV_PIX_FMT_CUDA if and only if its content is on
// the GPU. In this branch, the frame is on the CPU: this is what NVDEC
Expand All @@ -229,29 +231,35 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
return;
}

// Above we checked that the AVFrame was on GPU, but that's not enough, we
// also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
// because this is what the NPP color conversion routines expect.
// TODO: we should investigate how to can perform color conversion for
// non-8bit videos. This is supported on CPU.
TORCH_CHECK(
avFrame->hw_frames_ctx != nullptr,
"The AVFrame does not have a hw_frames_ctx. "
"That's unexpected, please report this to the TorchCodec repo.");

auto hwFramesCtx =
reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
AVPixelFormat actualFormat = hwFramesCtx->sw_format;
TORCH_CHECK(
actualFormat == AV_PIX_FMT_NV12,
"The AVFrame is ",
(av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
: "unknown"),
", but we expected AV_PIX_FMT_NV12. This typically happens when "
"the video isn't 8bit, which is not supported on CUDA at the moment. "
"Try using the CPU device instead. "
"If the video is 10bit, we are tracking 10bit support in "
"https://github.com/pytorch/torchcodec/issues/776");
// TODONVDEC: We're currently calling this function from within the CNI
// (Custome NVDEC Interface). But the AVFrame's hw_frames_ctx doesn't exist,
// so we error. Not sure how to solve this: either set the field in a
// meaningful way, or allow to bypass the check, but then how do we know the
// pix format?

// // Above we checked that the AVFrame was on GPU, but that's not enough, we
// // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits),
// // because this is what the NPP color conversion routines expect.
// // TODO: we should investigate how to can perform color conversion for
// // non-8bit videos. This is supported on CPU.
// TORCH_CHECK(
// avFrame->hw_frames_ctx != nullptr,
// "The AVFrame does not have a hw_frames_ctx. "
// "That's unexpected, please report this to the TorchCodec repo.");

// auto hwFramesCtx =
// reinterpret_cast<AVHWFramesContext*>(avFrame->hw_frames_ctx->data);
// AVPixelFormat actualFormat = hwFramesCtx->sw_format;
// TORCH_CHECK(
// actualFormat == AV_PIX_FMT_NV12,
// "The AVFrame is ",
// (av_get_pix_fmt_name(actualFormat) ? av_get_pix_fmt_name(actualFormat)
// : "unknown"),
// ", but we expected AV_PIX_FMT_NV12. This typically happens when "
// "the video isn't 8bit, which is not supported on CUDA at the moment. "
// "Try using the CPU device instead. "
// "If the video is 10bit, we are tracking 10bit support in "
// "https://github.com/pytorch/torchcodec/issues/776");

auto frameDims =
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
Expand Down Expand Up @@ -285,19 +293,19 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
// arbitrary, but unfortunately we know it's hardcoded to be the default
// stream by FFmpeg:
// https://github.com/FFmpeg/FFmpeg/blob/66e40840d15b514f275ce3ce2a4bf72ec68c7311/libavutil/hwcontext_cuda.c#L387-L388
TORCH_CHECK(
hwFramesCtx->device_ctx != nullptr,
"The AVFrame's hw_frames_ctx does not have a device_ctx. ");
auto cudaDeviceCtx =
static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
at::cuda::CUDAEvent nvdecDoneEvent;
at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
nvdecDoneEvent.record(nvdecStream);

// Don't start NPP work before NVDEC is done decoding the frame!
// TORCH_CHECK(
// hwFramesCtx->device_ctx != nullptr,
// "The AVFrame's hw_frames_ctx does not have a device_ctx. ");
// auto cudaDeviceCtx =
// static_cast<AVCUDADeviceContext*>(hwFramesCtx->device_ctx->hwctx);
// at::cuda::CUDAEvent nvdecDoneEvent;
// at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad.
// c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex);
// nvdecDoneEvent.record(nvdecStream);

// // Don't start NPP work before NVDEC is done decoding the frame!
at::cuda::CUDAStream nppStream = at::cuda::getCurrentCUDAStream(deviceIndex);
nvdecDoneEvent.block(nppStream);
// nvdecDoneEvent.block(nppStream);

// Create the NPP context if we haven't yet.
nppCtx_->hStream = nppStream.stream();
Expand All @@ -316,6 +324,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
// For background, see
// Note [YUV -> RGB Color Conversion, color space and color range]
if (avFrame->colorspace == AVColorSpace::AVCOL_SPC_BT709) {

if (avFrame->color_range == AVColorRange::AVCOL_RANGE_JPEG) {
// NPP provides a pre-defined color conversion function for BT.709 full
// range: nppiNV12ToRGB_709HDTV_8u_P2C3R_Ctx. But it's not closely
Expand Down Expand Up @@ -352,6 +361,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
*nppCtx_);
}
} else {

// TODO we're assuming BT.601 color space (and probably limited range) by
// calling nppiNV12ToRGB_8u_P2C3R_Ctx. We should handle BT.601 full range,
// and other color-spaces like 2020.
Expand Down
Loading
Loading