From 204251e2ff680a470ab6053a324cc34c5eb27ca8 Mon Sep 17 00:00:00 2001 From: Sahan Paliskara Date: Tue, 26 Aug 2025 10:15:56 -0700 Subject: [PATCH] Convert all tests in torchbench to bf16 --- BackendBench/scripts/dataset_filters.py | 23 ++++- .../scripts/parquet_trace_converter.py | 85 ++++++++++++++++++- BackendBench/utils.py | 5 +- 3 files changed, 106 insertions(+), 7 deletions(-) diff --git a/BackendBench/scripts/dataset_filters.py b/BackendBench/scripts/dataset_filters.py index ef2f5655..6cbcd61b 100644 --- a/BackendBench/scripts/dataset_filters.py +++ b/BackendBench/scripts/dataset_filters.py @@ -9,7 +9,25 @@ import tqdm from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args -from triton.testing import do_bench + +# Try to import triton, fallback to cpu_bench if not available +try: + if torch.cuda.is_available(): + from triton.testing import do_bench + else: + # CPU fallback using cpu_bench from eval module + from BackendBench.eval import cpu_bench + # Wrap cpu_bench to match do_bench interface + def do_bench(fn, warmup=None, rep=None, **kwargs): + # cpu_bench returns time in seconds, convert to ms to match triton + return cpu_bench(fn, num_runs=rep if rep else 100) * 1000 +except ImportError: + # Fallback if triton is not installed at all + from BackendBench.eval import cpu_bench + # Wrap cpu_bench to match do_bench interface + def do_bench(fn, warmup=None, rep=None, **kwargs): + # cpu_bench returns time in seconds, convert to ms to match triton + return cpu_bench(fn, num_runs=rep if rep else 100) * 1000 # Operators to skip for indexing ops that need valid indices SKIP_OPERATORS = [ @@ -62,7 +80,8 @@ def apply_skip_ops_filter(ops): def apply_runtime_filter(ops): def _overhead_benchmark(): - return torch.randn(1, device="cuda") + device = "cuda" if torch.cuda.is_available() else "cpu" + return torch.randn(1, device=device) runtime_threshold_ms = do_bench(_overhead_benchmark, warmup=25, rep=100) diff --git a/BackendBench/scripts/parquet_trace_converter.py b/BackendBench/scripts/parquet_trace_converter.py index 6160424d..bd6685ba 100644 --- a/BackendBench/scripts/parquet_trace_converter.py +++ b/BackendBench/scripts/parquet_trace_converter.py @@ -16,6 +16,7 @@ import numpy as np import pyarrow as pa import pyarrow.parquet as pq +import torch from BackendBench.data_loaders import _load_from_trace from BackendBench.scripts.dataset_filters import ( apply_runtime_filter, @@ -77,13 +78,72 @@ def setup_logging(log_level): ) -def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): +def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None, force_dtype: str = None): """ Convert a trace file to a parquet file + + Args: + trace_file: Path to trace file + parquet_file: Output parquet file path + limit: Max number of operations to process + force_dtype: Force all tensors to be of this dtype (e.g., 'bf16', 'f32', 'cpu') """ # Load operations using local trace parsing function ops = _load_from_trace(trace_file, filter=None, limit=limit) + + # Convert tensors to specified dtype if requested + conversion_failures = 0 + converted_ops = 0 + if force_dtype: + try: + from BackendBench.utils import deserialize_args, serialize_args, dtype_abbrs_parsing + except ImportError: + logger.error("Failed to import required utilities for tensor conversion") + raise + + # Check if force_dtype is 'cpu' (device) or a dtype + is_device_conversion = force_dtype == 'cpu' + target_dtype = None if is_device_conversion else dtype_abbrs_parsing.get(force_dtype) + + if not is_device_conversion and target_dtype is None: + raise ValueError(f"Invalid dtype: {force_dtype}. Valid options: {', '.join(dtype_abbrs_parsing.keys())} or 'cpu'") + + for op in ops: + try: + args, kwargs = deserialize_args(op["args"]) + + def convert_tensor(t): + if isinstance(t, torch.Tensor): + if is_device_conversion: + return t.cpu() + else: + return t.to(dtype=target_dtype) + return t + + def convert_nested(obj): + if isinstance(obj, torch.Tensor): + return convert_tensor(obj) + elif isinstance(obj, (list, tuple)): + return type(obj)(convert_nested(item) for item in obj) + elif isinstance(obj, dict): + return {k: convert_nested(v) for k, v in obj.items()} + return obj + + # Convert all tensors in args and kwargs + converted_args = convert_nested(args) + converted_kwargs = convert_nested(kwargs) + + # Re-serialize the converted arguments + op["args"] = serialize_args(converted_args, converted_kwargs) + converted_ops += 1 + + except Exception as e: + conversion_failures += 1 + op["conversion_failed"] = True + logger.debug(f"Failed to convert tensors for {op['op_name']}: {e}") + + logger.info(f"Tensor conversion to {force_dtype}: {converted_ops} successful, {conversion_failures} failed") # Add additional metadata fields required for the parquet format for op in ops: @@ -94,6 +154,9 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): op["relative_runtime_to_kernel_launch"] = np.nan op["runnable"] = True op["is_overhead_dominated_op"] = False + if force_dtype and "conversion_failed" in op: + op["included_in_benchmark"] = False + op["why_excluded"].append(f"tensor_conversion_to_{force_dtype}_failed") # apply filters ops = apply_skip_ops_filter(ops) @@ -103,7 +166,10 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): exclusion_mapping = defaultdict(lambda: set()) testable_ops = set() all_ops = set() + conversion_failed_count = 0 for op in ops: + if "conversion_failed" in op and op["conversion_failed"]: + conversion_failed_count += 1 for reason in op["why_excluded"]: exclusion_dict[reason] += 1 exclusion_mapping[reason].add(op["op_name"]) @@ -134,6 +200,13 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): logger.info( f"Found {len(overhead_dominated_op_names)} / {len(all_ops)} unique ops that are dominated by overhead" ) + + if force_dtype and conversion_failed_count > 0: + logger.warning(f"\n{'='*60}") + logger.warning(f"TENSOR CONVERSION RESULTS (to {force_dtype}):") + logger.warning(f"Tests that failed conversion: {conversion_failed_count} / {len(ops)}") + logger.warning(f"Tests that did not run due to conversion: {conversion_failed_count}") + logger.warning(f"{'='*60}\n") # Create parquet table with all metadata (formerly "dev" version) table = pa.Table.from_pylist(ops) @@ -245,7 +318,13 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str: type=int, help="Limit the number of operators to convert. (Useful for testing)", ) -def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit): +@click.option( + "--force-dtype", + default="bf16", + type=str, + help="Force all tensors to specific dtype (e.g., 'bf16', 'f32', 'i32', 'cpu'). Default: bf16", +) +def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit, force_dtype): """Convert trace files to parquet format or vice versa.""" setup_logging(log_level) @@ -259,7 +338,7 @@ def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit): logger.info(f"Converting trace file {trace_file} to parquet file {parquet_name}") - convert_trace_to_parquet(trace_file, parquet_name, limit=limit) + convert_trace_to_parquet(trace_file, parquet_name, limit=limit, force_dtype=force_dtype) logger.info("Conversion completed successfully") if upload_to_hf: diff --git a/BackendBench/utils.py b/BackendBench/utils.py index 7f6c2aa9..33bc3637 100644 --- a/BackendBench/utils.py +++ b/BackendBench/utils.py @@ -225,5 +225,6 @@ def compute_errors(ref, res, eps=1e-10): def cleanup_memory_and_gpu(): """Helper function to clean up GPU memory""" gc.collect() - torch.cuda.synchronize() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache()