Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions BackendBench/scripts/dataset_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,25 @@

import tqdm
from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args
from triton.testing import do_bench

# Try to import triton, fallback to cpu_bench if not available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't have to make this work for cpu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly to make testing easier when you don't have access to a gpu (I've gotten into the habit of developing on my laptop and using a devserver for testing).

try:
if torch.cuda.is_available():
from triton.testing import do_bench
else:
# CPU fallback using cpu_bench from eval module
from BackendBench.eval import cpu_bench
# Wrap cpu_bench to match do_bench interface
def do_bench(fn, warmup=None, rep=None, **kwargs):
# cpu_bench returns time in seconds, convert to ms to match triton
return cpu_bench(fn, num_runs=rep if rep else 100) * 1000
except ImportError:
# Fallback if triton is not installed at all
from BackendBench.eval import cpu_bench
# Wrap cpu_bench to match do_bench interface
def do_bench(fn, warmup=None, rep=None, **kwargs):
# cpu_bench returns time in seconds, convert to ms to match triton
return cpu_bench(fn, num_runs=rep if rep else 100) * 1000

# Operators to skip for indexing ops that need valid indices
SKIP_OPERATORS = [
Expand Down Expand Up @@ -62,7 +80,8 @@ def apply_skip_ops_filter(ops):

def apply_runtime_filter(ops):
def _overhead_benchmark():
return torch.randn(1, device="cuda")
device = "cuda" if torch.cuda.is_available() else "cpu"
return torch.randn(1, device=device)

runtime_threshold_ms = do_bench(_overhead_benchmark, warmup=25, rep=100)

Expand Down
85 changes: 82 additions & 3 deletions BackendBench/scripts/parquet_trace_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from BackendBench.data_loaders import _load_from_trace
from BackendBench.scripts.dataset_filters import (
apply_runtime_filter,
Expand Down Expand Up @@ -77,13 +78,72 @@ def setup_logging(log_level):
)


def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None):
def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None, force_dtype: str = None):
"""
Convert a trace file to a parquet file

Args:
trace_file: Path to trace file
parquet_file: Output parquet file path
limit: Max number of operations to process
force_dtype: Force all tensors to be of this dtype (e.g., 'bf16', 'f32', 'cpu')
"""

# Load operations using local trace parsing function
ops = _load_from_trace(trace_file, filter=None, limit=limit)

# Convert tensors to specified dtype if requested
conversion_failures = 0
converted_ops = 0
if force_dtype:
try:
from BackendBench.utils import deserialize_args, serialize_args, dtype_abbrs_parsing
except ImportError:
logger.error("Failed to import required utilities for tensor conversion")
raise

# Check if force_dtype is 'cpu' (device) or a dtype
is_device_conversion = force_dtype == 'cpu'
target_dtype = None if is_device_conversion else dtype_abbrs_parsing.get(force_dtype)

if not is_device_conversion and target_dtype is None:
raise ValueError(f"Invalid dtype: {force_dtype}. Valid options: {', '.join(dtype_abbrs_parsing.keys())} or 'cpu'")

for op in ops:
try:
args, kwargs = deserialize_args(op["args"])

def convert_tensor(t):
if isinstance(t, torch.Tensor):
if is_device_conversion:
return t.cpu()
else:
return t.to(dtype=target_dtype)
return t

def convert_nested(obj):
if isinstance(obj, torch.Tensor):
return convert_tensor(obj)
elif isinstance(obj, (list, tuple)):
return type(obj)(convert_nested(item) for item in obj)
elif isinstance(obj, dict):
return {k: convert_nested(v) for k, v in obj.items()}
return obj

# Convert all tensors in args and kwargs
converted_args = convert_nested(args)
converted_kwargs = convert_nested(kwargs)

# Re-serialize the converted arguments
op["args"] = serialize_args(converted_args, converted_kwargs)
converted_ops += 1

except Exception as e:
conversion_failures += 1
op["conversion_failed"] = True
logger.debug(f"Failed to convert tensors for {op['op_name']}: {e}")

logger.info(f"Tensor conversion to {force_dtype}: {converted_ops} successful, {conversion_failures} failed")

# Add additional metadata fields required for the parquet format
for op in ops:
Expand All @@ -94,6 +154,9 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None):
op["relative_runtime_to_kernel_launch"] = np.nan
op["runnable"] = True
op["is_overhead_dominated_op"] = False
if force_dtype and "conversion_failed" in op:
op["included_in_benchmark"] = False
op["why_excluded"].append(f"tensor_conversion_to_{force_dtype}_failed")

# apply filters
ops = apply_skip_ops_filter(ops)
Expand All @@ -103,7 +166,10 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None):
exclusion_mapping = defaultdict(lambda: set())
testable_ops = set()
all_ops = set()
conversion_failed_count = 0
for op in ops:
if "conversion_failed" in op and op["conversion_failed"]:
conversion_failed_count += 1
for reason in op["why_excluded"]:
exclusion_dict[reason] += 1
exclusion_mapping[reason].add(op["op_name"])
Expand Down Expand Up @@ -134,6 +200,13 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None):
logger.info(
f"Found {len(overhead_dominated_op_names)} / {len(all_ops)} unique ops that are dominated by overhead"
)

if force_dtype and conversion_failed_count > 0:
logger.warning(f"\n{'='*60}")
logger.warning(f"TENSOR CONVERSION RESULTS (to {force_dtype}):")
logger.warning(f"Tests that failed conversion: {conversion_failed_count} / {len(ops)}")
logger.warning(f"Tests that did not run due to conversion: {conversion_failed_count}")
logger.warning(f"{'='*60}\n")

# Create parquet table with all metadata (formerly "dev" version)
table = pa.Table.from_pylist(ops)
Expand Down Expand Up @@ -245,7 +318,13 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str:
type=int,
help="Limit the number of operators to convert. (Useful for testing)",
)
def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit):
@click.option(
"--force-dtype",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm suspicious about general dtype conversions, this seems like a massive footgun for correctness testing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh if you have time let's chat about it offline, this was more so just executing on the rec from @malfet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok we chatted offline and this is specific to torchbench data prep

default="bf16",
type=str,
help="Force all tensors to specific dtype (e.g., 'bf16', 'f32', 'i32', 'cpu'). Default: bf16",
)
def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit, force_dtype):
"""Convert trace files to parquet format or vice versa."""
setup_logging(log_level)

Expand All @@ -259,7 +338,7 @@ def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit):

logger.info(f"Converting trace file {trace_file} to parquet file {parquet_name}")

convert_trace_to_parquet(trace_file, parquet_name, limit=limit)
convert_trace_to_parquet(trace_file, parquet_name, limit=limit, force_dtype=force_dtype)
logger.info("Conversion completed successfully")

if upload_to_hf:
Expand Down
5 changes: 3 additions & 2 deletions BackendBench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,5 +225,6 @@ def compute_errors(ref, res, eps=1e-10):
def cleanup_memory_and_gpu():
"""Helper function to clean up GPU memory"""
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
if torch.cuda.is_available():
torch.cuda.synchronize()
torch.cuda.empty_cache()
Loading