Skip to content

Modified parameter validation to support better subclassing #2843

@tRosenflanz

Description

@tRosenflanz

Is your feature request related to a current problem? Please describe.

Current parameter validation for TorchForecasting models seems to break when using subclassing of the existing models because it is using hardcoded TorchForecastingModel/PLForecastingModule/cls combination:

https://github.com/unit8co/darts/blob/2391d8bb1c8290c9b940d5658c47072fe3de0ca8/darts/models/forecasting/torch_forecasting_model.py#L375C9-L375C31

Describe proposed solution

Using logic that inspects the whole chain of parents classes seems to work well

Describe potential alternatives

The only workaround is to provide all of the params of the parent class in the child class init as well. This is a pretty brittle approach though

Additional context
Implementation that worked in my case:

    @classmethod
    def _validate_model_params(cls, **kwargs):
        valid_kwargs = set(
            inspect.signature(PLForecastingModule.__init__).parameters.keys()
        )

        for base in inspect.getmro(cls):
            if base is object:
                break
            try:
                sig = inspect.signature(base.__init__)
                valid_kwargs.update(sig.parameters.keys())
            except (ValueError, TypeError):
                # In case a built-in class throws or __init__ is not introspectable
                continue

        # Remove 'self' and 'args/kwargs' from consideration
        valid_kwargs.discard("self")
        valid_kwargs.discard("args")
        valid_kwargs.discard("kwargs")

        invalid_kwargs = [kwarg for kwarg in kwargs if kwarg not in valid_kwargs]

        raise_if(
            len(invalid_kwargs) > 0,
            f"Invalid model creation parameters. Model `{cls.__name__}` has no args/kwargs `{invalid_kwargs}`",
            logger=logger,
        )

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions