Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions BackendBench/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
TRITON_AVAILABLE = False

from BackendBench.utils import serialize_args, uses_cuda_stream, compute_errors
from BackendBench.scripts.pytorch_operators import extract_operator_name
from BackendBench.scripts.dataset_filters import TENSOR_CREATION_OPERATORS

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -64,6 +66,22 @@ def allclose(a, b, atol=1e-2, rtol=1e-2):
return False


def equal_metadata(a, b):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'm not super clear on is that OpInfo this is indeed the way they test tensor creation ops, that's how we figured out this might be the right testing strategy. So why not just use OpInfo again here?

Copy link
Contributor

@PaliC PaliC Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try:
_allclose(a.shape, b.shape, atol=0.0, rtol=0.0)
_allclose(a.stride(), b.stride(), atol=0.0, rtol=0.0)
_allclose(a.dtype, b.dtype, atol=0.0, rtol=0.0)
_allclose(a.device, b.device, atol=0.0, rtol=0.0)
_allclose(a.is_sparse, b.is_sparse, atol=0.0, rtol=0.0)
return True
Copy link
Contributor

@PaliC PaliC Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd check the type string as well as per the reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! I have added _allclose(a.is_sparse, b.is_sparse, atol=0.0, rtol=0.0).

The type string assertion checks for dtype, device, and is_sparse. The first two are checked already, so I only add is_sparse.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait ... let's just use the functions / machinery from pytorch directly. I feel like that's a bit more future proof / feeds into our desire to make these generated kernels mergable into pytorch

except Exception:
return False


def test_metadata(op):
return extract_operator_name(str(op)) in TENSOR_CREATION_OPERATORS


def eval_correctness_test(
op, impl, test
) -> Tuple[bool, Optional[str], Optional[float], Optional[float]]:
Expand All @@ -76,12 +94,16 @@ def eval_correctness_test(
ref = op(*args, **kwargs)
try:
res = impl(*args, **kwargs)
is_correct = allclose(ref, res)
if test_metadata(op):
is_correct = equal_metadata(ref, res)
return is_correct, None, 0.0, 0.0
else:
is_correct = allclose(ref, res)

# Compute errors even if test passes (for verbose mode)
abs_error, rel_error = compute_errors(ref, res)
# Compute errors even if test passes (for verbose mode)
abs_error, rel_error = compute_errors(ref, res)

return is_correct, None, abs_error, rel_error
return is_correct, None, abs_error, rel_error
except Exception as e:
error_msg = format_exception(e, op, args, kwargs)
logger.warning(error_msg)
Expand Down Expand Up @@ -147,11 +169,17 @@ def eval_performance(op, impl, tests, test_data: defaultdict = defaultdict(dict)
try:
ref = op(*test.args, **test.kwargs)
res = impl(*test.args, **test.kwargs)
if not allclose(
ref,
res,
):
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}")
if test_metadata(op):
if not equal_metadata(ref, res):
raise ValueError(
f"Reference and result tensors metadata are not equal: {ref} vs {res}"
)
else:
if not allclose(
ref,
res,
):
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}")
test_time = bench_fn(lambda: impl(*test.args, **test.kwargs))
except Exception:
pass
Expand Down
10 changes: 7 additions & 3 deletions BackendBench/scripts/dataset_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@
# https://github.com/meta-pytorch/BackendBench/issues/108
RELATIVE_RUNTIME_THRESHOLD = 1.3
UNTESTABLE_OPERATORS = [
"empty_like", # We can check using metadata
"new_empty", # We can check using metadata
"new_empty_strided", # We can check using metadata
"bernoulli", # We can write a custom test to verify this one (albeit not the randomness)
]

# Check using metadata
TENSOR_CREATION_OPERATORS = [
"empty_like",
"new_empty",
"new_empty_strided",
]


def apply_skip_ops_filter(ops):
for op in tqdm.tqdm(ops, desc="Filtering ops by skip and synthetic ops"):
Expand Down
62 changes: 32 additions & 30 deletions test/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,18 @@
import torch
import numpy as np

try:
import importlib.util
from BackendBench.eval import (
format_exception,
allclose,
eval_correctness_test,
eval_correctness,
eval_one_op,
cpu_bench,
gpu_bench,
perf_at_p,
)

HAS_TRITON = importlib.util.find_spec("triton") is not None
except ImportError:
HAS_TRITON = False
import importlib.util
from BackendBench.eval import (
format_exception,
allclose,
eval_correctness_test,
eval_correctness,
eval_one_op,
cpu_bench,
perf_at_p,
)

HAS_TRITON = importlib.util.find_spec("triton") is not None

pytestmark = pytest.mark.skipif(not HAS_TRITON, reason="triton not available")

Expand All @@ -37,7 +33,7 @@ def test_format_exception(self):

formatted = format_exception(exc, op, args, kwargs)
assert "relu.default" in formatted
assert "torch.float32[2, 3]" in formatted
assert "T([2, 3], f32)" in formatted
assert "dim" in formatted
assert "Test error" in formatted

Expand Down Expand Up @@ -167,7 +163,25 @@ def __init__(self, args, kwargs):
test_data = {}
score = eval_correctness(op, impl, tests, test_data)
assert score == 1.0
assert len(test_data) == len(tests) # Should have data for each test
# TODO: test_data is overwritten when test with same args
# assert len(test_data) == len(tests) # Should have data for each test

def test_eval_correctness_metadata(self):
op = torch.empty_like
impl = torch.empty_like # Same implementation

class TestCase:
def __init__(self, args, kwargs):
self.args = args
self.kwargs = kwargs

tests = [TestCase([torch.randn(2, 3)], {})]

test_data = {}
score = eval_correctness(op, impl, tests, test_data)
assert score == 1.0
# TODO: test_data is overwritten when test with same args
# assert len(test_data) == len(tests) # Should have data for each test


class TestEvalPerformance:
Expand All @@ -185,18 +199,6 @@ def test_fn():
assert counter == 20
assert time_per_run > 0

def test_gpu_bench(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

was this giving a problem or do you jus think it's a useless test?

Copy link
Contributor Author

@jiannanWang jiannanWang Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no gpu_bench function in eval.py and we are using triton.testing.do_bench for gpu performance. This actually causes an import error and is fixed in this pr.

counter = 0

def test_fn():
nonlocal counter
counter += 1

time_per_run = gpu_bench(test_fn, num_runs=10)

assert counter == 20
assert time_per_run > 0


class TestEvalOneOp:
def test_eval_one_op(self):
Expand Down
Loading