-
Notifications
You must be signed in to change notification settings - Fork 8
Test metadata for untestable ops and fix test_eval.py #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
209044e
a3fa3b0
077638c
e542c22
24c699f
9059b1d
6831c5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,8 @@ | |
TRITON_AVAILABLE = False | ||
|
||
from BackendBench.utils import serialize_args, uses_cuda_stream, compute_errors | ||
from BackendBench.scripts.pytorch_operators import extract_operator_name | ||
from BackendBench.scripts.dataset_filters import TENSOR_CREATION_OPERATORS | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
@@ -64,6 +66,22 @@ def allclose(a, b, atol=1e-2, rtol=1e-2): | |
return False | ||
|
||
|
||
def equal_metadata(a, b): | ||
try: | ||
_allclose(a.shape, b.shape, atol=0.0, rtol=0.0) | ||
_allclose(a.stride(), b.stride(), atol=0.0, rtol=0.0) | ||
_allclose(a.dtype, b.dtype, atol=0.0, rtol=0.0) | ||
_allclose(a.device, b.device, atol=0.0, rtol=0.0) | ||
_allclose(a.is_sparse, b.is_sparse, atol=0.0, rtol=0.0) | ||
return True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd check the type string as well as per the reference There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! I have added The type string assertion checks for dtype, device, and is_sparse. The first two are checked already, so I only add is_sparse. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wait ... let's just use the functions / machinery from pytorch directly. I feel like that's a bit more future proof / feeds into our desire to make these generated kernels mergable into pytorch |
||
except Exception: | ||
return False | ||
|
||
|
||
def test_metadata(op): | ||
return extract_operator_name(str(op)) in TENSOR_CREATION_OPERATORS | ||
|
||
|
||
def eval_correctness_test( | ||
op, impl, test | ||
) -> Tuple[bool, Optional[str], Optional[float], Optional[float]]: | ||
|
@@ -76,12 +94,16 @@ def eval_correctness_test( | |
ref = op(*args, **kwargs) | ||
try: | ||
res = impl(*args, **kwargs) | ||
is_correct = allclose(ref, res) | ||
if test_metadata(op): | ||
is_correct = equal_metadata(ref, res) | ||
return is_correct, None, 0.0, 0.0 | ||
else: | ||
is_correct = allclose(ref, res) | ||
|
||
# Compute errors even if test passes (for verbose mode) | ||
abs_error, rel_error = compute_errors(ref, res) | ||
# Compute errors even if test passes (for verbose mode) | ||
abs_error, rel_error = compute_errors(ref, res) | ||
|
||
return is_correct, None, abs_error, rel_error | ||
return is_correct, None, abs_error, rel_error | ||
except Exception as e: | ||
error_msg = format_exception(e, op, args, kwargs) | ||
logger.warning(error_msg) | ||
|
@@ -147,11 +169,17 @@ def eval_performance(op, impl, tests, test_data: defaultdict = defaultdict(dict) | |
try: | ||
ref = op(*test.args, **test.kwargs) | ||
res = impl(*test.args, **test.kwargs) | ||
if not allclose( | ||
ref, | ||
res, | ||
): | ||
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}") | ||
if test_metadata(op): | ||
if not equal_metadata(ref, res): | ||
raise ValueError( | ||
f"Reference and result tensors metadata are not equal: {ref} vs {res}" | ||
) | ||
else: | ||
if not allclose( | ||
ref, | ||
res, | ||
): | ||
raise ValueError(f"Reference and result tensors are not close: {ref} vs {res}") | ||
test_time = bench_fn(lambda: impl(*test.args, **test.kwargs)) | ||
except Exception: | ||
pass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,22 +8,18 @@ | |
import torch | ||
import numpy as np | ||
|
||
try: | ||
import importlib.util | ||
from BackendBench.eval import ( | ||
format_exception, | ||
allclose, | ||
eval_correctness_test, | ||
eval_correctness, | ||
eval_one_op, | ||
cpu_bench, | ||
gpu_bench, | ||
perf_at_p, | ||
) | ||
|
||
HAS_TRITON = importlib.util.find_spec("triton") is not None | ||
except ImportError: | ||
HAS_TRITON = False | ||
import importlib.util | ||
from BackendBench.eval import ( | ||
format_exception, | ||
allclose, | ||
eval_correctness_test, | ||
eval_correctness, | ||
eval_one_op, | ||
cpu_bench, | ||
perf_at_p, | ||
) | ||
|
||
HAS_TRITON = importlib.util.find_spec("triton") is not None | ||
|
||
pytestmark = pytest.mark.skipif(not HAS_TRITON, reason="triton not available") | ||
|
||
|
@@ -37,7 +33,7 @@ def test_format_exception(self): | |
|
||
formatted = format_exception(exc, op, args, kwargs) | ||
assert "relu.default" in formatted | ||
assert "torch.float32[2, 3]" in formatted | ||
assert "T([2, 3], f32)" in formatted | ||
assert "dim" in formatted | ||
assert "Test error" in formatted | ||
|
||
|
@@ -167,7 +163,25 @@ def __init__(self, args, kwargs): | |
test_data = {} | ||
score = eval_correctness(op, impl, tests, test_data) | ||
assert score == 1.0 | ||
assert len(test_data) == len(tests) # Should have data for each test | ||
# TODO: test_data is overwritten when test with same args | ||
# assert len(test_data) == len(tests) # Should have data for each test | ||
|
||
def test_eval_correctness_metadata(self): | ||
op = torch.empty_like | ||
impl = torch.empty_like # Same implementation | ||
|
||
class TestCase: | ||
def __init__(self, args, kwargs): | ||
self.args = args | ||
self.kwargs = kwargs | ||
|
||
tests = [TestCase([torch.randn(2, 3)], {})] | ||
|
||
test_data = {} | ||
score = eval_correctness(op, impl, tests, test_data) | ||
assert score == 1.0 | ||
# TODO: test_data is overwritten when test with same args | ||
# assert len(test_data) == len(tests) # Should have data for each test | ||
|
||
|
||
class TestEvalPerformance: | ||
|
@@ -185,18 +199,6 @@ def test_fn(): | |
assert counter == 20 | ||
assert time_per_run > 0 | ||
|
||
def test_gpu_bench(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was this giving a problem or do you jus think it's a useless test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no |
||
counter = 0 | ||
|
||
def test_fn(): | ||
nonlocal counter | ||
counter += 1 | ||
|
||
time_per_run = gpu_bench(test_fn, num_runs=10) | ||
|
||
assert counter == 20 | ||
assert time_per_run > 0 | ||
|
||
|
||
class TestEvalOneOp: | ||
def test_eval_one_op(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing I'm not super clear on is that OpInfo this is indeed the way they test tensor creation ops, that's how we figured out this might be the right testing strategy. So why not just use OpInfo again here?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is a reference here to pytorch's testing strategy https://github.com/pytorch/pytorch/blob/332fa5b388521c05a19217649745c6edfdc2836d/test/test_tensor_creation_ops.py