|
7 | 7 | import abc
|
8 | 8 | import argparse
|
9 | 9 | import importlib
|
| 10 | +import json |
10 | 11 | import os
|
11 | 12 | import timeit
|
12 | 13 |
|
|
17 | 18 | from torchcodec.decoders._core import (
|
18 | 19 | add_video_stream,
|
19 | 20 | create_from_file,
|
| 21 | + get_frames_at_indices, |
| 22 | + get_json_metadata, |
20 | 23 | get_next_frame,
|
| 24 | + scan_all_streams_to_update_metadata, |
21 | 25 | seek_to_pts,
|
22 | 26 | )
|
23 | 27 |
|
@@ -143,6 +147,39 @@ def get_consecutive_frames_from_video(self, video_file, numFramesToDecode):
|
143 | 147 | return frames
|
144 | 148 |
|
145 | 149 |
|
| 150 | +class TorchCodecDecoderNonCompiledBatch(AbstractDecoder): |
| 151 | + def __init__(self, num_threads=None): |
| 152 | + self._print_each_iteration_time = False |
| 153 | + self._num_threads = num_threads |
| 154 | + |
| 155 | + def get_frames_from_video(self, video_file, pts_list): |
| 156 | + decoder = create_from_file(video_file) |
| 157 | + scan_all_streams_to_update_metadata(decoder) |
| 158 | + add_video_stream(decoder, num_threads=self._num_threads) |
| 159 | + metadata = json.loads(get_json_metadata(decoder)) |
| 160 | + average_fps = metadata["averageFps"] |
| 161 | + best_video_stream = metadata["bestVideoStreamIndex"] |
| 162 | + indexes_list = [int(pts * average_fps) for pts in pts_list] |
| 163 | + frames = [] |
| 164 | + frames = get_frames_at_indices( |
| 165 | + decoder, stream_index=best_video_stream, frame_indices=indexes_list |
| 166 | + ) |
| 167 | + return frames |
| 168 | + |
| 169 | + def get_consecutive_frames_from_video(self, video_file, numFramesToDecode): |
| 170 | + decoder = create_from_file(video_file) |
| 171 | + scan_all_streams_to_update_metadata(decoder) |
| 172 | + add_video_stream(decoder, num_threads=self._num_threads) |
| 173 | + metadata = json.loads(get_json_metadata(decoder)) |
| 174 | + best_video_stream = metadata["bestVideoStreamIndex"] |
| 175 | + frames = [] |
| 176 | + indices_list = list(range(numFramesToDecode)) |
| 177 | + frames = get_frames_at_indices( |
| 178 | + decoder, stream_index=best_video_stream, frame_indices=indices_list |
| 179 | + ) |
| 180 | + return frames |
| 181 | + |
| 182 | + |
146 | 183 | @torch.compile(fullgraph=True, backend="eager")
|
147 | 184 | def compiled_seek_and_next(decoder, pts):
|
148 | 185 | seek_to_pts(decoder, pts)
|
@@ -257,9 +294,9 @@ def main() -> None:
|
257 | 294 | )
|
258 | 295 | parser.add_argument(
|
259 | 296 | "--decoders",
|
260 |
| - help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec", |
| 297 | + help="Comma-separated list of decoders to benchmark. Choices are torchcodec, torchaudio, torchvision, decord, torchcodec1, torchcodec_compiled. torchcodec1 means torchcodec with num_threads=1. torchcodec_compiled means torch.compiled torchcodec. torchcodec_batch means torchcodec using batch methods.", |
261 | 298 | type=str,
|
262 |
| - default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled", |
| 299 | + default="decord,torchcodec,torchvision,torchaudio,torchcodec1,torchcodec_compiled,torchcodec_batch", |
263 | 300 | )
|
264 | 301 |
|
265 | 302 | args = parser.parse_args()
|
@@ -291,6 +328,10 @@ def main() -> None:
|
291 | 328 | )
|
292 | 329 | if "torchaudio" in decoders:
|
293 | 330 | decoder_dict["TorchAudioDecoder"] = TorchAudioDecoder()
|
| 331 | + if "torchcodec_batch" in decoders: |
| 332 | + decoder_dict["TorchCodecDecoderNonCompiledBatch"] = ( |
| 333 | + TorchCodecDecoderNonCompiledBatch() |
| 334 | + ) |
294 | 335 |
|
295 | 336 | decoder_dict["TVNewAPIDecoderWithBackendVideoReader"]
|
296 | 337 |
|
|
0 commit comments