Skip to content

Commit 2b4f809

Browse files
authored
Add benchmark for batch decoding (#200)
1 parent bfc5ba2 commit 2b4f809

File tree

1 file changed

+43
-2
lines changed

1 file changed

+43
-2
lines changed

benchmarks/decoders/benchmark_decoders.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import abc
88
import argparse
99
import importlib
10+
import json
1011
import os
1112
import timeit
1213

@@ -17,7 +18,10 @@
1718
from torchcodec.decoders._core import (
1819
add_video_stream,
1920
create_from_file,
21+
get_frames_at_indices,
22+
get_json_metadata,
2023
get_next_frame,
24+
scan_all_streams_to_update_metadata,
2125
seek_to_pts,
2226
)
2327

@@ -143,6 +147,39 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
143147
return frames
144148

145149

150+
class TorchCodecDecoderNonCompiledBatch(AbstractDecoder):
151+
def __init__(self, num_threads=None):
152+
self._print_each_iteration_time = False
153+
self._num_threads = num_threads
154+
155+
def get_frames_from_video(self, video_file, pts_list):
156+
decoder = create_from_file(video_file)
157+
scan_all_streams_to_update_metadata(decoder)
158+
add_video_stream(decoder, num_threads=self._num_threads)
159+
metadata = json.loads(get_json_metadata(decoder))
160+
average_fps = metadata["averageFps"]
161+
best_video_stream = metadata["bestVideoStreamIndex"]
162+
indexes_list = [int(pts * average_fps) for pts in pts_list]
163+
frames = []
164+
frames = get_frames_at_indices(
165+
decoder, stream_index=best_video_stream, frame_indices=indexes_list
166+
)
167+
return frames
168+
169+
def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
170+
decoder = create_from_file(video_file)
171+
scan_all_streams_to_update_metadata(decoder)
172+
add_video_stream(decoder, num_threads=self._num_threads)
173+
metadata = json.loads(get_json_metadata(decoder))
174+
best_video_stream = metadata["bestVideoStreamIndex"]
175+
frames = []
176+
indices_list = list(range(numFramesToDecode))
177+
frames = get_frames_at_indices(
178+
decoder, stream_index=best_video_stream, frame_indices=indices_list
179+
)
180+
return frames
181+
182+
146183
@torch.compile(fullgraph=True, backend="eager")
147184
def compiled_seek_and_next(decoder, pts):
148185
seek_to_pts(decoder, pts)
@@ -257,9 +294,9 @@ def main() -> None:
257294
)
258295
parser.add_argument(
259296
"--decoders",
260-
help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec",
297+
help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec. torchcodec_batch means torchcodec using batch methods.",
261298
type=str,
262-
default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled",
299+
default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled,torchcodec_batch",
263300
)
264301

265302
args = parser.parse_args()
@@ -291,6 +328,10 @@ def main() -> None:
291328
)
292329
if "torchaudio" in decoders:
293330
decoder_dict["TorchAudioDecoder"] = TorchAudioDecoder()
331+
if "torchcodec_batch" in decoders:
332+
decoder_dict["TorchCodecDecoderNonCompiledBatch"] = (
333+
TorchCodecDecoderNonCompiledBatch()
334+
)
294335

295336
decoder_dict["TVNewAPIDecoderWithBackendVideoReader"]
296337

0 commit comments

Comments
 (0)