Skip to content

Conversation

radka-j
Copy link
Member

@radka-j radka-j commented Oct 6, 2025

Closes #748
Closes #874
Closes #878
Closes #757

This PR:

  • adds fit_from_reinitialised method that is used both in AutoEmulate.compare and HMW.refit_emulator
  • emulators now save all their input args so that all input values can be retrieved
  • replaces any **kwargs in emulators with scheduler_kwargs optional keyword argument to match use
  • update HMW so that user can pass emulator as well as result
  • updates Emulator.fit to handle InputLike instead of expecting only TensorLike
  • updates AL to except emulator predictions to be DistributionLike rather than GaussianLike to match TransformedEmulator prediction types
  • updates AL to use fit_from_reinitialized

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Base automatically changed from iss867/update_gp_factory to main October 6, 2025 13:19
@sgreenbury
Copy link
Collaborator

Just adding a note here as ran into this when working with a GP subclass for the error quantification. This call:

model_class = get_emulator_class(result.model_name)

fails since:
emulator_cls = EMULATOR_REGISTRY.get(
name.lower()
) or EMULATOR_REGISTRY_SHORT_NAME.get(name.lower())

doesn't also look at:
GP_REGISTRY = {
"GaussianProcess": GaussianProcess,
"GaussianProcessCorrelated": GaussianProcessCorrelated,
}

@radka-j - adding here as it might be addressed by the upcoming changes to this API? But if not happy to open a new issue to look at this. An option could also be to revisit having a central registry class to handle this uniformly.

@radka-j
Copy link
Member Author

radka-j commented Oct 6, 2025

@sgreenbury I don't think we should ever use the GaussianProcess or GaussianProcessCorrelated classes so this to me feels like correct behaviour. If we want a GP class for an RBF + constant kernel we should add that specifically to the registry.

@sgreenbury
Copy link
Collaborator

It was the GP context (passing a create_gp_subclass instance to AutoEmulate) I ran into this issue and a workaround might have been to also look at GP_REGISTRY since this maintains a registry of all GPs including the created subclasses.

But thinking more about it, it affects any subclass used by AutoEmulate currently if reinitialize is called, e.g. in the advanced tutorial:

class SimpleFNN(PyTorchBackend):
    ...
ae = AutoEmulate(x, y, models=[SimpleFNN])
ae.fit_from_reinitialized(x, y)

since SimpleFNN is constructed at runtime the class is not found in the lists of emulators.

I think if the emulator becomes the entity that does the refitting in this PR then a global emulator registry including all custom subclasses would not be needed for this but might still be useful?

… GaussianLike to match TranformedEmulator predict type
@radka-j
Copy link
Member Author

radka-j commented Oct 13, 2025

The lodget, trace and max_eigval plots in the AL documentation look wrong after the refactor here (they barely change). I started trying to figure out what's happening and have a sense that the predicted uncertainty is narrowed when using a GP wrapped inside a TransformedEmulator (even without any transforms) vs just a GP. I need to investigate this more formally but we need to understand what's happening before we can merge this.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

I don't know what the issue is yet but my previous comment about the uncertainty from TransformedEmulator being narrower was wrong. I was comparing GP vs TransformedEmulator with GP using different learning rates. Once the same learning rate was used they look visually identical.

@sgreenbury
Copy link
Collaborator

It might be related to whether posterior_predictive=True is being passed to the reinitialized GP when within the TransformedEmulator?

For example, on main in the dim reduction tutorial:
https://github.com/alan-turing-institute/autoemulate/blob/6d4a92fdcb2614b5dee5f907855e7003503c0910/docs/tutorials/emulation/02_dim_reduction.ipynb

em = ae.fit_from_reinitialized(x[train_idx], y[train_idx])

has:

print(em.model.posterior_predictive)
False

though the original AutoEmulate initialization having posterior_predictive=True.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

Thank you for checking! In this case the posterior_predictive is correctly set to True after the emulator is re-initialized each time.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

@sgreenbury I'm also not sure if you saw my previous comment but the uncertainty output from TransformedEmulator seems to be fine.

@radka-j
Copy link
Member Author

radka-j commented Oct 14, 2025

@sgreenbury I tried running the AL notebook using a GP wrapped inside a TransformedEmulator but calling emulator.fit instead of fit_from_reinitialized as originally implemented and the results look the very similar to the current docs. So it looks like the issue comes from re-initializing the emulator. Given the GP is refitting 1 data point at a time, this might be a case where calling fit with the hyperparameters fixed might actually make sense.

I therefore decided to revert this change and leave AL as is in this PR (only updating typing). We can separately decide whether to leave the associated issue (#757) open to revisit at some later point or close.

Comment on lines +199 to 200
if isinstance(output, DistributionLike):
assert isinstance(output.variance, TensorLike)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering if we might need a try/except here in general - e.g. if there are cases when model.predict() returns a Distribution that does not have mean or variance? This can happen with TransformedDistribution).

However, looking at the TransformedEmulator though I don't think this can happen there as either GaussianLike or Independent (base_dist Normal) is returned. So another option that would not need try/except then could be:

Suggested change
if isinstance(output, DistributionLike):
assert isinstance(output.variance, TensorLike)
if isinstance(output, GaussianLike) or (isinstance(output, torch.distributions.Independent) and isinstance(output.base_dist, torch.distributions.Normal))`:
assert isinstance(output.variance, TensorLike)

and we raise exception otherwise as not necessarily supported? Also happy for this to be considered in new issue if this needs more consideration.

Copy link
Collaborator

@sgreenbury sgreenbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks @radka-j! As we discussed:

  • it looks like choice of emulator reinitialization could be good to have in the API
  • there's an issue following our discussion capturing revisiting the overall workflow (#893)
  • the dimensionality reduction tutorial seems to not pick up the model_params={"posterior_predictive": True}

There is the comment above about DistributionLike not always having mean/variance - I don't think we'll run into this currently but might be good to either restrict here with the instance matching or have an issue for it.

Otherwise looks good to merge!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants