-
Notifications
You must be signed in to change notification settings - Fork 19
Improve emulator re-initialisation #872
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
radka-j
wants to merge
48
commits into
main
Choose a base branch
from
reinitialise
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
2ca9a94
rm kwargs from RF
radka-j 3ce9677
save mlp kwargs in ensembles
radka-j cb44c83
save all MLP and GP input params
radka-j 5346644
update HMW to work with Emulator as well Result object, rename result…
radka-j 23596b1
add option to pass transformed_emulator_params
radka-j 94b7029
Merge branch 'iss867/update_gp_factory' into reinitialise
radka-j d9503f1
make mlp_kwargs a keyword argument in MLP ensembles
radka-j 4e6e742
make scheduler_kwargs a keyword argument
radka-j c08f610
add scheduler_cls input keyword arg
radka-j e0ea75b
check x/y standardization from emulator object
radka-j dacefba
update correlated GP
radka-j 5e25da2
Merge branch 'main' into reinitialise
radka-j 8439e4a
fix scheduler kwarg passing to scheduler_setup
radka-j 38283a8
update scheduler_setup method
radka-j 7a925b2
fix test
radka-j 3c523a6
update scheduler tests
radka-j f6a2377
add reinitialize method
radka-j 5b0d9f2
add reinitialize method
radka-j 4102bbd
add tensor conversion and device handling to TransformedEmulator
radka-j d698bb8
fix var order
radka-j 36a8b64
update Emulator.fit to expect InputLike, not TensorLike
radka-j 12371a6
refactor fit_from_reinitialised function
radka-j 264d6a6
update learners tests
radka-j aa18dca
use fit_from_initialized in learners
radka-j dee04f5
revert changes in learners
radka-j e6d12fc
accept both emulator and result as keyword args to HMW
radka-j f8bf4ee
update test
radka-j e606532
revert doc change
radka-j fcd2827
revert Emulator data input types to TensorLike (not InputLike)
radka-j c10826f
rm now unnecessary tensor transform from TransformedEmulator
radka-j 41b0972
rm now unnecessary tensor transform from Emulator
radka-j e602188
update fit_from_reinitialize to expect TensorLike not InputLike
radka-j 3c8d3dd
ensure tensor is float
radka-j 61db6eb
ensure tensors are floats
radka-j 0c238b0
revert scheduler_params back to kwargs
radka-j b77e283
use convert_to_tensors method
radka-j 06c7ca0
rename scheduler_kwargs to scheduler_params
radka-j 39797ad
use fit_from_initialized in AL, change types to DistributionLike from…
radka-j a5fae9b
Update case_studies/patient_calibration/patient_calibration_case_stud…
radka-j ae942e5
update docstrings
radka-j 8498cb3
avoid code repetition
radka-j 85bf20d
revert to emulator.fit instead of fit_from_reinitialised in stream ba…
radka-j ce6ea04
Update case_studies/patient_calibration/patient_calibration_case_stud…
radka-j b5b902d
Update autoemulate/calibration/history_matching.py
radka-j bcd2939
Update autoemulate/calibration/history_matching.py
radka-j 704fe37
add option to change whether fit from reinitialized or not in AL
radka-j 43039f1
increase learning rate in AL tutorial, set posterior_predictive=False
radka-j 40d7530
Update autoemulate/calibration/history_matching.py
radka-j File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import inspect | ||
|
||
from autoemulate.core.types import DeviceLike, TensorLike | ||
from autoemulate.data.utils import set_random_seed | ||
from autoemulate.emulators import Emulator, TransformedEmulator, get_emulator_class | ||
|
||
|
||
def fit_from_reinitialized( | ||
x: TensorLike, | ||
y: TensorLike, | ||
emulator: Emulator, | ||
transformed_emulator_params: dict | None = None, | ||
device: DeviceLike | None = None, | ||
random_seed: int | None = None, | ||
): | ||
""" | ||
Fit a fresh model with reinitialized parameters using the best configuration. | ||
|
||
This method creates a new model instance with the same configuration as the | ||
best (or specified) model from the comparison, but with freshly initialized | ||
parameters fitted on the provided data. | ||
|
||
Parameters | ||
---------- | ||
x: TensorLike | ||
Input features for training the fresh model. | ||
y: TensorLike | ||
Target values for training the fresh model. | ||
emulator: Emulator | ||
An Emulator object containing the pre-trained emulator. | ||
transformed_emulator_params: None | TransformedEmulatorParams | ||
Parameters for the transformed emulator. When None, the same parameters as | ||
used when identifying the best model are used. Defaults to None. | ||
device: str | None | ||
Device to use for model fitting (e.g., 'cpu' or 'cuda'). If None, the default | ||
device is used. Defaults to None. | ||
random_seed: int | None | ||
Random seed for parameter initialization. Defaults to None. | ||
|
||
Returns | ||
------- | ||
TransformedEmulator | ||
A new model instance with the same configuration but fresh parameters | ||
fitted on the provided data. | ||
|
||
Notes | ||
----- | ||
Unlike TransformedEmulator.refit() which retrains an existing model, | ||
this method creates a completely new model instance with reinitialized | ||
parameters. This ensures that when fitting on new data that the same | ||
initialization conditions are applied. This can have an affect for example | ||
given kernel initialization in Gaussian Processes or weight initialization in | ||
neural networks. | ||
""" | ||
if random_seed is not None: | ||
set_random_seed(seed=random_seed) | ||
|
||
# Extract emulator and its parameters from Emulator instance | ||
if isinstance(emulator, TransformedEmulator): | ||
model = emulator.model | ||
emulator_name = emulator.untransformed_model_name | ||
x_transforms = emulator.x_transforms | ||
y_transforms = emulator.y_transforms | ||
else: | ||
model = emulator | ||
emulator_name = emulator.model_name() | ||
x_transforms = None | ||
y_transforms = None | ||
|
||
# Extract parameters from the provided emulator instance | ||
model_cls = get_emulator_class(emulator_name) | ||
init_sig = inspect.signature(model_cls.__init__) | ||
emulator_params = {} | ||
for param_name in init_sig.parameters: | ||
if param_name in ["self", "x", "y", "device"]: | ||
continue | ||
# NOTE: some emulators have standardize_x/y params option | ||
# this is different to TransformedEmulator x/y transforms | ||
if param_name == "standardize_x": | ||
emulator_params["standardize_x"] = bool(model.x_transform) | ||
if param_name == "standardize_y": | ||
emulator_params["standardize_y"] = bool(model.y_transform) | ||
if hasattr(model, param_name): | ||
emulator_params[param_name] = getattr(model, param_name) | ||
|
||
transformed_emulator_params = transformed_emulator_params or {} | ||
|
||
new_emulator = TransformedEmulator( | ||
x.float(), | ||
y.float(), | ||
model=model_cls, | ||
x_transforms=x_transforms, | ||
y_transforms=y_transforms, | ||
device=device, | ||
**emulator_params, | ||
) | ||
|
||
new_emulator.fit(x.float(), y.float()) | ||
return new_emulator |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.