Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/xpu/skip_list_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
"test_python_ref_executor__refs_mul_executor_aten_xpu_complex32",
# https://github.com/intel/torch-xpu-ops/issues/2254
"histogramdd",
"_vdot_",
"_dot_",
"_flash_attention_",
"_efficient_attention_",
),
"test_binary_ufuncs_xpu.py": (
"test_fmod_remainder_by_zero_integral_xpu_int64", # zero division is an undefined behavior: different handles on different backends
Expand Down
26 changes: 10 additions & 16 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest

import torch
from torch import bfloat16, cuda
from torch import cuda
from torch.testing._internal import (
common_cuda,
common_device_type,
Expand Down Expand Up @@ -354,6 +354,11 @@
"_refs.div",
"test_python_ref_torch_fallback",
),
("_refs.true_div", "test_python_ref"),
(
"_refs.true_div",
"test_python_ref_torch_fallback",
),
("argsort", "test_non_standard_bool_values"),
("sort", "test_non_standard_bool_values"),
]
Expand Down Expand Up @@ -865,7 +870,6 @@ def __init__(self, patch_test_case=True) -> None:
)
self.foreach_reduce_op_db = common_methods_invocations.foreach_reduce_op_db
self.foreach_other_op_db = common_methods_invocations.foreach_other_op_db
self.python_ref_db = common_methods_invocations.python_ref_db
self.ops_and_refs = common_methods_invocations.ops_and_refs
self.largeTensorTest = common_device_type.largeTensorTest
self.TEST_CUDA = common_cuda.TEST_CUDA
Expand Down Expand Up @@ -921,19 +925,10 @@ def gen_xpu_wrappers(op_name, wrappers):

def align_supported_dtypes(self, db):
for opinfo in db:
if (
opinfo.name not in _xpu_computation_op_list
and (
opinfo.torch_opinfo.name not in _xpu_computation_op_list
if db == common_methods_invocations.python_ref_db
else True
)
) or opinfo.name in _ops_without_cuda_support:
if opinfo.name in _ops_without_cuda_support:
opinfo.dtypesIf["xpu"] = opinfo.dtypes
else:
backward_dtypes = set(opinfo.backward_dtypesIfCUDA)
if bfloat16 in opinfo.dtypesIf["xpu"]:
backward_dtypes.add(bfloat16)
opinfo.backward_dtypes = tuple(backward_dtypes)

if opinfo.name in _ops_dtype_different_cuda_support:
Expand Down Expand Up @@ -1039,13 +1034,13 @@ def __init__(self, *args):
self.align_db_decorators(db)
self.filter_fp64_sample_input(db)
self.align_db_decorators(module_db)
common_methods_invocations.python_ref_db = [
_python_ref_db = [
op
for op in self.python_ref_db
for op in common_methods_invocations.python_ref_db
if op.torch_opinfo_name in _xpu_computation_op_list
]
common_methods_invocations.ops_and_refs = (
common_methods_invocations.op_db + common_methods_invocations.python_ref_db
common_methods_invocations.op_db + _python_ref_db
)
common_methods_invocations.unary_ufuncs = [
op
Expand Down Expand Up @@ -1128,7 +1123,6 @@ def __exit__(self, exc_type, exc_value, traceback):
self.instantiate_parametrized_tests_fn
)
common_utils.TestCase = self.test_case_cls
common_methods_invocations.python_ref_db = self.python_ref_db
common_methods_invocations.ops_and_refs = self.ops_and_refs
common_device_type.largeTensorTest = self.largeTensorTest
common_cuda.TEST_CUDA = self.TEST_CUDA
Expand Down