diff --git a/tests/unit/test_ask_tell_optimization.py b/tests/unit/test_ask_tell_optimization.py index c3aeab899..037d60dd3 100644 --- a/tests/unit/test_ask_tell_optimization.py +++ b/tests/unit/test_ask_tell_optimization.py @@ -189,7 +189,7 @@ def test_ask_tell_optimizer_returns_complete_state( assert_datasets_allclose(state.record.dataset, init_dataset) assert isinstance(state.record.model, type(model)) assert state.record.acquisition_state is None - assert state.local_data_ixs is not None + assert isinstance(state.local_data_ixs, Sequence) assert state.local_data_len == 2 npt.assert_array_equal( state.local_data_ixs, @@ -229,8 +229,8 @@ def test_ask_tell_optimizer_loads_from_state( assert_datasets_allclose(new_state.record.dataset, old_state.record.dataset) assert old_state.record.model is new_state.record.model - assert new_state.local_data_ixs is not None - assert old_state.local_data_ixs is not None + assert isinstance(new_state.local_data_ixs, Sequence) + assert isinstance(old_state.local_data_ixs, Sequence) npt.assert_array_equal(new_state.local_data_ixs, old_state.local_data_ixs) assert old_state.local_data_len == new_state.local_data_len == len(init_dataset.query_points) @@ -948,15 +948,13 @@ def test_ask_tell_optimizer_dataset_len_variables( assert AskTellOptimizer.dataset_len({"tag1": dataset, "tag2": dataset}) == 2 -def test_ask_tell_optimizer_dataset_len_raises_on_inconsistently_sized_datasets( +def test_ask_tell_optimizer_dataset_len_returns_dict_on_inconsistently_sized_datasets( init_dataset: Dataset, ) -> None: - with pytest.raises(ValueError): - AskTellOptimizer.dataset_len( - {"tag": init_dataset, "empty": Dataset(tf.zeros([0, 2]), tf.zeros([0, 2]))} - ) - with pytest.raises(ValueError): - AskTellOptimizer.dataset_len({}) + assert AskTellOptimizer.dataset_len( + {"tag": init_dataset, "empty": Dataset(tf.zeros([0, 2]), tf.zeros([0, 2]))} + ) == {"tag": 2, "empty": 0} + assert AskTellOptimizer.dataset_len({}) == {} @pytest.mark.parametrize("optimizer", OPTIMIZERS) diff --git a/trieste/acquisition/utils.py b/trieste/acquisition/utils.py index 91ffae7cb..8d6034c34 100644 --- a/trieste/acquisition/utils.py +++ b/trieste/acquisition/utils.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import copy import functools from typing import Dict, Mapping, Optional, Sequence, Tuple, Union @@ -162,7 +164,9 @@ def copy_to_local_models( def with_local_datasets( datasets: Mapping[Tag, Dataset], num_local_datasets: int, - local_dataset_indices: Optional[Sequence[TensorType]] = None, + local_dataset_indices: Optional[ + Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]] + ] = None, ) -> Dict[Tag, Dataset]: """ Helper method to add local datasets if they do not already exist, by copying global datasets @@ -174,17 +178,22 @@ def with_local_datasets( the global datasets should be copied. If None then the entire datasets are copied. :return: The updated mapping of datasets. """ - if local_dataset_indices is not None and len(local_dataset_indices) != num_local_datasets: - raise ValueError( - f"local_dataset_indices should have {num_local_datasets} entries, " - f"has {len(local_dataset_indices)}" - ) + if isinstance(local_dataset_indices, Sequence): + local_dataset_indices = {tag: local_dataset_indices for tag in datasets} updated_datasets = {} for tag in datasets: updated_datasets[tag] = datasets[tag] ltag = LocalizedTag.from_tag(tag) if not ltag.is_local: + if local_dataset_indices is not None: + if tag not in local_dataset_indices: + raise ValueError(f"local_dataset_indices missing tag {tag}") + elif len(local_dataset_indices[tag]) != num_local_datasets: + raise ValueError( + f"local_dataset_indices for tag {tag} should have {num_local_datasets} " + f"entries, but has {len(local_dataset_indices[tag])}" + ) for i in range(num_local_datasets): target_ltag = LocalizedTag(ltag.global_tag, i) if target_ltag not in datasets: @@ -194,10 +203,10 @@ def with_local_datasets( # TODO: use sparse tensors instead updated_datasets[target_ltag] = Dataset( query_points=tf.gather( - datasets[tag].query_points, local_dataset_indices[i] + datasets[tag].query_points, local_dataset_indices[tag][i] ), observations=tf.gather( - datasets[tag].observations, local_dataset_indices[i] + datasets[tag].observations, local_dataset_indices[tag][i] ), ) diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index f2c2da6d5..858076e76 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -82,7 +82,7 @@ class AskTellOptimizerState(Generic[StateType, ProbabilisticModelType]): record: Record[StateType, ProbabilisticModelType] """ A record of the current state of the optimization. """ - local_data_ixs: Optional[Sequence[TensorType]] + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] """ Indices to the local data, for LocalDatasetsAcquisitionRule rules when `track_data` is `False`. """ @@ -108,7 +108,7 @@ def __init__( *, fit_model: bool = True, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ): ... @@ -122,7 +122,7 @@ def __init__( *, fit_model: bool = True, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ): ... @@ -139,7 +139,7 @@ def __init__( *, fit_model: bool = True, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ): ... @@ -152,7 +152,7 @@ def __init__( *, fit_model: bool = True, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ): ... @@ -166,7 +166,7 @@ def __init__( *, fit_model: bool = True, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ): ... @@ -183,7 +183,7 @@ def __init__( *, fit_model: bool = True, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ): ... @@ -204,7 +204,7 @@ def __init__( *, fit_model: bool = True, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ): """ @@ -225,9 +225,12 @@ def __init__( updates to the global datasets (optionally using `local_data_ixs` and indices passed in to `tell`). :param local_data_ixs: Indices to the local data in the initial datasets. If unspecified, - assumes that the initial datasets are global. + assumes that the initial datasets are global. Can a be a single sequence for all + datasets, or a mapping with separate values for each dataset. :param local_data_len: Optional length of the data when the passed in `local_data_ixs` were measured. If the data has increased since then, the indices are extended. + (Note that this is only supported when all datasets have the same length. If not, + then it is up to the caller to update the indices before initialization.) :raise ValueError: If any of the following are true: - the keys in ``datasets`` and ``models`` do not match - ``datasets`` or ``models`` are empty @@ -287,12 +290,41 @@ def __init__( if self.track_data: datasets = self._datasets = with_local_datasets(self._datasets, num_local_datasets) else: - self._dataset_len = self.dataset_len(self._datasets) - if local_data_ixs is not None: + dataset_len = self.dataset_len(self._datasets) + self._dataset_len = dataset_len if isinstance(dataset_len, int) else None + self._dataset_ixs: list[TensorType] | Mapping[Tag, list[TensorType]] + + if local_data_ixs is None: + # assume that the initial datasets are global + if isinstance(dataset_len, int): + self._dataset_ixs = [ + tf.range(dataset_len) for _ in range(num_local_datasets) + ] + else: + self._dataset_ixs = { + t: [tf.range(l) for _ in range(num_local_datasets)] + for t, l in dataset_len.items() + } + + elif isinstance(local_data_ixs, Mapping): + self._dataset_ixs = {t: list(ixs) for t, ixs in local_data_ixs.items()} + if local_data_len is not None: + raise ValueError( + "Cannot infer new data points for datasets with different " + "local data indices. Pass in full indices instead." + ) + + else: self._dataset_ixs = list(local_data_ixs) + if local_data_len is not None: # infer new dataset indices from change in dataset sizes - num_new_points = self._dataset_len - local_data_len + if isinstance(dataset_len, Mapping): + raise ValueError( + "Cannot infer new data points for datasets with different " + "lengths. Pass in full indices instead." + ) + num_new_points = dataset_len - local_data_len if num_new_points < 0 or ( num_local_datasets > 0 and num_new_points % num_local_datasets != 0 ): @@ -310,10 +342,6 @@ def __init__( ], -1, ) - else: - self._dataset_ixs = [ - tf.range(self._dataset_len) for _ in range(num_local_datasets) - ] datasets = with_local_datasets( self._datasets, num_local_datasets, self._dataset_ixs @@ -375,7 +403,7 @@ def dataset(self) -> Dataset: raise ValueError(f"Expected a single dataset, found {len(datasets)}") @property - def local_data_ixs(self) -> Optional[Sequence[TensorType]]: + def local_data_ixs(self) -> Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]]: """Indices to the local data. Only stored for LocalDatasetsAcquisitionRule rules when `track_data` is `False`.""" if isinstance(self._acquisition_rule, LocalDatasetsAcquisitionRule) and not self.track_data: @@ -433,8 +461,8 @@ def acquisition_state(self) -> StateType | None: return self._acquisition_state @classmethod - def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int: - """Helper method for inferring the global dataset size.""" + def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int | Mapping[Tag, int]: + """Helper method for inferring the global dataset size(s).""" dataset_lens = { tag: int(tf.shape(dataset.query_points)[0]) for tag, dataset in datasets.items() @@ -444,9 +472,7 @@ def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int: if len(unique_lens) == 1: return int(unique_lens[0]) else: - raise ValueError( - f"Expected unique global dataset size, got {unique_lens}: {dataset_lens}" - ) + return dataset_lens @classmethod def from_record( @@ -465,7 +491,7 @@ def from_record( | None ) = None, track_data: bool = True, - local_data_ixs: Optional[Sequence[TensorType]] = None, + local_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, local_data_len: Optional[int] = None, ) -> AskTellOptimizerType: """Creates new :class:`~AskTellOptimizer` instance from provided optimization state. @@ -634,14 +660,15 @@ def ask(self) -> TensorType: def tell( self, new_data: Mapping[Tag, Dataset] | Dataset, - new_data_ixs: Optional[Sequence[TensorType]] = None, + new_data_ixs: Optional[Sequence[TensorType] | Mapping[Tag, Sequence[TensorType]]] = None, ) -> None: """Updates optimizer state with new data. :param new_data: New observed data. If `track_data` is `False`, this refers to all the data. :param new_data_ixs: Indices to the new observed local data, if `track_data` is `False`. - If unspecified, inferred from the change in dataset sizes. + If unspecified, inferred from the change in dataset sizes (as long as all the + datasets have the same size). :raise ValueError: If keys in ``new_data`` do not match those in already built dataset. """ if isinstance(new_data, Dataset): @@ -670,10 +697,45 @@ def tell( elif not isinstance(self._acquisition_rule, LocalDatasetsAcquisitionRule): datasets = new_data else: - num_local_datasets = len(self._dataset_ixs) - if new_data_ixs is None: + num_local_datasets = ( + len(self._dataset_ixs) + if isinstance(self._dataset_ixs, Sequence) + else len(next(iter(self._dataset_ixs.values()))) + ) + + if new_data_ixs is not None: + # use explicit indices + def update_ixs(ixs: list[TensorType], new_ixs: Sequence[TensorType]) -> None: + if len(ixs) != len(new_ixs): + raise ValueError( + f"new_data_ixs has {len(new_ixs)} entries, expected {len(ixs)}" + ) + for i in range(len(ixs)): + ixs[i] = tf.concat([ixs[i], new_ixs[i]], -1) + + if isinstance(new_data_ixs, Sequence) and isinstance(self._dataset_ixs, Mapping): + raise ValueError("separate new_data_ixs required for each dataset") + if isinstance(new_data_ixs, Mapping) and isinstance(self._dataset_ixs, Sequence): + self._dataset_ixs = {tag: list(self._dataset_ixs) for tag in self._datasets} + if isinstance(new_data_ixs, Mapping): + assert isinstance(self._dataset_ixs, Mapping) + for tag in self._datasets: + update_ixs(self._dataset_ixs[tag], new_data_ixs[tag]) + else: + assert isinstance(self._dataset_ixs, list) + update_ixs(self._dataset_ixs, new_data_ixs) + + else: # infer dataset indices from change in dataset sizes + if isinstance(self._dataset_ixs, Mapping) or not isinstance(self._dataset_len, int): + raise NotImplementedError( + "new data indices cannot be inferred for datasets with different sizes" + ) new_dataset_len = self.dataset_len(new_data) + if not isinstance(new_dataset_len, int): + raise NotImplementedError( + "new data indices cannot be inferred for new data with different sizes" + ) num_new_points = new_dataset_len - self._dataset_len if num_new_points < 0 or ( num_local_datasets > 0 and num_new_points % num_local_datasets != 0 @@ -690,17 +752,10 @@ def tell( ], -1, ) - else: - # use explicit indices - if len(new_data_ixs) != num_local_datasets: - raise ValueError( - f"new_data_ixs has {len(new_data_ixs)} entries, " - f"expected {num_local_datasets}" - ) - for i in range(num_local_datasets): - self._dataset_ixs[i] = tf.concat([self._dataset_ixs[i], new_data_ixs[i]], -1) + datasets = with_local_datasets(new_data, num_local_datasets, self._dataset_ixs) - self._dataset_len = self.dataset_len(datasets) + dataset_len = self.dataset_len(datasets) + self._dataset_len = dataset_len if isinstance(dataset_len, int) else None filtered_datasets = self._acquisition_rule.filter_datasets(self._models, datasets) if callable(filtered_datasets):