Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions BackendBench/suite/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
# LICENSE file in the root directory of this source tree.


import importlib


class Test:
def __init__(self, *args, **kwargs):
self._args = args
Expand All @@ -25,6 +28,23 @@ def __init__(self, op, correctness_tests, performance_tests):
self.correctness_tests = correctness_tests
self.performance_tests = performance_tests

def __getstate__(self):
# Custom serialization to handle callable op
state = self.__dict__.copy()
if callable(state.get("op")):
op = state.pop("op")
state["op_name"] = op.__name__
state["op_module"] = op.__module__
return state

def __setstate__(self, state):
if "op_name" in state and "op_module" in state:
op_name = state.pop("op_name")
op_module = state.pop("op_module")
module = importlib.import_module(op_module)
state["op"] = getattr(module, op_name)
Comment on lines +36 to +45
Copy link
Preview

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code assumes that callable objects always have __name__ and __module__ attributes, but some callable objects like lambda functions, functools.partial objects, or custom callable classes may not have these attributes or may have unexpected values. This could cause AttributeError during serialization.

Suggested change
state["op_name"] = op.__name__
state["op_module"] = op.__module__
return state
def __setstate__(self, state):
if "op_name" in state and "op_module" in state:
op_name = state.pop("op_name")
op_module = state.pop("op_module")
module = importlib.import_module(op_module)
state["op"] = getattr(module, op_name)
state["op_name"] = getattr(op, "__name__", None)
state["op_module"] = getattr(op, "__module__", None)
return state
def __setstate__(self, state):
if "op_name" in state and "op_module" in state:
op_name = state.pop("op_name")
op_module = state.pop("op_module")
if op_name is not None and op_module is not None:
module = importlib.import_module(op_module)
state["op"] = getattr(module, op_name)
else:
# Could not restore op; set to None or raise error
state["op"] = None

Copilot uses AI. Check for mistakes.

Comment on lines +44 to +45
Copy link
Preview

Copilot AI Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The deserialization process doesn't handle cases where the module cannot be imported or the attribute doesn't exist in the module. This could cause ImportError or AttributeError during deserialization, making the object unusable.

Suggested change
module = importlib.import_module(op_module)
state["op"] = getattr(module, op_name)
try:
module = importlib.import_module(op_module)
state["op"] = getattr(module, op_name)
except (ImportError, AttributeError) as e:
raise RuntimeError(
f"Failed to deserialize 'op': could not import module '{op_module}' or find attribute '{op_name}'."
) from e

Copilot uses AI. Check for mistakes.

self.__dict__.update(state)


class TestSuite:
def __init__(self, name, optests):
Expand Down