-
Notifications
You must be signed in to change notification settings - Fork 7
WIP: custom serialization #155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds custom serialization support to the OpTest
class to handle callable operations that contain non-serializable PyCapsule objects. The implementation breaks down callable operations into module and function name components during serialization and reconstructs them during deserialization.
Key changes:
- Added custom
__getstate__
and__setstate__
methods to handle callable operation serialization - Import of
importlib
module to support dynamic module loading during deserialization
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
state["op_name"] = op.__name__ | ||
state["op_module"] = op.__module__ | ||
return state | ||
|
||
def __setstate__(self, state): | ||
if "op_name" in state and "op_module" in state: | ||
op_name = state.pop("op_name") | ||
op_module = state.pop("op_module") | ||
module = importlib.import_module(op_module) | ||
state["op"] = getattr(module, op_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code assumes that callable objects always have __name__
and __module__
attributes, but some callable objects like lambda functions, functools.partial objects, or custom callable classes may not have these attributes or may have unexpected values. This could cause AttributeError during serialization.
state["op_name"] = op.__name__ | |
state["op_module"] = op.__module__ | |
return state | |
def __setstate__(self, state): | |
if "op_name" in state and "op_module" in state: | |
op_name = state.pop("op_name") | |
op_module = state.pop("op_module") | |
module = importlib.import_module(op_module) | |
state["op"] = getattr(module, op_name) | |
state["op_name"] = getattr(op, "__name__", None) | |
state["op_module"] = getattr(op, "__module__", None) | |
return state | |
def __setstate__(self, state): | |
if "op_name" in state and "op_module" in state: | |
op_name = state.pop("op_name") | |
op_module = state.pop("op_module") | |
if op_name is not None and op_module is not None: | |
module = importlib.import_module(op_module) | |
state["op"] = getattr(module, op_name) | |
else: | |
# Could not restore op; set to None or raise error | |
state["op"] = None |
Copilot uses AI. Check for mistakes.
module = importlib.import_module(op_module) | ||
state["op"] = getattr(module, op_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The deserialization process doesn't handle cases where the module cannot be imported or the attribute doesn't exist in the module. This could cause ImportError or AttributeError during deserialization, making the object unusable.
module = importlib.import_module(op_module) | |
state["op"] = getattr(module, op_name) | |
try: | |
module = importlib.import_module(op_module) | |
state["op"] = getattr(module, op_name) | |
except (ImportError, AttributeError) as e: | |
raise RuntimeError( | |
f"Failed to deserialize 'op': could not import module '{op_module}' or find attribute '{op_name}'." | |
) from e |
Copilot uses AI. Check for mistakes.
let us know when you need a review |
|
Should be very much non-intrusive and shouldn't matter for any current stuff. I.e. when working with modal/multi-process etc, op is a PyCapsule object as it encapsulates the underlying method, which is not serializable. This just breaks it into a module + op_name and reimports when deserialized.