-
Notifications
You must be signed in to change notification settings - Fork 8
Convert all tests in torchbench to bf16 #118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import numpy as np | ||
import pyarrow as pa | ||
import pyarrow.parquet as pq | ||
import torch | ||
from BackendBench.data_loaders import _load_from_trace | ||
from BackendBench.scripts.dataset_filters import ( | ||
apply_runtime_filter, | ||
|
@@ -77,13 +78,72 @@ def setup_logging(log_level): | |
) | ||
|
||
|
||
def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): | ||
def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None, force_dtype: str = None): | ||
""" | ||
Convert a trace file to a parquet file | ||
|
||
Args: | ||
trace_file: Path to trace file | ||
parquet_file: Output parquet file path | ||
limit: Max number of operations to process | ||
force_dtype: Force all tensors to be of this dtype (e.g., 'bf16', 'f32', 'cpu') | ||
""" | ||
|
||
# Load operations using local trace parsing function | ||
ops = _load_from_trace(trace_file, filter=None, limit=limit) | ||
|
||
# Convert tensors to specified dtype if requested | ||
conversion_failures = 0 | ||
converted_ops = 0 | ||
if force_dtype: | ||
try: | ||
from BackendBench.utils import deserialize_args, serialize_args, dtype_abbrs_parsing | ||
except ImportError: | ||
logger.error("Failed to import required utilities for tensor conversion") | ||
raise | ||
|
||
# Check if force_dtype is 'cpu' (device) or a dtype | ||
is_device_conversion = force_dtype == 'cpu' | ||
target_dtype = None if is_device_conversion else dtype_abbrs_parsing.get(force_dtype) | ||
|
||
if not is_device_conversion and target_dtype is None: | ||
raise ValueError(f"Invalid dtype: {force_dtype}. Valid options: {', '.join(dtype_abbrs_parsing.keys())} or 'cpu'") | ||
|
||
for op in ops: | ||
try: | ||
args, kwargs = deserialize_args(op["args"]) | ||
|
||
def convert_tensor(t): | ||
if isinstance(t, torch.Tensor): | ||
if is_device_conversion: | ||
return t.cpu() | ||
else: | ||
return t.to(dtype=target_dtype) | ||
return t | ||
|
||
def convert_nested(obj): | ||
if isinstance(obj, torch.Tensor): | ||
return convert_tensor(obj) | ||
elif isinstance(obj, (list, tuple)): | ||
return type(obj)(convert_nested(item) for item in obj) | ||
elif isinstance(obj, dict): | ||
return {k: convert_nested(v) for k, v in obj.items()} | ||
return obj | ||
|
||
# Convert all tensors in args and kwargs | ||
converted_args = convert_nested(args) | ||
converted_kwargs = convert_nested(kwargs) | ||
|
||
# Re-serialize the converted arguments | ||
op["args"] = serialize_args(converted_args, converted_kwargs) | ||
converted_ops += 1 | ||
|
||
except Exception as e: | ||
conversion_failures += 1 | ||
op["conversion_failed"] = True | ||
logger.debug(f"Failed to convert tensors for {op['op_name']}: {e}") | ||
|
||
logger.info(f"Tensor conversion to {force_dtype}: {converted_ops} successful, {conversion_failures} failed") | ||
|
||
# Add additional metadata fields required for the parquet format | ||
for op in ops: | ||
|
@@ -94,6 +154,9 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): | |
op["relative_runtime_to_kernel_launch"] = np.nan | ||
op["runnable"] = True | ||
op["is_overhead_dominated_op"] = False | ||
if force_dtype and "conversion_failed" in op: | ||
op["included_in_benchmark"] = False | ||
op["why_excluded"].append(f"tensor_conversion_to_{force_dtype}_failed") | ||
|
||
# apply filters | ||
ops = apply_skip_ops_filter(ops) | ||
|
@@ -103,7 +166,10 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): | |
exclusion_mapping = defaultdict(lambda: set()) | ||
testable_ops = set() | ||
all_ops = set() | ||
conversion_failed_count = 0 | ||
for op in ops: | ||
if "conversion_failed" in op and op["conversion_failed"]: | ||
conversion_failed_count += 1 | ||
for reason in op["why_excluded"]: | ||
exclusion_dict[reason] += 1 | ||
exclusion_mapping[reason].add(op["op_name"]) | ||
|
@@ -134,6 +200,13 @@ def convert_trace_to_parquet(trace_file, parquet_file, limit: int = None): | |
logger.info( | ||
f"Found {len(overhead_dominated_op_names)} / {len(all_ops)} unique ops that are dominated by overhead" | ||
) | ||
|
||
if force_dtype and conversion_failed_count > 0: | ||
logger.warning(f"\n{'='*60}") | ||
logger.warning(f"TENSOR CONVERSION RESULTS (to {force_dtype}):") | ||
logger.warning(f"Tests that failed conversion: {conversion_failed_count} / {len(ops)}") | ||
logger.warning(f"Tests that did not run due to conversion: {conversion_failed_count}") | ||
logger.warning(f"{'='*60}\n") | ||
|
||
# Create parquet table with all metadata (formerly "dev" version) | ||
table = pa.Table.from_pylist(ops) | ||
|
@@ -245,7 +318,13 @@ def _validate_trace_file(trace_file: str, is_input: bool = True) -> str: | |
type=int, | ||
help="Limit the number of operators to convert. (Useful for testing)", | ||
) | ||
def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit): | ||
@click.option( | ||
"--force-dtype", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm suspicious about general dtype conversions, this seems like a massive footgun for correctness testing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ohh if you have time let's chat about it offline, this was more so just executing on the rec from @malfet. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok we chatted offline and this is specific to torchbench data prep |
||
default="bf16", | ||
type=str, | ||
help="Force all tensors to specific dtype (e.g., 'bf16', 'f32', 'i32', 'cpu'). Default: bf16", | ||
) | ||
def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit, force_dtype): | ||
"""Convert trace files to parquet format or vice versa.""" | ||
setup_logging(log_level) | ||
|
||
|
@@ -259,7 +338,7 @@ def main(log_level, mode, trace_file, parquet_name, upload_to_hf, limit): | |
|
||
logger.info(f"Converting trace file {trace_file} to parquet file {parquet_name}") | ||
|
||
convert_trace_to_parquet(trace_file, parquet_name, limit=limit) | ||
convert_trace_to_parquet(trace_file, parquet_name, limit=limit, force_dtype=force_dtype) | ||
logger.info("Conversion completed successfully") | ||
|
||
if upload_to_hf: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't have to make this work for cpu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is mostly to make testing easier when you don't have access to a gpu (I've gotten into the habit of developing on my laptop and using a devserver for testing).