-
Notifications
You must be signed in to change notification settings - Fork 971
Open
Labels
good first issueGood for newcomersGood for newcomersimprovementNew feature or improvementNew feature or improvementpr_welcomeOpen to be worked onOpen to be worked on
Description
Is your feature request related to a current problem? Please describe.
Current parameter validation for TorchForecasting models seems to break when using subclassing of the existing models because it is using hardcoded TorchForecastingModel/PLForecastingModule/cls combination:
Describe proposed solution
Using logic that inspects the whole chain of parents classes seems to work well
Describe potential alternatives
The only workaround is to provide all of the params of the parent class in the child class init as well. This is a pretty brittle approach though
Additional context
Implementation that worked in my case:
@classmethod
def _validate_model_params(cls, **kwargs):
valid_kwargs = set(
inspect.signature(PLForecastingModule.__init__).parameters.keys()
)
for base in inspect.getmro(cls):
if base is object:
break
try:
sig = inspect.signature(base.__init__)
valid_kwargs.update(sig.parameters.keys())
except (ValueError, TypeError):
# In case a built-in class throws or __init__ is not introspectable
continue
# Remove 'self' and 'args/kwargs' from consideration
valid_kwargs.discard("self")
valid_kwargs.discard("args")
valid_kwargs.discard("kwargs")
invalid_kwargs = [kwarg for kwarg in kwargs if kwarg not in valid_kwargs]
raise_if(
len(invalid_kwargs) > 0,
f"Invalid model creation parameters. Model `{cls.__name__}` has no args/kwargs `{invalid_kwargs}`",
logger=logger,
)
eschibli
Metadata
Metadata
Assignees
Labels
good first issueGood for newcomersGood for newcomersimprovementNew feature or improvementNew feature or improvementpr_welcomeOpen to be worked onOpen to be worked on