-
Notifications
You must be signed in to change notification settings - Fork 19
Improve emulator re-initialisation #872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Just adding a note here as ran into this when working with a GP subclass for the error quantification. This call: autoemulate/autoemulate/core/compare.py Line 578 in 76689ae
fails since: autoemulate/autoemulate/emulators/__init__.py Lines 68 to 70 in 76689ae
doesn't also look at: autoemulate/autoemulate/emulators/gaussian_process/exact.py Lines 460 to 463 in 76689ae
@radka-j - adding here as it might be addressed by the upcoming changes to this API? But if not happy to open a new issue to look at this. An option could also be to revisit having a central registry class to handle this uniformly. |
@sgreenbury I don't think we should ever use the |
It was the GP context (passing a But thinking more about it, it affects any subclass used by class SimpleFNN(PyTorchBackend):
...
ae = AutoEmulate(x, y, models=[SimpleFNN])
ae.fit_from_reinitialized(x, y) since I think if the emulator becomes the entity that does the refitting in this PR then a global emulator registry including all custom subclasses would not be needed for this but might still be useful? |
… GaussianLike to match TranformedEmulator predict type
The lodget, trace and max_eigval plots in the AL documentation look wrong after the refactor here (they barely change). I started trying to figure out what's happening and have a sense that the predicted uncertainty is narrowed when using a GP wrapped inside a |
I don't know what the issue is yet but my previous comment about the uncertainty from |
case_studies/patient_calibration/patient_calibration_case_study.ipynb
Outdated
Show resolved
Hide resolved
It might be related to whether For example, on main in the dim reduction tutorial: em = ae.fit_from_reinitialized(x[train_idx], y[train_idx]) has:
though the original |
Thank you for checking! In this case the |
@sgreenbury I'm also not sure if you saw my previous comment but the uncertainty output from |
@sgreenbury I tried running the AL notebook using a GP wrapped inside a TransformedEmulator but calling I therefore decided to revert this change and leave AL as is in this PR (only updating typing). We can separately decide whether to leave the associated issue (#757) open to revisit at some later point or close. |
case_studies/patient_calibration/patient_calibration_case_study.ipynb
Outdated
Show resolved
Hide resolved
Co-authored-by: Sam Greenbury <[email protected]>
if isinstance(output, DistributionLike): | ||
assert isinstance(output.variance, TensorLike) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was wondering if we might need a try/except here in general - e.g. if there are cases when model.predict()
returns a Distribution
that does not have mean or variance? This can happen with TransformedDistribution
).
However, looking at the TransformedEmulator
though I don't think this can happen there as either GaussianLike
or Independent
(base_dist Normal
) is returned. So another option that would not need try/except then could be:
if isinstance(output, DistributionLike): | |
assert isinstance(output.variance, TensorLike) | |
if isinstance(output, GaussianLike) or (isinstance(output, torch.distributions.Independent) and isinstance(output.base_dist, torch.distributions.Normal))`: | |
assert isinstance(output.variance, TensorLike) |
and we raise exception otherwise as not necessarily supported? Also happy for this to be considered in new issue if this needs more consideration.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks @radka-j! As we discussed:
- it looks like choice of emulator reinitialization could be good to have in the API
- there's an issue following our discussion capturing revisiting the overall workflow (#893)
- the dimensionality reduction tutorial seems to not pick up the
model_params={"posterior_predictive": True}
There is the comment above about DistributionLike
not always having mean/variance - I don't think we'll run into this currently but might be good to either restrict here with the instance matching or have an issue for it.
Otherwise looks good to merge!
Closes #748
Closes #874
Closes #878
Closes #757
This PR:
fit_from_reinitialised
method that is used both inAutoEmulate.compare
andHMW.refit_emulator
**kwargs
in emulators withscheduler_kwargs
optional keyword argument to match useEmulator.fit
to handleInputLike
instead of expecting onlyTensorLike
DistributionLike
rather thanGaussianLike
to matchTransformedEmulator
prediction typesfit_from_reinitialized