diff --git a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py index e63484ffe208..9fddf368d23f 100644 --- a/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py +++ b/examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py @@ -23,35 +23,52 @@ ## To evaluate a model in cache-aware streaming mode on a single audio file: -python speech_to_text_streaming_infer.py \ - --asr_model=asr_model.nemo \ - --audio_file=audio_file.wav \ - --compare_vs_offline \ - --use_amp \ - --debug_mode +python speech_to_text_cache_aware_streaming_infer.py \ + model_path=asr_model.nemo \ + audio_file=audio_file.wav \ + compare_vs_offline=true \ + amp=true \ + debug_mode=true ## To evaluate a model in cache-aware streaming mode on a manifest file: -python speech_to_text_streaming_infer.py \ - --asr_model=asr_model.nemo \ - --manifest_file=manifest_file.json \ - --batch_size=16 \ - --compare_vs_offline \ - --use_amp \ - --debug_mode - -You may drop the '--debug_mode' and '--compare_vs_offline' to speedup the streaming evaluation. +python speech_to_text_cache_aware_streaming_infer.py \ + model_path=asr_model.nemo \ + dataset_manifest=manifest_file.json \ + batch_size=16 \ + compare_vs_offline=true \ + amp=true \ + debug_mode=true + +## It is also possible to use phrase boosting or external LM with cache-aware models: + +python speech_to_text_cache_aware_streaming_infer.py \ + model_path=asr_model.nemo \ + dataset_manifest=manifest_file.json \ + batch_size=16 \ + rnnt_decoding.greedy.boosting_tree.key_phrases_file=key_words_list.txt \ + rnnt_decoding.greedy.boosting_tree_alpha=1.0 \ + rnnt_decoding.greedy.ngram_lm_model=lm_model.nemo \ + rnnt_decoding.greedy.ngram_lm_model=0.5 \ + compare_vs_offline=true \ + amp=true \ + debug_mode=true + +You may drop the 'debug_mode' and 'compare_vs_offline' to speedup the streaming evaluation. If compare_vs_offline is not used, then significantly larger batch_size can be used. -Setting `--pad_and_drop_preencoded` would perform the caching for all steps including the first step. +Setting `pad_and_drop_preencoded` would perform the caching for all steps including the first step. It may result in slightly different outputs from the sub-sampling module compared to offline mode for some techniques like striding and sw_striding. Enabling it would make it easier to export the model to ONNX. +For customization details (phrases list, n-gram LM) see details in the documentation: +https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/asr/asr_language_modeling_and_customization.html + ## Hybrid ASR models -For Hybrid ASR models which have two decoders, you may select the decoder by --set_decoder DECODER_TYPE, where DECODER_TYPE can be "ctc" or "rnnt". +For Hybrid ASR models which have two decoders, you may select the decoder by decoder_type DECODER_TYPE, where DECODER_TYPE can be "ctc" or "rnnt". If decoder is not set, then the default decoder would be used which is the RNNT decoder for Hybrid ASR models. ## Multi-lookahead models -For models which support multiple lookaheads, the default is the first one in the list of model.encoder.att_context_size. To change it, you may use --att_context_size, for example --att_context_size [70,1]. +For models which support multiple lookaheads, the default is the first one in the list of model.encoder.att_context_size. To change it, you may use att_context_size, for example att_context_size=ยง[70,1]. ## Evaluate a model trained with full context for offline mode @@ -66,37 +83,108 @@ The following command would simulate cache-aware streaming on a pretrained model from NGC with chunk_size of 100, shift_size of 50 and 2 left chunks as left context. The chunk_size of 100 would be 100*4*10=4000ms for a model with 4x downsampling and 10ms shift in feature extraction. -python speech_to_text_streaming_infer.py \ - --asr_model=stt_en_conformer_ctc_large \ - --chunk_size=100 \ - --shift_size=50 \ - --left_chunks=2 \ - --online_normalization \ - --manifest_file=manifest_file.json \ - --batch_size=16 \ - --compare_vs_offline \ - --use_amp \ - --debug_mode +python speech_to_text_cache_aware_streaming_infer.py \ + pretrained_name=stt_en_conformer_ctc_large \ + chunk_size=100 \ + shift_size=50 \ + left_chunks=2 \ + online_normalization=true \ + dataset_manifest=manifest_file.json \ + batch_size=16 \ + compare_vs_offline=true \ + debug_mode=true """ -import contextlib import json import os import time -from argparse import ArgumentParser +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional +import lightning.pytorch as pl import torch -from omegaconf import open_dict +from omegaconf import OmegaConf -import nemo.collections.asr as nemo_asr from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig +from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTDecodingConfig +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.asr.parts.utils.streaming_utils import CacheAwareStreamingAudioBuffer +from nemo.collections.asr.parts.utils.transcribe_utils import get_inference_device, get_inference_dtype, setup_model +from nemo.core.config import hydra_runner from nemo.utils import logging +@dataclass +class TranscriptionConfig: + """ + Transcription Configuration for cache-aware inference. + """ + + # Required configs + model_path: Optional[str] = None # Path to a .nemo file + pretrained_name: Optional[str] = None # Name of a pretrained model + # audio_dir: Optional[str] = None # Path to a directory which contains audio files + audio_file: Optional[str] = None # Path to an audio file to perform streaming + dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + output_path: Optional[str] = None # Path to output file when manifest is used as input + + # General configs + batch_size: int = 32 + # num_workers: int = 0 + # append_pred: bool = False # Sets mode of work, if True it will add new field transcriptions. + # pred_name_postfix: Optional[str] = None # If you need to use another model name, rather than standard one. + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + + # Chunked configs + chunk_size: int = -1 # The chunk_size to be used for models trained with full context and offline models + shift_size: int = -1 # The shift_size to be used for models trained with full context and offline models + left_chunks: Optional[int] = ( + 2 # The number of left chunks to be used as left context via caching for offline models + ) + online_normalization: bool = False # Perform normalization on the run per chunk. + # `pad_and_drop_preencoded` enables padding the audio input and then dropping the extra steps after + # the pre-encoding for all the steps including the the first step. It may make the outputs of the downsampling + # slightly different from offline mode for some techniques like striding or sw_striding. + pad_and_drop_preencoded: bool = False + att_context_size: Optional[list] = ( + None # Sets the att_context_size for the models which support multiple lookaheads + ) + + compare_vs_offline: bool = False # Whether to compare the output of the model with the offline mode. + + # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA + # device anyway, and do inference on CPU only if CUDA device is not found. + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + allow_mps: bool = False # allow to select MPS device (Apple Silicon M-series GPU) + amp: bool = False + amp_dtype: str = "float16" # can be set to "float16" or "bfloat16" when using amp + # NB: default compute_dtype is float32 since currently cache-aware models do not work with different dtype + compute_dtype: Optional[str] = ( + "float32" # "float32" (default), "bfloat16" or "float16"; if None: bfloat16 if available else float32 + ) + matmul_precision: str = "high" # Literal["highest", "high", "medium"] + + # Decoding strategy for CTC models + ctc_decoding: CTCDecodingConfig = field(default_factory=CTCDecodingConfig) + # Decoding strategy for RNNT models + rnnt_decoding: RNNTDecodingConfig = field(default_factory=lambda: RNNTDecodingConfig(fused_batch_size=-1)) + # Selects the decoder for Hybrid ASR models which has both the CTC and RNNT decoder. + decoder_type: Optional[str] = None # Literal["ctc", "rnnt"] + + # Config for word / character error rate calculation + # calculate_wer: bool = True + # clean_groundtruth_text: bool = False + # langid: str = "en" # specify this for convert_num_to_words step in groundtruth cleaning + # use_cer: bool = False + debug_mode: bool = False # Whether to print more detail in the output. + + def extract_transcriptions(hyps): """ The transcribed_texts returned by CTC and RNNT models are different. @@ -120,28 +208,33 @@ def calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded): def perform_streaming( - asr_model, streaming_buffer, compare_vs_offline=False, debug_mode=False, pad_and_drop_preencoded=False + asr_model, + streaming_buffer, + compute_dtype: torch.dtype, + compare_vs_offline=False, + debug_mode=False, + pad_and_drop_preencoded=False, ): batch_size = len(streaming_buffer.streams_length) if compare_vs_offline: # would pass the whole audio at once through the model like offline mode in order to compare the results with the stremaing mode # the output of the model in the offline and streaming mode should be exactly the same with torch.inference_mode(): - with autocast: - processed_signal, processed_signal_length = streaming_buffer.get_all_audios() - with torch.no_grad(): - ( - pred_out_offline, - transcribed_texts, - cache_last_channel_next, - cache_last_time_next, - cache_last_channel_len, - best_hyp, - ) = asr_model.conformer_stream_step( - processed_signal=processed_signal, - processed_signal_length=processed_signal_length, - return_transcription=True, - ) + processed_signal, processed_signal_length = streaming_buffer.get_all_audios() + processed_signal = processed_signal.to(compute_dtype) + with torch.no_grad(): + ( + pred_out_offline, + transcribed_texts, + cache_last_channel_next, + cache_last_time_next, + cache_last_channel_len, + best_hyp, + ) = asr_model.conformer_stream_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + return_transcription=True, + ) final_offline_tran = extract_transcriptions(transcribed_texts) logging.info(f" Final offline transcriptions: {final_offline_tran}") else: @@ -156,32 +249,29 @@ def perform_streaming( pred_out_stream = None for step_num, (chunk_audio, chunk_lengths) in enumerate(streaming_buffer_iter): with torch.inference_mode(): - with autocast: - # keep_all_outputs needs to be True for the last step of streaming when model is trained with att_context_style=regular - # otherwise the last outputs would get dropped - - with torch.no_grad(): - ( - pred_out_stream, - transcribed_texts, - cache_last_channel, - cache_last_time, - cache_last_channel_len, - previous_hypotheses, - ) = asr_model.conformer_stream_step( - processed_signal=chunk_audio, - processed_signal_length=chunk_lengths, - cache_last_channel=cache_last_channel, - cache_last_time=cache_last_time, - cache_last_channel_len=cache_last_channel_len, - keep_all_outputs=streaming_buffer.is_buffer_empty(), - previous_hypotheses=previous_hypotheses, - previous_pred_out=pred_out_stream, - drop_extra_pre_encoded=calc_drop_extra_pre_encoded( - asr_model, step_num, pad_and_drop_preencoded - ), - return_transcription=True, - ) + # keep_all_outputs needs to be True for the last step of streaming when model is trained with att_context_style=regular + # otherwise the last outputs would get dropped + chunk_audio = chunk_audio.to(compute_dtype) + with torch.no_grad(): + ( + pred_out_stream, + transcribed_texts, + cache_last_channel, + cache_last_time, + cache_last_channel_len, + previous_hypotheses, + ) = asr_model.conformer_stream_step( + processed_signal=chunk_audio, + processed_signal_length=chunk_lengths, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + keep_all_outputs=streaming_buffer.is_buffer_empty(), + previous_hypotheses=previous_hypotheses, + previous_pred_out=pred_out_stream, + drop_extra_pre_encoded=calc_drop_extra_pre_encoded(asr_model, step_num, pad_and_drop_preencoded), + return_transcription=True, + ) if debug_mode: logging.info(f"Streaming transcriptions: {extract_transcriptions(transcribed_texts)}") @@ -207,158 +297,89 @@ def perform_streaming( return final_streaming_tran, final_offline_tran -def main(): - parser = ArgumentParser() - parser.add_argument( - "--asr_model", - type=str, - required=True, - help="Path to an ASR model .nemo file or name of a pretrained model.", - ) - parser.add_argument( - "--device", type=str, help="The device to load the model onto and perform the streaming", default="cuda" - ) - parser.add_argument("--audio_file", type=str, help="Path to an audio file to perform streaming", default=None) - parser.add_argument( - "--manifest_file", - type=str, - help="Path to a manifest file containing audio files to perform streaming", - default=None, - ) - parser.add_argument("--use_amp", action="store_true", help="Whether to use AMP") - parser.add_argument("--debug_mode", action="store_true", help="Whether to print more detail in the output.") - parser.add_argument( - "--compare_vs_offline", - action="store_true", - help="Whether to compare the output of the model with the offline mode.", - ) - parser.add_argument( - "--batch_size", - type=int, - default=32, - help="The batch size to be used to perform streaming in batch mode with multiple streams", - ) - parser.add_argument( - "--chunk_size", - type=int, - default=-1, - help="The chunk_size to be used for models trained with full context and offline models", - ) - parser.add_argument( - "--shift_size", - type=int, - default=-1, - help="The shift_size to be used for models trained with full context and offline models", - ) - parser.add_argument( - "--left_chunks", - type=int, - default=2, - help="The number of left chunks to be used as left context via caching for offline models", - ) +@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) +def main(cfg: TranscriptionConfig): + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + torch.set_grad_enabled(False) + torch.set_float32_matmul_precision(cfg.matmul_precision) + cfg = OmegaConf.structured(cfg) + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) - parser.add_argument( - "--online_normalization", - default=False, - action='store_true', - help="Perform normalization on the run per chunk.", - ) - parser.add_argument( - "--output_path", type=str, help="path to output file when manifest is used as input", default=None - ) - parser.add_argument( - "--pad_and_drop_preencoded", - action="store_true", - help="Enables padding the audio input and then dropping the extra steps after the pre-encoding for all the steps including the the first step. It may make the outputs of the downsampling slightly different from offline mode for some techniques like striding or sw_striding.", - ) + # setup device + device = get_inference_device(cuda=cfg.cuda, allow_mps=cfg.allow_mps) - parser.add_argument( - "--set_decoder", - choices=["ctc", "rnnt"], - default=None, - help="Selects the decoder for Hybrid ASR models which has both the CTC and RNNT decoder. Supported decoders are ['ctc', 'rnnt']", - ) + if (cfg.compute_dtype is not None and cfg.compute_dtype != "float32") and cfg.amp: + raise ValueError("amp=true is mutually exclusive with a compute_dtype other than float32") - parser.add_argument( - "--att_context_size", - type=str, - default=None, - help="Sets the att_context_size for the models which support multiple lookaheads", - ) - - parser.add_argument( - "--matmul-precision", - type=str, - default="high", - choices=["highest", "high", "medium"], - help="Set torch matmul precision", - ) + amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 - parser.add_argument("--strategy", type=str, default="greedy_batch", help="decoding strategy to use") - - args = parser.parse_args() + compute_dtype: torch.dtype + if cfg.amp: + # with amp model weights required to be in float32 + compute_dtype = torch.float32 + else: + compute_dtype = get_inference_dtype(compute_dtype=cfg.compute_dtype, device=device) + + if compute_dtype != torch.float32: + # NB: cache-aware models do not currently work with compute_dtype != float32 + # since in some layers output is force-casted to float32 + # TODO(vbataev): implement support in future; set `compute_dtype` in config to None by default + raise NotImplementedError( + f"Compute dtype {compute_dtype} is not yet supported for cache-aware models, use float32 instead" + ) - torch.set_float32_matmul_precision(args.matmul_precision) - if (args.audio_file is None and args.manifest_file is None) or ( - args.audio_file is not None and args.manifest_file is not None + if (cfg.audio_file is None and cfg.dataset_manifest is None) or ( + cfg.audio_file is not None and cfg.dataset_manifest is not None ): - raise ValueError("One of the audio_file and manifest_file should be non-empty!") + raise ValueError("One of the audio_file and dataset_manifest should be non-empty!") - if args.asr_model.endswith('.nemo'): - logging.info(f"Using local ASR model from {args.asr_model}") - asr_model = nemo_asr.models.ASRModel.restore_from(restore_path=args.asr_model) - else: - logging.info(f"Using NGC cloud ASR model {args.asr_model}") - asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=args.asr_model) + asr_model, model_name = setup_model(cfg=cfg, map_location=device) logging.info(asr_model.encoder.streaming_cfg) - if args.set_decoder is not None: - if hasattr(asr_model, "cur_decoder"): - asr_model.change_decoding_strategy(decoder_type=args.set_decoder) - else: - raise ValueError("Decoder cannot get changed for non-Hybrid ASR models.") - - if args.att_context_size is not None: + if cfg.att_context_size is not None: if hasattr(asr_model.encoder, "set_default_att_context_size"): - asr_model.encoder.set_default_att_context_size(att_context_size=json.loads(args.att_context_size)) + asr_model.encoder.set_default_att_context_size(att_context_size=cfg.att_context_size) else: raise ValueError("Model does not support multiple lookaheads.") - global autocast - autocast = torch.amp.autocast(asr_model.device.type, enabled=args.use_amp) - - # configure the decoding config - decoding_cfg = asr_model.cfg.decoding - with open_dict(decoding_cfg): - decoding_cfg.strategy = args.strategy - decoding_cfg.preserve_alignments = False - if hasattr(asr_model, 'joint'): # if an RNNT model - decoding_cfg.fused_batch_size = -1 - if not (max_symbols := decoding_cfg.greedy.get("max_symbols")) or max_symbols <= 0: - decoding_cfg.greedy.max_symbols = 10 - if hasattr(asr_model, "cur_decoder"): - # hybrid model, explicitly pass decoder type, otherwise it will be set to "rnnt" - asr_model.change_decoding_strategy(decoding_cfg, decoder_type=asr_model.cur_decoder) + # Setup decoding strategy + if hasattr(asr_model, 'change_decoding_strategy') and hasattr(asr_model, 'decoding'): + if cfg.decoder_type is not None: + decoding_cfg = cfg.rnnt_decoding if cfg.decoder_type == 'rnnt' else cfg.ctc_decoding + + if hasattr(asr_model, 'cur_decoder'): + asr_model.change_decoding_strategy(decoding_cfg, decoder_type=cfg.decoder_type) + else: + asr_model.change_decoding_strategy(decoding_cfg) + + # Check if ctc or rnnt model + elif hasattr(asr_model, 'joint'): # RNNT model + cfg.rnnt_decoding.fused_batch_size = -1 + if hasattr(asr_model, 'cur_decoder'): + asr_model.change_decoding_strategy(cfg.rnnt_decoding, decoder_type=cfg.decoder_type) + else: + asr_model.change_decoding_strategy(cfg.rnnt_decoding) else: - asr_model.change_decoding_strategy(decoding_cfg) + asr_model.change_decoding_strategy(cfg.ctc_decoding) - asr_model = asr_model.to(args.device) + asr_model = asr_model.to(device=device, dtype=compute_dtype) asr_model.eval() # chunk_size is set automatically for models trained for streaming. For models trained for offline mode with full context, we need to pass the chunk_size explicitly. - if args.chunk_size > 0: - if args.shift_size < 0: - shift_size = args.chunk_size + if cfg.chunk_size > 0: + if cfg.shift_size < 0: + shift_size = cfg.chunk_size else: - shift_size = args.shift_size + shift_size = cfg.shift_size asr_model.encoder.setup_streaming_params( - chunk_size=args.chunk_size, left_chunks=args.left_chunks, shift_size=shift_size + chunk_size=cfg.chunk_size, left_chunks=cfg.left_chunks, shift_size=shift_size ) # In streaming, offline normalization is not feasible as we don't have access to the whole audio at the beginning # When online_normalization is enabled, the normalization of the input features (mel-spectrograms) are done per step # It is suggested to train the streaming models without any normalization in the input features. - if args.online_normalization: + if cfg.online_normalization: if asr_model.cfg.preprocessor.normalize not in ["per_feature", "all_feature"]: logging.warning( "online_normalization is enabled but the model has no normalization in the feature extration part, so it is ignored." @@ -373,57 +394,61 @@ def main(): streaming_buffer = CacheAwareStreamingAudioBuffer( model=asr_model, online_normalization=online_normalization, - pad_and_drop_preencoded=args.pad_and_drop_preencoded, + pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, ) - if args.audio_file is not None: - # stream a single audio file - processed_signal, processed_signal_length, stream_id = streaming_buffer.append_audio_file( - args.audio_file, stream_id=-1 - ) - perform_streaming( - asr_model=asr_model, - streaming_buffer=streaming_buffer, - compare_vs_offline=args.compare_vs_offline, - pad_and_drop_preencoded=args.pad_and_drop_preencoded, - ) - else: - # stream audio files in a manifest file in batched mode - samples = [] - all_streaming_tran = [] - all_offline_tran = [] - all_refs_text = [] - - with open(args.manifest_file, 'r') as f: - for line in f: - item = json.loads(line) - samples.append(item) - - logging.info(f"Loaded {len(samples)} from the manifest at {args.manifest_file}.") - - start_time = time.time() - for sample_idx, sample in enumerate(samples): - processed_signal, processed_signal_length, stream_id = streaming_buffer.append_audio_file( - sample['audio_filepath'], stream_id=-1 + with torch.amp.autocast('cuda' if device.type == "cuda" else "cpu", dtype=amp_dtype, enabled=cfg.amp): + if cfg.audio_file is not None: + # stream a single audio file + _ = streaming_buffer.append_audio_file(cfg.audio_file, stream_id=-1) + perform_streaming( + asr_model=asr_model, + streaming_buffer=streaming_buffer, + compute_dtype=compute_dtype, + compare_vs_offline=cfg.compare_vs_offline, + pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, ) - if "text" in sample: - all_refs_text.append(sample["text"]) - logging.info(f'Added this sample to the buffer: {sample["audio_filepath"]}') - - if (sample_idx + 1) % args.batch_size == 0 or sample_idx == len(samples) - 1: - logging.info(f"Starting to stream samples {sample_idx - len(streaming_buffer) + 1} to {sample_idx}...") - streaming_tran, offline_tran = perform_streaming( - asr_model=asr_model, - streaming_buffer=streaming_buffer, - compare_vs_offline=args.compare_vs_offline, - debug_mode=args.debug_mode, - pad_and_drop_preencoded=args.pad_and_drop_preencoded, - ) - all_streaming_tran.extend(streaming_tran) - if args.compare_vs_offline: - all_offline_tran.extend(offline_tran) - streaming_buffer.reset_buffer() + else: + # stream audio files in a manifest file in batched mode + all_streaming_tran = [] + all_offline_tran = [] + all_refs_text = [] + batch_size = cfg.batch_size + + manifest_dir = Path(cfg.dataset_manifest).parent + samples = read_manifest(cfg.dataset_manifest) + # fix relative paths + for item in samples: + audio_filepath = Path(item["audio_filepath"]) + if not audio_filepath.is_absolute(): + item["audio_filepath"] = str(manifest_dir / audio_filepath) + + logging.info(f"Loaded {len(samples)} from the manifest at {cfg.dataset_manifest}.") + + start_time = time.time() + for sample_idx, sample in enumerate(samples): + _ = streaming_buffer.append_audio_file(sample['audio_filepath'], stream_id=-1) + if "text" in sample: + all_refs_text.append(sample["text"]) + logging.info(f'Added this sample to the buffer: {sample["audio_filepath"]}') + + if (sample_idx + 1) % batch_size == 0 or sample_idx == len(samples) - 1: + logging.info( + f"Starting to stream samples {sample_idx - len(streaming_buffer) + 1} to {sample_idx}..." + ) + streaming_tran, offline_tran = perform_streaming( + asr_model=asr_model, + streaming_buffer=streaming_buffer, + compute_dtype=compute_dtype, + compare_vs_offline=cfg.compare_vs_offline, + debug_mode=cfg.debug_mode, + pad_and_drop_preencoded=cfg.pad_and_drop_preencoded, + ) + all_streaming_tran.extend(streaming_tran) + if cfg.compare_vs_offline: + all_offline_tran.extend(offline_tran) + streaming_buffer.reset_buffer() - if args.compare_vs_offline and len(all_refs_text) == len(all_offline_tran): + if cfg.compare_vs_offline and len(all_refs_text) == len(all_offline_tran): offline_wer = word_error_rate(hypotheses=all_offline_tran, references=all_refs_text) logging.info(f"WER% of offline mode: {round(offline_wer * 100, 2)}") if len(all_refs_text) == len(all_streaming_tran): @@ -434,17 +459,17 @@ def main(): logging.info(f"The whole streaming process took: {round(end_time - start_time, 2)}s") # stores the results including the transcriptions of the streaming inference in a json file - if args.output_path is not None and len(all_refs_text) == len(all_streaming_tran): + if cfg.output_path is not None and len(all_refs_text) == len(all_streaming_tran): fname = ( "streaming_out_" - + os.path.splitext(os.path.basename(args.asr_model))[0] + + os.path.splitext(os.path.basename(model_name))[0] + "_" - + os.path.splitext(os.path.basename(args.manifest_file))[0] + + os.path.splitext(os.path.basename(cfg.dataset_manifest))[0] + ".json" ) - hyp_json = os.path.join(args.output_path, fname) - os.makedirs(args.output_path, exist_ok=True) + hyp_json = os.path.join(cfg.output_path, fname) + os.makedirs(cfg.output_path, exist_ok=True) with open(hyp_json, "w") as out_f: for i, hyp in enumerate(all_streaming_tran): record = { diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py index 574a7b21ce66..c5813a75220e 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_streaming_infer_rnnt.py @@ -75,7 +75,13 @@ StreamingBatchedAudioBuffer, ) from nemo.collections.asr.parts.utils.timestamp_utils import process_timestamp_outputs -from nemo.collections.asr.parts.utils.transcribe_utils import compute_output_filename, setup_model, write_transcription +from nemo.collections.asr.parts.utils.transcribe_utils import ( + compute_output_filename, + get_inference_device, + get_inference_dtype, + setup_model, + write_transcription, +) from nemo.core.config import hydra_runner from nemo.utils import logging @@ -164,34 +170,9 @@ def main(cfg: TranscriptionConfig) -> TranscriptionConfig: filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) manifest = None # ignore dataset_manifest if audio_dir and dataset_manifest both presents - # setup GPU - if cfg.cuda is None: - if torch.cuda.is_available(): - map_location = torch.device('cuda:0') # use 0th CUDA device - elif cfg.allow_mps and hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): - logging.warning( - "MPS device (Apple Silicon M-series GPU) support is experimental." - " Env variable `PYTORCH_ENABLE_MPS_FALLBACK=1` should be set in most cases to avoid failures." - ) - map_location = torch.device('mps') - else: - map_location = torch.device('cpu') - elif cfg.cuda < 0: - # negative number => inference on CPU - map_location = torch.device('cpu') - else: - map_location = torch.device(f'cuda:{cfg.cuda}') - - compute_dtype: torch.dtype - if cfg.compute_dtype is None: - can_use_bfloat16 = map_location.type == "cuda" and torch.cuda.is_bf16_supported() - if can_use_bfloat16: - compute_dtype = torch.bfloat16 - else: - compute_dtype = torch.float32 - else: - assert cfg.compute_dtype in {"float32", "bfloat16", "float16"} - compute_dtype = getattr(torch, cfg.compute_dtype) + # setup device + map_location = get_inference_device(cuda=cfg.cuda, allow_mps=cfg.allow_mps) + compute_dtype = get_inference_dtype(cfg.compute_dtype, device=map_location) logging.info(f"Inference will be done on device : {map_location} with compute_dtype: {compute_dtype}") diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index 8e9ae6befeeb..2cdc3a30b96d 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -32,6 +32,7 @@ from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.collections.asr.parts.utils.transcribe_utils import ( compute_output_filename, + get_inference_dtype, prepare_audio_data, restore_transcription_order, setup_model, @@ -276,15 +277,11 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis amp_dtype = torch.float16 if cfg.amp_dtype == "float16" else torch.bfloat16 compute_dtype: torch.dtype - if cfg.compute_dtype is None: - can_use_bfloat16 = (not cfg.amp) and map_location.type == "cuda" and torch.cuda.is_bf16_supported() - if can_use_bfloat16: - compute_dtype = torch.bfloat16 - else: - compute_dtype = torch.float32 + if cfg.amp: + # with amp model weights required to be in float32 + compute_dtype = torch.float32 else: - assert cfg.compute_dtype in {"float32", "bfloat16", "float16"} - compute_dtype = getattr(torch, cfg.compute_dtype) + compute_dtype = get_inference_dtype(compute_dtype=cfg.compute_dtype, device=map_location) asr_model.to(compute_dtype) diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index cd94565d732d..e1cae01f2fe4 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -1674,6 +1674,8 @@ def forward( - If `return_best_hypothesis` is True, returns the best hypothesis for each batch. - Otherwise, returns the N-best hypotheses for each batch. """ + if partial_hypotheses is not None: + raise NotImplementedError("Partial hypotheses feature is not yet supported in batched beam search.") # Preserve decoder and joint training state decoder_training_state = self.decoder.training joint_training_state = self.joint.training diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py index 1d76d70b2b53..1d4b422bdd72 100644 --- a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -945,6 +945,8 @@ def forward( - If `return_best_hypothesis` is True, returns the best hypothesis for each batch. - Otherwise, returns the N-best hypotheses for each batch. """ + if partial_hypotheses is not None: + raise NotImplementedError("Partial hypotheses feature is not yet supported in batched beam search.") # Preserve decoder and joint training state decoder_training_state = self.decoder.training joint_training_state = self.joint.training diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index ba00f09db92a..438fc7e8b42f 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -33,6 +33,73 @@ from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging, model_utils +_MPS_WARNING_TEXT = ( + "MPS device (Apple Silicon M-series GPU) support is experimental." + " Env variable `PYTORCH_ENABLE_MPS_FALLBACK=1` should be set in most cases to avoid failures." +) + + +def get_auto_inference_device(allow_mps: bool = True) -> torch.device: + """Get best available inference device. Preference: CUDA -> MPS -> CPU""" + cuda_available = torch.cuda.is_available() + mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + if cuda_available: + device = torch.device('cuda:0') # use 0th CUDA device + elif allow_mps and mps_available: + logging.warning(_MPS_WARNING_TEXT) + device = torch.device('mps') + else: + device = torch.device('cpu') + return device + + +def get_inference_device(cuda: int | None = None, allow_mps: bool = True) -> torch.device: + """ + Get the best available device for model inference + + Args: + cuda: CUDA (GPU) device ID; negative value = GPU is not allowed; if None, select device automatically. + allow_mps: allow to select MPS device (Apple Silicon) + + Returns: + device: torch.device + """ + mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + cuda_available = torch.cuda.is_available() + if cuda is None: + return get_auto_inference_device(allow_mps=allow_mps) + elif cuda < 0: + # negative number => inference on CPU or MPS + if allow_mps and mps_available: + logging.warning(_MPS_WARNING_TEXT) + device = torch.device('mps') + else: + device = torch.device('cpu') + else: + if cuda_available: + device = torch.device(f'cuda:{cuda}') + else: + raise ValueError(f"CUDA device {cuda} requested, but unavailable") + return device + + +def get_auto_inference_dtype(device: torch.device) -> torch.dtype: + """Get inference dtype automatically. Preference: bfloat16 -> float32""" + can_use_bfloat16 = device.type == "cuda" and torch.cuda.is_bf16_supported() + if can_use_bfloat16: + return torch.bfloat16 + return torch.float32 + + +def get_inference_dtype(compute_dtype: str | None, device: torch.device) -> torch.dtype: + """Get dtype for model inference. If compute_dtype is None, the best available option is selected""" + dtype: torch.dtype + if compute_dtype is None: + return get_auto_inference_dtype(device=device) + assert compute_dtype in {"float32", "bfloat16", "float16"} + dtype = getattr(torch, compute_dtype) + return dtype + def get_buffered_pred_feat_rnnt( asr: FrameBatchASR,