diff --git a/.gitignore b/.gitignore index c28eebef8..6777e4d97 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,5 @@ site example_eo example_mri .mypy_cache -*.req \ No newline at end of file +*.req +serializedTree \ No newline at end of file diff --git a/indexTree.capnp b/indexTree.capnp new file mode 100644 index 000000000..e85939d25 --- /dev/null +++ b/indexTree.capnp @@ -0,0 +1,18 @@ +@0xae1f1be0650fec43; + +struct Value { + # NEED TO DO THIS STILL + value :union { + strVal @0 :Text; + intVal @1 :Int64; + doubleVal @2 :Float64; + } + } + +struct Node { + axis @0 :Text; + value @1 :List(Value); + result @2 :List(Float64); + resultSize @3 :List(Int64); + children @4 :List(Node); +} \ No newline at end of file diff --git a/polytope/datacube/backends/datacube.py b/polytope/datacube/backends/datacube.py index 6b1f7107c..9b022f687 100644 --- a/polytope/datacube/backends/datacube.py +++ b/polytope/datacube/backends/datacube.py @@ -7,6 +7,7 @@ from ...utility.combinatorics import validate_axes from ..datacube_axis import DatacubeAxis from ..index_tree import DatacubePath, IndexTree +from ..transformations.datacube_mappers.datacube_mappers import DatacubeMapper from ..transformations.datacube_transformations import ( DatacubeAxisTransformation, has_transform, @@ -31,6 +32,7 @@ def __init__(self, axis_options=None, datacube_options=None): self.nearest_search = {} self._axes = None self.transformed_axes = [] + self.compressed_grid_axes = [] @abstractmethod def get(self, requests: IndexTree) -> Any: @@ -54,6 +56,9 @@ def _create_axes(self, name, values, transformation_type_key, transformation_opt ) for blocked_axis in transformation.blocked_axes(): self.blocked_axes.append(blocked_axis) + if isinstance(transformation, DatacubeMapper): + for compressed_grid_axis in transformation.compressed_grid_axes: + self.compressed_grid_axes.append(compressed_grid_axis) if len(final_axis_names) > 1: self.coupled_axes.append(final_axis_names) for axis_name in final_axis_names: diff --git a/polytope/datacube/backends/fdb.py b/polytope/datacube/backends/fdb.py index 30c139935..78ab10099 100644 --- a/polytope/datacube/backends/fdb.py +++ b/polytope/datacube/backends/fdb.py @@ -1,6 +1,8 @@ import logging from copy import deepcopy +from itertools import product +import numpy as np import pygribjump as pygj from ...utility.geometry import nearest_pt @@ -47,8 +49,46 @@ def get(self, requests: IndexTree): fdb_requests = [] fdb_requests_decoding_info = [] self.get_fdb_requests(requests, fdb_requests, fdb_requests_decoding_info) - output_values = self.gj.extract(fdb_requests) - self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info) + + # TODO: note that this doesn't exactly work as intended, it's just going to retrieve value from gribjump that + # corresponds to first value in the compressed tuples + + complete_branch_combi_sizes = [] + output_values = [] + for request in fdb_requests: + interm_branch_tuple_values = [] + for key in request[0].keys(): + # remove the tuple of the request when we ask the fdb + interm_branch_tuple_values.append(request[0][key]) + request[0][key] = request[0][key][0] + branch_tuple_combi = product(*interm_branch_tuple_values) + # TODO: now build the relevant requests from this and ask gj for them + # TODO: then group the output values together to fit back with the original compressed request and continue + new_requests = [] + for combi in branch_tuple_combi: + new_request = {} + for i, key in enumerate(request[0].keys()): + new_request[key] = combi[i] + new_requests.append((new_request, request[1])) + branch_output_values = self.gj.extract(new_requests) + branch_combi_sizes = [len(t) for t in interm_branch_tuple_values] + + all_remapped_output_values = [] + for k, req in enumerate(new_requests): + output = branch_output_values[k][0] + output_dict = {} + for i, o in enumerate(output): + output_dict[i] = o[0] + + all_remapped_output_values.append(output_dict) + + output_data_branch = [] + output_data_branch = np.array(all_remapped_output_values) + output_data_branch = np.reshape(output_data_branch, tuple(branch_combi_sizes)) + output_values.append([output_data_branch]) + complete_branch_combi_sizes.append([list(range(b)) for b in branch_combi_sizes]) + + self.assign_fdb_output_to_nodes(output_values, fdb_requests_decoding_info, complete_branch_combi_sizes) def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_decoding_info=[], leaf_path=None): if leaf_path is None: @@ -62,7 +102,7 @@ def get_fdb_requests(self, requests: IndexTree, fdb_requests=[], fdb_requests_de self.get_fdb_requests(c, fdb_requests, fdb_requests_decoding_info) # If request node has no children, we have a leaf so need to assign fdb values to it else: - key_value_path = {requests.axis.name: requests.value} + key_value_path = {requests.axis.name: requests.values} ax = requests.axis (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( key_value_path, leaf_path, self.unwanted_path @@ -112,7 +152,7 @@ def get_2nd_last_values(self, requests, leaf_path=None): found_latlon_pts = [] for lat_child in requests.children: for lon_child in lat_child.children: - found_latlon_pts.append([lat_child.value, lon_child.value]) + found_latlon_pts.append([lat_child.values, lon_child.values]) # now find the nearest lat lon to the points requested nearest_latlons = [] @@ -121,20 +161,21 @@ def get_2nd_last_values(self, requests, leaf_path=None): nearest_latlons.append(nearest_latlon) # need to remove the branches that do not fit - lat_children_values = [child.value for child in requests.children] + lat_children_values = [child.values for child in requests.children] for i in range(len(lat_children_values)): lat_child_val = lat_children_values[i] - lat_child = [child for child in requests.children if child.value == lat_child_val][0] - if lat_child.value not in [latlon[0] for latlon in nearest_latlons]: + lat_child = [child for child in requests.children if child.values == lat_child_val][0] + if lat_child.values not in [(latlon[0],) for latlon in nearest_latlons]: lat_child.remove_branch() else: - possible_lons = [latlon[1] for latlon in nearest_latlons if latlon[0] == lat_child.value] - lon_children_values = [child.value for child in lat_child.children] + possible_lons = [latlon[1] for latlon in nearest_latlons if (latlon[0],) == lat_child.values] + lon_children_values = [child.values for child in lat_child.children] for j in range(len(lon_children_values)): lon_child_val = lon_children_values[j] - lon_child = [child for child in lat_child.children if child.value == lon_child_val][0] - if lon_child.value not in possible_lons: - lon_child.remove_branch() + lon_child = [child for child in lat_child.children if child.values == lon_child_val][0] + for value in lon_child.values: + if value not in possible_lons: + lon_child.remove_compressed_branch(value) lat_length = len(requests.children) range_lengths = [False] * lat_length @@ -149,7 +190,7 @@ def get_2nd_last_values(self, requests, leaf_path=None): range_length = deepcopy(range_lengths[i]) current_start_idx = deepcopy(current_start_idxs[i]) fdb_range_nodes = deepcopy(fdb_node_ranges[i]) - key_value_path = {lat_child.axis.name: lat_child.value} + key_value_path = {lat_child.axis.name: lat_child.values} ax = lat_child.axis (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( key_value_path, leaf_path, self.unwanted_path @@ -160,14 +201,14 @@ def get_2nd_last_values(self, requests, leaf_path=None): ) leaf_path_copy = deepcopy(leaf_path) - leaf_path_copy.pop("values") + leaf_path_copy.pop("values", None) return (leaf_path_copy, range_lengths, current_start_idxs, fdb_node_ranges, lat_length) def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, fdb_range_n): i = 0 for c in requests.children: # now c are the leaves of the initial tree - key_value_path = {c.axis.name: c.value} + key_value_path = {c.axis.name: c.values} ax = c.axis (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( key_value_path, leaf_path, self.unwanted_path @@ -182,7 +223,7 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, range_l[i] += 1 fdb_range_n[i][range_l[i] - 1] = c else: - key_value_path = {c.axis.name: c.value} + key_value_path = {c.axis.name: c.values} ax = c.axis (key_value_path, leaf_path, self.unwanted_path) = ax.unmap_path_key( key_value_path, leaf_path, self.unwanted_path @@ -193,8 +234,10 @@ def get_last_layer_before_leaf(self, requests, leaf_path, range_l, current_idx, current_idx[i] = current_start_idx return (range_l, current_idx, fdb_range_n) - def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): + def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info, complete_branch_combi_sizes): for k in range(len(output_values)): + combi_sizes = complete_branch_combi_sizes[k] + combi_sizes_combis = list(product(*combi_sizes)) request_output_values = output_values[k] ( original_indices, @@ -215,7 +258,13 @@ def assign_fdb_output_to_nodes(self, output_values, fdb_requests_decoding_info): for i in range(len(sorted_fdb_range_nodes)): for j in range(sorted_range_lengths[i]): n = sorted_fdb_range_nodes[i][j] - n.result = request_output_values[0][i][0][j] + for size_combi in list(combi_sizes_combis): + interm_output_values = request_output_values[0] + # TODO: the result associated to nodes is still only a simple float, not an array and is not + # the right one... + for val in size_combi: + interm_output_values = interm_output_values[val] + n.result = interm_output_values[i][j] def sort_fdb_request_ranges(self, range_lengths, current_start_idx, lat_length): interm_request_ranges = [] diff --git a/polytope/datacube/backends/mock.py b/polytope/datacube/backends/mock.py index e412d98c3..f34bbea9e 100644 --- a/polytope/datacube/backends/mock.py +++ b/polytope/datacube/backends/mock.py @@ -31,7 +31,7 @@ def get(self, requests: IndexTree): for r in requests.leaves: path = r.flatten() if len(path.items()) == len(self.dimensions.items()): - result = 0 + result = (0,) for k, v in path.items(): result += v * self.stride[k] diff --git a/polytope/datacube/backends/xarray.py b/polytope/datacube/backends/xarray.py index 459db272f..b00c5351e 100644 --- a/polytope/datacube/backends/xarray.py +++ b/polytope/datacube/backends/xarray.py @@ -41,23 +41,26 @@ def get(self, requests: IndexTree): for r in requests.leaves: path = r.flatten() if len(path.items()) == self.axis_counter: - # first, find the grid mapper transform + # TODO: need to undo the tuples in the path into actual paths with a single value that xarray can read unmapped_path = {} path_copy = deepcopy(path) for key in path_copy: axis = self._axes[key] key_value_path = {key: path_copy[key]} - # (path, unmapped_path) = axis.unmap_to_datacube(path, unmapped_path) (key_value_path, path, unmapped_path) = axis.unmap_path_key(key_value_path, path, unmapped_path) path.update(key_value_path) - path.update(unmapped_path) unmapped_path = {} self.refit_path(path, unmapped_path, path) + for key in path: + path[key] = list(path[key]) + for key in unmapped_path: + if isinstance(unmapped_path[key], tuple): + unmapped_path[key] = list(unmapped_path[key]) subxarray = self.dataarray.sel(path, method="nearest") subxarray = subxarray.sel(unmapped_path) - value = subxarray.item() + value = subxarray.values key = subxarray.name r.result = (key, value) else: @@ -93,6 +96,12 @@ def refit_path(self, path_copy, unmapped_path, path): path_copy.pop(key, None) def select(self, path, unmapped_path): + for key in path: + key_value = path[key][0] + path[key] = key_value + for key in unmapped_path: + key_value = unmapped_path[key][0] + unmapped_path[key] = key_value path_copy = deepcopy(path) self.refit_path(path_copy, unmapped_path, path) subarray = self.dataarray.sel(path_copy, method="nearest") diff --git a/polytope/datacube/datacube_axis.py b/polytope/datacube/datacube_axis.py index 4314fe7e9..938864e4a 100644 --- a/polytope/datacube/datacube_axis.py +++ b/polytope/datacube/datacube_axis.py @@ -173,6 +173,7 @@ def __init__(self): # TODO: Maybe here, store transformations as a dico instead self.transformations = [] self.type = 0 + self.can_round = True def parse(self, value: Any) -> Any: return float(value) @@ -194,6 +195,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = 0.0 + self.can_round = True def parse(self, value: Any) -> Any: return float(value) @@ -215,6 +217,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = pd.Timestamp("2000-01-01T00:00:00") + self.can_round = False def parse(self, value: Any) -> Any: if isinstance(value, np.str_): @@ -244,6 +247,7 @@ def __init__(self): self.range = None self.transformations = [] self.type = np.timedelta64(0, "s") + self.can_round = False def parse(self, value: Any) -> Any: if isinstance(value, np.str_): @@ -272,6 +276,7 @@ def __init__(self): self.tol = float("NaN") self.range = None self.transformations = [] + self.can_round = False def parse(self, value: Any) -> Any: return value diff --git a/polytope/datacube/index_tree.capnp b/polytope/datacube/index_tree.capnp new file mode 100644 index 000000000..745e48b0f --- /dev/null +++ b/polytope/datacube/index_tree.capnp @@ -0,0 +1,18 @@ + + +struct Value { + # NEED TO DO THIS STILL + value :union { + str_val @0 :Text; + int_val @1 :Int64; + double_val @2 :Float64; + } + } + +struct Node { + axis @0 :Text; + value @1 :List(Value); + result @2 :List(Float64); + result_size @3 :List(Int64); + children @4 :List(Node); +} \ No newline at end of file diff --git a/polytope/datacube/index_tree.proto b/polytope/datacube/index_tree.proto new file mode 100644 index 000000000..f4989a9a6 --- /dev/null +++ b/polytope/datacube/index_tree.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package index_tree; + +message Value { + oneof value { + string str_val = 1; + int64 int_val = 2; + double double_val = 3; + } + } + +message Node { + string axis = 1; + repeated Value value = 2; + repeated double result = 3; + repeated int64 result_size = 4; + repeated Node children = 5; +} + diff --git a/polytope/datacube/index_tree.py b/polytope/datacube/index_tree.py index 2afb8416a..22dbb38de 100644 --- a/polytope/datacube/index_tree.py +++ b/polytope/datacube/index_tree.py @@ -54,13 +54,6 @@ def copy_children_from_other(self, other): c.copy_children_from_other(o) return - def pprint_2(self, level=0): - if self.axis.name == "root": - print("\n") - print("\t" * level + "\u21b3" + str(self)) - for child in self.children: - child.pprint_2(level + 1) - def _collect_leaf_nodes_old(self, leaves): if len(self.children) == 0: leaves.append(self) diff --git a/polytope/datacube/index_tree_pb2.py b/polytope/datacube/index_tree_pb2.py new file mode 100644 index 000000000..dc04bc0e9 --- /dev/null +++ b/polytope/datacube/index_tree_pb2.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: index_tree.proto +# Protobuf Python Version: 4.25.3 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10index_tree.proto\x12\nindex_tree\"L\n\x05Value\x12\x11\n\x07str_val\x18\x01 \x01(\tH\x00\x12\x11\n\x07int_val\x18\x02 \x01(\x03H\x00\x12\x14\n\ndouble_val\x18\x03 \x01(\x01H\x00\x42\x07\n\x05value\"\x7f\n\x04Node\x12\x0c\n\x04\x61xis\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x03(\x0b\x32\x11.index_tree.Value\x12\x0e\n\x06result\x18\x03 \x03(\x01\x12\x13\n\x0bresult_size\x18\x04 \x03(\x03\x12\"\n\x08\x63hildren\x18\x05 \x03(\x0b\x32\x10.index_tree.Nodeb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'index_tree_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _globals['_VALUE']._serialized_start=32 + _globals['_VALUE']._serialized_end=108 + _globals['_NODE']._serialized_start=110 + _globals['_NODE']._serialized_end=237 +# @@protoc_insertion_point(module_scope) diff --git a/polytope/datacube/tensor_index_tree.py b/polytope/datacube/tensor_index_tree.py new file mode 100644 index 000000000..fecfff3b7 --- /dev/null +++ b/polytope/datacube/tensor_index_tree.py @@ -0,0 +1,222 @@ +import copy +import logging + +from sortedcontainers import SortedList + +from .datacube_axis import IntDatacubeAxis, UnsliceableDatacubeAxis +from .index_tree import DatacubePath + + +class TensorIndexTree(object): + root = IntDatacubeAxis() + root.name = "root" + + def __init__(self, axis=root, values=tuple()): + # NOTE: the values here is a tuple so we can hash it + self.values = values + self.children = SortedList() + self._parent = None + self.result = None + self.axis = axis + self.ancestors = [] + self.result_size = [] + + @property + def leaves(self): + leaves = [] + self._collect_leaf_nodes(leaves) + return leaves + + @property + def leaves_with_ancestors(self): + leaves = [] + self._collect_leaf_nodes(leaves) + return leaves + + def copy_children_from_other(self, other): + for o in other.children: + c = TensorIndexTree(o.axis, copy.copy(o.value)) + self.add_child(c) + c.copy_children_from_other(o) + return + + def _collect_leaf_nodes(self, leaves): + # NOTE: leaves_and_ancestors is going to be a list of tuples, where first entry is leaf and second entry is a + # list of its ancestors + if len(self.children) == 0: + leaves.append(self) + self.ancestors.append(self) + for n in self.children: + for ancestor in self.ancestors: + n.ancestors.append(ancestor) + if self.axis != TensorIndexTree.root: + n.ancestors.append(self) + n._collect_leaf_nodes(leaves) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getitem__(self, key): + return getattr(self, key) + + def __delitem__(self, key): + return delattr(self, key) + + def __hash__(self): + return hash((self.axis.name, self.values)) + + def __eq__(self, other): + if not isinstance(other, TensorIndexTree): + return False + if self.axis.name != other.axis.name: + return False + else: + if other.values == self.values: + return True + else: + if isinstance(self.axis, UnsliceableDatacubeAxis): + return False + else: + if len(other.values) != len(self.values): + return False + for i in range(len(other.values)): + other_val = other.values[i] + self_val = self.values[i] + if abs(other_val - self_val) > 2 * max(other.axis.tol, self.axis.tol): + return False + return True + + def __lt__(self, other): + return (self.axis.name, self.values) < (other.axis.name, other.values) + + def __repr__(self): + if self.axis != "root": + return f"{self.axis.name}={self.values}" + else: + return f"{self.axis}" + + def add_child(self, node): + self.children.add(node) + node._parent = self + + def find_compressed_child(self, axis): + for c in self.children: + if c.axis == axis: + return c + return None + + def create_child(self, axis, value, compressed_axes, next_nodes): + # TODO: if the axis should not be compressed, just create a child with a tuple value with a single value + # TODO: if the axis should be compressed, check if we already have a child with the axis name. + # TODO: Then: if we have such a child, add to its tuple value the new value. + # TODO: Else, just create a child with a tuple value with a single value + + if axis.name not in compressed_axes: + # In this case, the child should not already exist? But you never know if the slicer hasn't found the same + # value twice? It shouldn't though? + # Can safely add the child here though to self + node = TensorIndexTree(axis, (value,)) + existing_child = self.find_child(node) + if not existing_child: + self.add_child(node) + return (node, next_nodes) + return (existing_child, next_nodes) + else: + # TODO: find the compressed child + existing_compressed_child = self.find_compressed_child(axis) + if existing_compressed_child: + # NOTE: do we even need to hash the values anymore if we implement logic to only compare children when + # we have the right compressed children? Then could have a list here for the values which is easier to + # manipulate... + new_value = list(existing_compressed_child.values) + new_value.append(value) + existing_compressed_child.values = tuple(new_value) + next_nodes.remove(existing_compressed_child) + return (existing_compressed_child, next_nodes) + else: + node = TensorIndexTree(axis, (value,)) + self.add_child(node) + return (node, next_nodes) + + @property + def parent(self): + return self._parent + + @parent.setter + def set_parent(self, node): + if self.parent is not None: + self.parent.children.remove(self) + self._parent = node + self._parent.children.add(self) + + def get_root(self): + node = self + while node.parent is not None: + node = node.parent + return node + + def is_root(self): + return self.parent is None + + def find_child(self, node): + index = self.children.bisect_left(node) + if index >= len(self.children): + return None + child = self.children[index] + if not child == node: + return None + return child + + def merge(self, other): + for other_child in other.children: + my_child = self.find_child(other_child) + if not my_child: + self.add_child(other_child) + else: + my_child.merge(other_child) + + def pprint(self, level=0): + if self.axis.name == "root": + logging.debug("\n") + logging.debug("\t" * level + "\u21b3" + str(self)) + for child in self.children: + child.pprint(level + 1) + if len(self.children) == 0: + logging.debug("\t" * (level + 1) + "\u21b3" + str(self.result)) + + def remove_branch(self): + if not self.is_root(): + old_parent = self._parent + self._parent.children.remove(self) + self._parent = None + if len(old_parent.children) == 0: + old_parent.remove_branch() + + def remove_compressed_branch(self, value): + if value in self.values: + if len(self.values) == 1: + self.remove_branch() + else: + self.values = tuple(val for val in self.values if val != value) + + def flatten(self): + path = DatacubePath() + ancestors = self.get_ancestors() + for ancestor in ancestors: + path[ancestor.axis.name] = ancestor.values + return path + + def flatten_with_ancestors(self): + path = DatacubePath() + ancestors = self.ancestors + for ancestor in ancestors: + path[ancestor.axis.name] = ancestor.values + return path + + def get_ancestors(self): + ancestors = [] + current_node = self + while current_node.axis != TensorIndexTree.root: + ancestors.append(current_node) + current_node = current_node.parent + return ancestors[::-1] diff --git a/polytope/datacube/transformations/datacube_mappers/datacube_mappers.py b/polytope/datacube/transformations/datacube_mappers/datacube_mappers.py index 6dad1d7ab..a443e7a71 100644 --- a/polytope/datacube/transformations/datacube_mappers/datacube_mappers.py +++ b/polytope/datacube/transformations/datacube_mappers/datacube_mappers.py @@ -19,6 +19,7 @@ def __init__(self, name, mapper_options): self._final_transformation = self.generate_final_transformation() self._final_mapped_axes = self._final_transformation._mapped_axes self._axis_reversed = self._final_transformation._axis_reversed + self.compressed_grid_axes = self._final_transformation.compressed_grid_axes def generate_final_transformation(self): map_type = _type_to_datacube_mapper_lookup[self.grid_type] @@ -80,6 +81,8 @@ def find_modified_indexes(self, indexes, path, datacube, axis): return self.first_axis_vals() if axis.name == self._mapped_axes()[1]: first_val = path[self._mapped_axes()[0]] + if not isinstance(first_val, tuple): + first_val = (first_val,) return self.second_axis_vals(first_val) def unmap_path_key(self, key_value_path, leaf_path, unwanted_path, axis): diff --git a/polytope/datacube/transformations/datacube_mappers/mapper_axis_decorator.py b/polytope/datacube/transformations/datacube_mappers/mapper_axis_decorator.py new file mode 100644 index 000000000..468aa827f --- /dev/null +++ b/polytope/datacube/transformations/datacube_mappers/mapper_axis_decorator.py @@ -0,0 +1,108 @@ +import bisect + +from ....utility.list_tools import bisect_left_cmp, bisect_right_cmp +from .datacube_mappers import DatacubeMapper + + +def mapper(cls): + if cls.has_mapper: + + def find_indexes(path, datacube): + # first, find the relevant transformation object that is a mapping in the cls.transformation dico + for transform in cls.transformations: + if isinstance(transform, DatacubeMapper): + transformation = transform + if cls.name == transformation._mapped_axes()[0]: + return transformation.first_axis_vals() + if cls.name == transformation._mapped_axes()[1]: + first_val = path[transformation._mapped_axes()[0]] + return transformation.second_axis_vals(first_val) + + old_unmap_to_datacube = cls.unmap_to_datacube + + def unmap_to_datacube(path, unmapped_path): + (path, unmapped_path) = old_unmap_to_datacube(path, unmapped_path) + for transform in cls.transformations: + if isinstance(transform, DatacubeMapper): + if cls.name == transform._mapped_axes()[0]: + # if we are on the first axis, then need to add the first val to unmapped_path + first_val = path.get(cls.name, None) + path.pop(cls.name, None) + if cls.name not in unmapped_path: + # if for some reason, the unmapped_path already has the first axis val, then don't update + unmapped_path[cls.name] = first_val + if cls.name == transform._mapped_axes()[1]: + # if we are on the second axis, then the val of the first axis is stored + # inside unmapped_path so can get it from there + second_val = path.get(cls.name, None) + path.pop(cls.name, None) + first_val = unmapped_path.get(transform._mapped_axes()[0], None) + unmapped_path.pop(transform._mapped_axes()[0], None) + # if the first_val was not in the unmapped_path, then it's still in path + if first_val is None: + first_val = path.get(transform._mapped_axes()[0], None) + path.pop(transform._mapped_axes()[0], None) + if first_val is not None and second_val is not None: + unmapped_idx = transform.unmap(first_val, second_val) + unmapped_path[transform.old_axis] = (unmapped_idx,) + return (path, unmapped_path) + + old_unmap_path_key = cls.unmap_path_key + + def unmap_path_key(key_value_path, leaf_path, unwanted_path): + key_value_path, leaf_path, unwanted_path = old_unmap_path_key(key_value_path, leaf_path, unwanted_path) + value = key_value_path[cls.name] + for transform in cls.transformations: + if isinstance(transform, DatacubeMapper): + if cls.name == transform._mapped_axes()[0]: + unwanted_val = key_value_path[transform._mapped_axes()[0]] + unwanted_path[cls.name] = unwanted_val + if cls.name == transform._mapped_axes()[1]: + first_val = unwanted_path[transform._mapped_axes()[0]] + unmapped_idx = transform.unmap(first_val, value) + leaf_path.pop(transform._mapped_axes()[0], None) + key_value_path.pop(cls.name) + key_value_path[transform.old_axis] = unmapped_idx + return (key_value_path, leaf_path, unwanted_path) + + def find_indices_between(index_ranges, low, up, datacube, method=None): + # TODO: add method for snappping + indexes_between_ranges = [] + for transform in cls.transformations: + if isinstance(transform, DatacubeMapper): + transformation = transform + if cls.name in transformation._mapped_axes(): + for idxs in index_ranges: + if method == "surrounding" or method == "nearest": + axis_reversed = transform._axis_reversed[cls.name] + if not axis_reversed: + start = bisect.bisect_left(idxs, low) + end = bisect.bisect_right(idxs, up) + else: + # TODO: do the custom bisect + end = bisect_left_cmp(idxs, low, cmp=lambda x, y: x > y) + 1 + start = bisect_right_cmp(idxs, up, cmp=lambda x, y: x > y) + start = max(start - 1, 0) + end = min(end + 1, len(idxs)) + indexes_between = idxs[start:end] + indexes_between_ranges.append(indexes_between) + else: + axis_reversed = transform._axis_reversed[cls.name] + if not axis_reversed: + lower_idx = bisect.bisect_left(idxs, low) + upper_idx = bisect.bisect_right(idxs, up) + indexes_between = idxs[lower_idx:upper_idx] + else: + # TODO: do the custom bisect + end_idx = bisect_left_cmp(idxs, low, cmp=lambda x, y: x > y) + 1 + start_idx = bisect_right_cmp(idxs, up, cmp=lambda x, y: x > y) + indexes_between = idxs[start_idx:end_idx] + indexes_between_ranges.append(indexes_between) + return indexes_between_ranges + + cls.find_indexes = find_indexes + cls.unmap_to_datacube = unmap_to_datacube + cls.find_indices_between = find_indices_between + cls.unmap_path_key = unmap_path_key + + return cls diff --git a/polytope/datacube/transformations/datacube_mappers/mapper_types/healpix.py b/polytope/datacube/transformations/datacube_mappers/mapper_types/healpix.py index 8589ec71a..93b59908f 100644 --- a/polytope/datacube/transformations/datacube_mappers/mapper_types/healpix.py +++ b/polytope/datacube/transformations/datacube_mappers/mapper_types/healpix.py @@ -12,6 +12,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]): self._resolution = resolution self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False} self._first_axis_vals = self.first_axis_vals() + self.compressed_grid_axes = [self._mapped_axes[1]] def first_axis_vals(self): rad2deg = 180 / math.pi @@ -38,7 +39,7 @@ def map_first_axis(self, lower, upper): def second_axis_vals(self, first_val): tol = 1e-8 - first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_val = [i for i in self._first_axis_vals if first_val[0] - tol <= i <= first_val[0] + tol][0] idx = self._first_axis_vals.index(first_val) # Polar caps @@ -116,9 +117,9 @@ def unmap_first_val_to_start_line_idx(self, first_val): def unmap(self, first_val, second_val): tol = 1e-8 - first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] - first_idx = self._first_axis_vals.index(first_val) - second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0] + first_value = [i for i in self._first_axis_vals if first_val[0] - tol <= i <= first_val[0] + tol][0] + first_idx = self._first_axis_vals.index(first_value) + second_val = [i for i in self.second_axis_vals(first_val) if second_val[0] - tol <= i <= second_val[0] + tol][0] second_idx = self.second_axis_vals(first_val).index(second_val) healpix_index = self.axes_idx_to_healpix_idx(first_idx, second_idx) return healpix_index diff --git a/polytope/datacube/transformations/datacube_mappers/mapper_types/local_regular.py b/polytope/datacube/transformations/datacube_mappers/mapper_types/local_regular.py index a1514778a..40f86bfbd 100644 --- a/polytope/datacube/transformations/datacube_mappers/mapper_types/local_regular.py +++ b/polytope/datacube/transformations/datacube_mappers/mapper_types/local_regular.py @@ -22,6 +22,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]): self._second_deg_increment = (local_area[3] - local_area[2]) / self.second_resolution self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False} self._first_axis_vals = self.first_axis_vals() + self.compressed_grid_axes = [self._mapped_axes[1]] def first_axis_vals(self): first_ax_vals = [self._first_axis_max - i * self._first_deg_increment for i in range(self.first_resolution + 1)] @@ -61,9 +62,9 @@ def unmap_first_val_to_start_line_idx(self, first_val): def unmap(self, first_val, second_val): tol = 1e-8 - first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_val = [i for i in self._first_axis_vals if first_val[0] - tol <= i <= first_val[0] + tol][0] first_idx = self._first_axis_vals.index(first_val) - second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0] + second_val = [i for i in self.second_axis_vals(first_val) if second_val[0] - tol <= i <= second_val[0] + tol][0] second_idx = self.second_axis_vals(first_val).index(second_val) final_index = self.axes_idx_to_regular_idx(first_idx, second_idx) return final_index diff --git a/polytope/datacube/transformations/datacube_mappers/mapper_types/octahedral.py b/polytope/datacube/transformations/datacube_mappers/mapper_types/octahedral.py index 730ac9592..f48fca712 100644 --- a/polytope/datacube/transformations/datacube_mappers/mapper_types/octahedral.py +++ b/polytope/datacube/transformations/datacube_mappers/mapper_types/octahedral.py @@ -14,6 +14,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]): self._first_idx_map = self.create_first_idx_map() self._second_axis_spacing = {} self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False} + self.compressed_grid_axes = [self._mapped_axes[1]] def gauss_first_guess(self): i = 0 @@ -81,8 +82,6 @@ def gauss_first_guess(self): def get_precomputed_values_N1280(self): lats = [0] * 2560 - # lats = SortedList() - # lats = {} lats[0] = 89.946187715665616 lats[1] = 89.876478353332288 lats[2] = 89.806357319542244 @@ -2683,7 +2682,7 @@ def map_first_axis(self, lower, upper): def second_axis_vals(self, first_val): first_axis_vals = self._first_axis_vals tol = 1e-10 - first_idx = bisect_left_cmp(first_axis_vals, first_val - tol, cmp=lambda x, y: x > y) + first_idx = bisect_left_cmp(first_axis_vals, first_val[0] - tol, cmp=lambda x, y: x > y) if first_idx >= self._resolution: first_idx = (2 * self._resolution) - 1 - first_idx first_idx = first_idx + 1 @@ -2695,7 +2694,7 @@ def second_axis_vals(self, first_val): def second_axis_spacing(self, first_val): first_axis_vals = self._first_axis_vals tol = 1e-10 - _first_idx = bisect_left_cmp(first_axis_vals, first_val - tol, cmp=lambda x, y: x > y) + _first_idx = bisect_left_cmp(first_axis_vals, first_val[0] - tol, cmp=lambda x, y: x > y) first_idx = _first_idx if first_idx >= self._resolution: first_idx = (2 * self._resolution) - 1 - first_idx @@ -2741,10 +2740,10 @@ def create_first_idx_map(self): def find_second_axis_idx(self, first_val, second_val): (second_axis_spacing, first_idx) = self.second_axis_spacing(first_val) tol = 1e-8 - if second_val / second_axis_spacing > int(second_val / second_axis_spacing) + 1 - tol: - second_idx = int(second_val / second_axis_spacing) + 1 + if second_val[0] / second_axis_spacing > int(second_val[0] / second_axis_spacing) + 1 - tol: + second_idx = int(second_val[0] / second_axis_spacing) + 1 else: - second_idx = int(second_val / second_axis_spacing) + second_idx = int(second_val[0] / second_axis_spacing) return (first_idx, second_idx) def unmap(self, first_val, second_val): diff --git a/polytope/datacube/transformations/datacube_mappers/mapper_types/reduced_ll.py b/polytope/datacube/transformations/datacube_mappers/mapper_types/reduced_ll.py index ece09b4ac..5a76f5d10 100644 --- a/polytope/datacube/transformations/datacube_mappers/mapper_types/reduced_ll.py +++ b/polytope/datacube/transformations/datacube_mappers/mapper_types/reduced_ll.py @@ -11,6 +11,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]): self._resolution = resolution self._axis_reversed = {mapped_axes[0]: False, mapped_axes[1]: False} self._first_axis_vals = self.first_axis_vals() + self.compressed_grid_axes = [self._mapped_axes[1]] def first_axis_vals(self): resolution = 180 / (self._resolution - 1) @@ -1469,7 +1470,7 @@ def lon_spacing(self): ] def second_axis_vals(self, first_val): - first_idx = self._first_axis_vals.index(first_val) + first_idx = self._first_axis_vals.index(first_val[0]) Ny = self.lon_spacing()[first_idx] second_spacing = 360 / Ny return [i * second_spacing for i in range(Ny)] @@ -1497,9 +1498,9 @@ def find_second_idx(self, first_val, second_val): def unmap(self, first_val, second_val): tol = 1e-8 - first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] - first_idx = self._first_axis_vals.index(first_val) - second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0] + first_value = [i for i in self._first_axis_vals if first_val[0] - tol <= i <= first_val[0] + tol][0] + first_idx = self._first_axis_vals.index(first_value) + second_val = [i for i in self.second_axis_vals(first_val) if second_val[0] - tol <= i <= second_val[0] + tol][0] second_idx = self.second_axis_vals(first_val).index(second_val) reduced_ll_index = self.axes_idx_to_reduced_ll_idx(first_idx, second_idx) return reduced_ll_index diff --git a/polytope/datacube/transformations/datacube_mappers/mapper_types/regular.py b/polytope/datacube/transformations/datacube_mappers/mapper_types/regular.py index c8f207fca..3b40f77e4 100644 --- a/polytope/datacube/transformations/datacube_mappers/mapper_types/regular.py +++ b/polytope/datacube/transformations/datacube_mappers/mapper_types/regular.py @@ -12,6 +12,7 @@ def __init__(self, base_axis, mapped_axes, resolution, local_area=[]): self.deg_increment = 90 / self._resolution self._axis_reversed = {mapped_axes[0]: True, mapped_axes[1]: False} self._first_axis_vals = self.first_axis_vals() + self.compressed_grid_axes = [self._mapped_axes[1]] def first_axis_vals(self): first_ax_vals = [90 - i * self.deg_increment for i in range(2 * self._resolution)] @@ -49,9 +50,9 @@ def unmap_first_val_to_start_line_idx(self, first_val): def unmap(self, first_val, second_val): tol = 1e-8 - first_val = [i for i in self._first_axis_vals if first_val - tol <= i <= first_val + tol][0] + first_val = [i for i in self._first_axis_vals if first_val[0] - tol <= i <= first_val[0] + tol][0] first_idx = self._first_axis_vals.index(first_val) - second_val = [i for i in self.second_axis_vals(first_val) if second_val - tol <= i <= second_val + tol][0] + second_val = [i for i in self.second_axis_vals(first_val) if second_val[0] - tol <= i <= second_val[0] + tol][0] second_idx = self.second_axis_vals(first_val).index(second_val) final_index = self.axes_idx_to_regular_idx(first_idx, second_idx) return final_index diff --git a/polytope/datacube/transformations/datacube_merger/datacube_merger.py b/polytope/datacube/transformations/datacube_merger/datacube_merger.py index 91d680197..bd34973e2 100644 --- a/polytope/datacube/transformations/datacube_merger/datacube_merger.py +++ b/polytope/datacube/transformations/datacube_merger/datacube_merger.py @@ -52,21 +52,26 @@ def generate_final_transformation(self): return self def unmerge(self, merged_val): - merged_val = str(merged_val) - first_idx = merged_val.find(self._linkers[0]) - first_val = merged_val[:first_idx] - first_linker_size = len(self._linkers[0]) - second_linked_size = len(self._linkers[1]) - second_val = merged_val[first_idx + first_linker_size : -second_linked_size] + first_values = [] + second_values = [] + for merged_value in merged_val: + merged_val = str(merged_value) + first_idx = merged_val.find(self._linkers[0]) + first_val = merged_val[:first_idx] + first_linker_size = len(self._linkers[0]) + second_linked_size = len(self._linkers[1]) + second_val = merged_val[first_idx + first_linker_size : -second_linked_size] - # TODO: maybe replacing like this is too specific to time/dates? - first_val = str(first_val).replace("-", "") - second_val = second_val.replace(":", "") - logging.info( - f"Unmerged value {merged_val} to values {first_val} on axis {self.name} \ - and {second_val} on axis {self._second_axis}" - ) - return (first_val, second_val) + # TODO: maybe replacing like this is too specific to time/dates? + first_val = str(first_val).replace("-", "") + second_val = second_val.replace(":", "") + logging.info( + f"Unmerged value {merged_val} to values {first_val} on axis {self.name} \ + and {second_val} on axis {self._second_axis}" + ) + first_values.append(first_val) + second_values.append(second_val) + return (tuple(first_values), tuple(second_values)) def change_val_type(self, axis_name, values): new_values = pd.to_datetime(values) diff --git a/polytope/datacube/transformations/datacube_reverse/reverse_axis_decorator.py b/polytope/datacube/transformations/datacube_reverse/reverse_axis_decorator.py new file mode 100644 index 000000000..f0a500843 --- /dev/null +++ b/polytope/datacube/transformations/datacube_reverse/reverse_axis_decorator.py @@ -0,0 +1,70 @@ +import bisect + +from .datacube_reverse import DatacubeAxisReverse + + +def reverse(cls): + if cls.reorder: + + def find_indexes(path, datacube): + # first, find the relevant transformation object that is a mapping in the cls.transformation dico + # NOTE here that we assume the subsequent axis indices will all be the same for the multiple values in the + # previous node tuples + for key in path: + path[key] = path[key][0] + subarray = datacube.dataarray.sel(path, method="nearest") + unordered_indices = datacube.datacube_natural_indexes(cls, subarray) + if cls.name in datacube.complete_axes: + ordered_indices = unordered_indices.sort_values() + else: + ordered_indices = unordered_indices + return ordered_indices + + def find_indices_between(index_ranges, low, up, datacube, method=None): + # TODO: add method for snappping + indexes_between_ranges = [] + for transform in cls.transformations: + if isinstance(transform, DatacubeAxisReverse): + transformation = transform + if cls.name == transformation.name: + for indexes in index_ranges: + if cls.name in datacube.complete_axes: + # Find the range of indexes between lower and upper + # https://pandas.pydata.org/docs/reference/api/pandas.Index.searchsorted.html + # Assumes the indexes are already sorted (could sort to be sure) and monotonically + # increasing + if method == "surrounding" or method == "nearest": + start = indexes.searchsorted(low, "left") + end = indexes.searchsorted(up, "right") + start = max(start - 1, 0) + end = min(end + 1, len(indexes)) + indexes_between = indexes[start:end].to_list() + indexes_between_ranges.append(indexes_between) + else: + start = indexes.searchsorted(low, "left") + end = indexes.searchsorted(up, "right") + indexes_between = indexes[start:end].to_list() + indexes_between_ranges.append(indexes_between) + else: + if method == "surrounding" or method == "nearest": + start = indexes.index(low) + end = indexes.index(up) + start = max(start - 1, 0) + end = min(end + 1, len(indexes)) + indexes_between = indexes[start:end] + indexes_between_ranges.append(indexes_between) + else: + lower_idx = bisect.bisect_left(indexes, low) + upper_idx = bisect.bisect_right(indexes, up) + indexes_between = indexes[lower_idx:upper_idx] + indexes_between_ranges.append(indexes_between) + return indexes_between_ranges + + def remap(range): + return [range] + + cls.remap = remap + cls.find_indexes = find_indexes + cls.find_indices_between = find_indices_between + + return cls diff --git a/polytope/datacube/transformations/datacube_type_change/datacube_type_change.py b/polytope/datacube/transformations/datacube_type_change/datacube_type_change.py index 137ed8a40..c0c25cdd6 100644 --- a/polytope/datacube/transformations/datacube_type_change/datacube_type_change.py +++ b/polytope/datacube/transformations/datacube_type_change/datacube_type_change.py @@ -58,7 +58,10 @@ def transform_type(self, value): return int(value) def make_str(self, value): - return str(value) + values = [] + for val in value: + values.append(str(val)) + return tuple(values) _type_to_datacube_type_change_lookup = {"int": "TypeChangeStrToInt"} diff --git a/polytope/datacube/tree_encoding.py b/polytope/datacube/tree_encoding.py new file mode 100644 index 000000000..5324a041a --- /dev/null +++ b/polytope/datacube/tree_encoding.py @@ -0,0 +1,162 @@ +import numpy as np +import pandas as pd + +from . import index_tree_pb2 as pb2 +from .datacube_axis import IntDatacubeAxis +from .tensor_index_tree import TensorIndexTree +from copy import deepcopy + + +def encode_tree(tree: TensorIndexTree): + node = pb2.Node() + + node.axis = tree.axis.name + + # NOTE: do we need this if we parse the tree before it has values? + if tree.result is not None: + for result in tree.result: + node.result.append(result) + + # Assign the node value according to the type + # Argueably, do not need to do this since we will only encode from the root node... + # if isinstance(tree.value[0], int): + # for i, tree_val in enumerate(tree.value): + # node.value[i].int_val = tree_val + # if isinstance(tree.value[0], float): + # for i, tree_val in enumerate(tree.value): + # node.value[i].double_val = tree_val + # if isinstance(tree.value[0], str): + # for i, tree_val in enumerate(tree.value): + # node.value[i].str_val = tree_val + # if isinstance(tree.value[0], pd.Timestamp): + # for i, tree_val in enumerate(tree.value): + # node.value[i].str_val = tree_val.strftime("%Y/%m/%dT%H:%M:%S") + # if isinstance(tree.value[0], np.datetime64): + # for i, tree_val in enumerate(tree.value): + # node.value[i].str_val = pd.to_datetime(str(tree_val)).strftime("%Y/%m/%dT%H:%M:%S") + # if isinstance(tree.value[0], np.timedelta64): + # for i, tree_val in enumerate(tree.value): + # node.value[i].str_val = str(tree_val) + + # Nest children in protobuf root tree node + for c in tree.children: + encode_child(tree, c, node) + + # Write to file + + # TODO: JUST RETURN A BYTES HERE + return node.SerializeToString() + # import time + # time1 = time.time() + # with open("./serializedTree", "wb") as fd: + # fd.write(node.SerializeToString()) + # print("TIME NOW") + # print(time.time() - time1) + + +# TODO: complete the type mappings to the right value protobuf attribute and use as a factory? +# type_mappings = {int: "int_val", +# str: "str_val", +# float: "double_val"} + + +def encode_child(tree: TensorIndexTree, child: TensorIndexTree, node, result_size=[]): + child_node = pb2.Node() + + child_node.axis = child.axis.name + # result_size.append(len(child.values)) + + # Add the result size to the final node + # TODO: how to assign repeated fields more efficiently? + # NOTE: this will only really be efficient when we compress and have less leaves + if len(child.children) == 0: + # child_node.result_size.extend(result_size) + result_size.append(len(child.values)) + child_node.result_size.extend(result_size) + # NOTE: do we need this if we parse the tree before it has values? + # TODO: not clear if child.value is a numpy array or a simple float... + # TODO: not clear what happens if child.value is a np array since this is not a supported type by protobuf + if child.result is not None: + if isinstance(child.result, list): + child_node.result.extend(child.result) + # for result in child.result: + # child_node.result.append(result) + else: + child_node.result.append(child.result) + + # Assign the node value according to the type + if isinstance(child.values[0], int): + for i, child_val in enumerate(child.values): + child_node_val = pb2.Value() + child_node_val.int_val = child_val + child_node.value.append(child_node_val) + if isinstance(child.values[0], float): + for i, child_val in enumerate(child.values): + child_node_val = pb2.Value() + child_node_val.double_val = child_val + child_node.value.append(child_node_val) + if isinstance(child.values[0], str): + for i, child_val in enumerate(child.values): + child_node_val = pb2.Value() + child_node_val.str_val = child_val + child_node.value.append(child_node_val) + if isinstance(child.values[0], pd.Timestamp): + for i, child_val in enumerate(child.values): + child_node_val = pb2.Value() + child_node_val.str_val = child_val.strftime("%Y%m%dT%H%M%S") + child_node.value.append(child_node_val) + if isinstance(child.values[0], np.datetime64): + for i, child_val in enumerate(child.values): + child_node_val = pb2.Value() + child_node_val.str_val = pd.to_datetime(str(child_val)).strftime("%Y/%m/%dT%H:%M:%S") + child_node.value.append(child_node_val) + if isinstance(child.values[0], np.timedelta64): + for i, child_val in enumerate(child.values): + child_node_val = pb2.Value() + child_node_val.str_val = str(child_val) + child_node.value.append(child_node_val) + + for c in child.children: + # result_size.append(len(child.values)) + new_result_size = deepcopy(result_size) + new_result_size.append(len(child.values)) + encode_child(child, c, child_node, new_result_size) + + # NOTE: we append the children once their branch has been completed until the leaf + node.children.append(child_node) + + +def decode_tree(datacube, bytearray): + node = pb2.Node() + node.ParseFromString(bytearray) + # with open("./serializedTree", "rb") as f: + # node.ParseFromString(f.read()) + + tree = TensorIndexTree() + + if node.axis == "root": + root = IntDatacubeAxis() + root.name = "root" + tree.axis = root + else: + tree.axis = datacube._axes[node.axis] + + # Put contents of node children into tree + decode_child(node, tree, datacube) + + return tree + + +def decode_child(node, tree, datacube): + if len(node.children) == 0: + tree.result = node.result + tree.result_size = node.result_size + for child in node.children: + child_axis = datacube._axes[child.axis] + child_vals = [] + for child_val in child.value: + child_vals.append(getattr(child_val, child_val.WhichOneof("value"))) + child_vals = tuple(child_vals) + child_node = TensorIndexTree(child_axis, child_vals) + tree.add_child(child_node) + decode_child(child, child_node, datacube) diff --git a/polytope/datacube/tree_encoding_capnp.py b/polytope/datacube/tree_encoding_capnp.py new file mode 100644 index 000000000..9ad33075b --- /dev/null +++ b/polytope/datacube/tree_encoding_capnp.py @@ -0,0 +1,155 @@ +import capnp +import numpy as np +import pandas as pd + +from .datacube_axis import IntDatacubeAxis +from .tensor_index_tree import TensorIndexTree +from copy import deepcopy + +tree_obj = capnp.load('indexTree.capnp') + + +def encode_tree(tree: TensorIndexTree): + node = tree_obj.Node.new_message() + + node.axis = tree.axis.name + + # NOTE: do we need this if we parse the tree before it has values? + if tree.result is not None: + for result in tree.result: + node.result.append(result) + + # Nest children in protobuf root tree node + children = node.init("children", int(len(tree.children))) + for i, c in enumerate(tree.children): + encode_child(tree, c, i, node, children) + + # Write to file + # node.write("./serializedTree") + with open("./serializedTree", "wb") as fd: + fd.write(node.to_bytes()) + + +def encode_child(tree: TensorIndexTree, child: TensorIndexTree, i, node, children, result_size=[]): + child_node = tree_obj.Node.new_message() + values = child_node.init('value', int(len(child.values))) + + child_node.axis = child.axis.name + # result_size.append(len(child.values)) + + # Add the result size to the final node + # TODO: how to assign repeated fields more efficiently? + # NOTE: this will only really be efficient when we compress and have less leaves + if len(child.children) == 0: + # child_node.resultSize.extend(result_size) + result_size.append(len(child.values)) + child_node.resultSize = result_size + + # NOTE: do we need this if we parse the tree before it has values? + # TODO: not clear if child.value is a numpy array or a simple float... + # TODO: not clear what happens if child.value is a np array since this is not a supported type by protobuf + if child.result is not None: + if isinstance(child.result, list): + # for result in child.result: + # child_node.result.append(result) + child_node.result = child.result + else: + # child_node.result.append(child.result) + child_node.result = [float(child.result)] + + # Assign the node value according to the type + if isinstance(child.values[0], int): + for j, child_val in enumerate(child.values): + child_node_val = tree_obj.Value.new_message() + child_node_val.value.intVal = child_val + # child_node.value.append(child_node_val) + values[j] = child_node_val + if isinstance(child.values[0], float): + for j, child_val in enumerate(child.values): + child_node_val = tree_obj.Value.new_message() + child_node_val.value.doubleVal = child_val + # child_node.value.append(child_node_val) + values[j] = child_node_val + if isinstance(child.values[0], str): + for j, child_val in enumerate(child.values): + child_node_val = tree_obj.Value.new_message() + child_node_val.value.strVal = child_val + # child_node.value.append(child_node_val) + values[j] = child_node_val + if isinstance(child.values[0], pd.Timestamp): + for j, child_val in enumerate(child.values): + child_node_val = tree_obj.Value.new_message() + child_node_val.value.strVal = child_val.strftime("%Y%m%dT%H%M%S") + # child_node.value.append(child_node_val) + values[j] = child_node_val + if isinstance(child.values[0], np.datetime64): + for j, child_val in enumerate(child.values): + child_node_val = tree_obj.Value.new_message() + child_node_val.value.strVal = pd.to_datetime(str(child_val)).strftime("%Y/%m/%dT%H:%M:%S") + # child_node.value.append(child_node_val) + values[j] = child_node_val + if isinstance(child.values[0], np.timedelta64): + for j, child_val in enumerate(child.values): + child_node_val = tree_obj.Value.new_message() + child_node_val.value.strVal = str(child_val) + # child_node.value.append(child_node_val) + values[j] = child_node_val + + child_children = child_node.init("children", int(len(child.children))) + for k, c in enumerate(child.children): + new_result_size = deepcopy(result_size) + new_result_size.append(len(child.values)) + encode_child(child, c, k, child_node, child_children, new_result_size) + + # NOTE: we append the children once their branch has been completed until the leaf + # node.children.append(child_node) + children[i] = child_node + + +def decode_tree(datacube): + # node = tree_obj.Node.read("./serializedTree") + with open("./serializedTree", "rb") as f: + node = tree_obj.Node.read(f) + # node_bytes = f.read() + # node_reader = tree_obj.Node.from_bytes(node_bytes) + # # node = tree_obj.Node.read(node_reader) + # node = node_reader.get_root(tree_obj.Node) + + tree = TensorIndexTree() + + if node.axis == "root": + root = IntDatacubeAxis() + root.name = "root" + tree.axis = root + else: + tree.axis = datacube._axes[node.axis] + + # Put contents of node children into tree + decode_child(node, tree, datacube) + + return tree + + +def decode_child(node, tree, datacube): + if len(node.children) == 0: + tree.result = node.result + tree.result_size = node.resultSize + for child in node.children: + # print("NOW") + # print(child) + child_axis = datacube._axes[child.axis] + child_vals = [] + for child_val in child.value: + which = child_val.value.which() + if which == "strVal": + new_child_val = child_val.value.strVal + if which == "intVal": + new_child_val = child_val.value.intVal + if which == "doubleVal": + new_child_val = child_val.value.doubleVal + # child_vals.append(getattr(child_val, child_val.WhichOneof("value"))) + child_vals.append(new_child_val) + child_vals = tuple(child_vals) + child_node = TensorIndexTree(child_axis, child_vals) + tree.add_child(child_node) + decode_child(child, child_node, datacube) diff --git a/polytope/engine/hullslicer.py b/polytope/engine/hullslicer.py index 8e14eeae2..70fdd18e2 100644 --- a/polytope/engine/hullslicer.py +++ b/polytope/engine/hullslicer.py @@ -5,8 +5,9 @@ import scipy.spatial -from ..datacube.backends.datacube import Datacube, IndexTree +from ..datacube.backends.datacube import Datacube from ..datacube.datacube_axis import UnsliceableDatacubeAxis +from ..datacube.tensor_index_tree import TensorIndexTree from ..shapes import ConvexPolytope from ..utility.combinatorics import argmax, argmin, group, tensor_product, unique from ..utility.exceptions import UnsliceableShapeError @@ -39,6 +40,7 @@ def _build_unsliceable_child(self, polytope, ax, node, datacube, lower, next_nod raise UnsliceableShapeError(ax) path = node.flatten() + # all unsliceable children are natively 1D so can group them together in a tuple... flattened_tuple = tuple() if len(datacube.coupled_axes) > 0: if path.get(datacube.coupled_axes[0][0], None) is not None: @@ -52,7 +54,7 @@ def _build_unsliceable_child(self, polytope, ax, node, datacube, lower, next_nod datacube_has_index = self.axis_values_between[(flattened_tuple, ax.name, lower)] if datacube_has_index: - child = node.create_child(ax, lower) + (child, next_nodes) = node.create_child(ax, lower, datacube.compressed_grid_axes, next_nodes) child["unsliced_polytopes"] = copy(node["unsliced_polytopes"]) child["unsliced_polytopes"].remove(polytope) next_nodes.append(child) @@ -73,8 +75,8 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, nex if method == "nearest": datacube.nearest_search[ax.name] = polytope.points - # TODO: this hashing doesn't work because we need to know the latitude val for finding longitude values - # TODO: Maybe create a coupled_axes list inside of datacube and add to it during axis formation, then here + # NOTE: caching + # Create a coupled_axes list inside of datacube and add to it during axis formation, then here # do something like if ax is in second place of coupled_axes, then take the flattened part of the array that # corresponds to the first place of cooupled_axes in the hashing # Else, if we do not need the flattened bit in the hash, can just put an empty string instead? @@ -95,30 +97,79 @@ def _build_sliceable_child(self, polytope, ax, node, datacube, lower, upper, nex if len(values) == 0: node.remove_branch() - for value in values: - # convert to float for slicing - fvalue = ax.to_float(value) - new_polytope = self.sliced_polytopes.get((polytope, ax.name, fvalue, slice_axis_idx), False) - if new_polytope is False: - new_polytope = slice(polytope, ax.name, fvalue, slice_axis_idx) - self.sliced_polytopes[(polytope, ax.name, fvalue, slice_axis_idx)] = new_polytope - - # store the native type - remapped_val = self.remapped_vals.get((value, ax.name), None) - if remapped_val is None: - remapped_val = value - if ax.is_cyclic: - remapped_val_interm = ax.remap([value, value])[0] - remapped_val = (remapped_val_interm[0] + remapped_val_interm[1]) / 2 - remapped_val = round(remapped_val, int(-math.log10(ax.tol))) - self.remapped_vals[(value, ax.name)] = remapped_val - - child = node.create_child(ax, remapped_val) - child["unsliced_polytopes"] = copy(node["unsliced_polytopes"]) - child["unsliced_polytopes"].remove(polytope) - if new_polytope is not None: - child["unsliced_polytopes"].add(new_polytope) - next_nodes.append(child) + # # check whether polytope is 1D and that the axis is not a coupled axis + # # read from the datacube which grid axes can be compressed... + # if ax.name not in datacube.compressed_grid_axes: + # ax_in_forbidden_axes = not any(ax.name in sublist for sublist in datacube.coupled_axes) + # else: + # ax_in_forbidden_axes = True + + # TODO: find which axes can be compressed here... + # compressed_axes = datacube.compressed_grid_axes + compressed_axes = [] + if polytope.is_natively_1D: + compressed_axes.extend(polytope.axes()) + # if polytope.method is not None: + # compressed_axes.extend(polytope.axes()) + + # if polytope.is_natively_1D and ax_in_forbidden_axes: + # # TODO: instead of checking here whether an axis/indices can be compressed and doing a for loop, + # # do this logic of recursively adding children to the tensor index tree, so do this inside of create_child + # all_remapped_vals = [] + # for value in values: + # fvalue = ax.to_float(value) + # remapped_val = self.remapped_vals.get((value, ax.name), None) + # if remapped_val is None: + # remapped_val = value + # if ax.is_cyclic: + # remapped_val_interm = ax.remap([value, value])[0] + # remapped_val = (remapped_val_interm[0] + remapped_val_interm[1]) / 2 + # remapped_val = round(remapped_val, int(-math.log10(ax.tol))) + # self.remapped_vals[(value, ax.name)] = remapped_val + # all_remapped_vals.append(remapped_val) + # # NOTE we remove unnecessary empty branches here too + # if len(tuple(all_remapped_vals)) == 0: + # node.remove_branch() + # else: + # child = node.create_child(ax, tuple(all_remapped_vals)) + # # TODO: here, we will now recursively add values to the tuple inside the created child, and we will + # only need to assign the unsliced polytopes of the child at the end? + # child["unsliced_polytopes"] = copy(node["unsliced_polytopes"]) + # child["unsliced_polytopes"].remove(polytope) + # next_nodes.append(child) + # else: + + # TODO: here add the children that are required now to the tree + + if True: + for value in values: + # convert to float for slicing + fvalue = ax.to_float(value) + new_polytope = self.sliced_polytopes.get((polytope, ax.name, fvalue, slice_axis_idx), False) + if new_polytope is False: + new_polytope = slice(polytope, ax.name, fvalue, slice_axis_idx) + self.sliced_polytopes[(polytope, ax.name, fvalue, slice_axis_idx)] = new_polytope + # store the native type + remapped_val = self.remapped_vals.get((value, ax.name), None) + if remapped_val is None: + remapped_val = value + if ax.is_cyclic: + remapped_val_interm = ax.remap([value, value])[0] + remapped_val = (remapped_val_interm[0] + remapped_val_interm[1]) / 2 + if ax.can_round: + remapped_val = round(remapped_val, int(-math.log10(ax.tol))) + self.remapped_vals[(value, ax.name)] = remapped_val + + # NOTE we remove unnecessary empty branches here too + if len(tuple([remapped_val])) == 0: + node.remove_branch() + else: + (child, next_nodes) = node.create_child(ax, remapped_val, compressed_axes, next_nodes) + child["unsliced_polytopes"] = copy(node["unsliced_polytopes"]) + child["unsliced_polytopes"].remove(polytope) + if new_polytope is not None: + child["unsliced_polytopes"].add(new_polytope) + next_nodes.append(child) def _build_branch(self, ax, node, datacube, next_nodes): for polytope in node["unsliced_polytopes"]: @@ -141,7 +192,7 @@ def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): groups, input_axes = group(polytopes) datacube.validate(input_axes) - request = IndexTree() + request = TensorIndexTree() combinations = tensor_product(groups) # NOTE: could optimise here if we know combinations will always be for one request. @@ -149,35 +200,18 @@ def extract(self, datacube: Datacube, polytopes: List[ConvexPolytope]): # directly work on request and return it... for c in combinations: - cached_node = None - repeated_sub_nodes = [] - - r = IndexTree() + r = TensorIndexTree() r["unsliced_polytopes"] = set(c) current_nodes = [r] for ax in datacube.axes.values(): next_nodes = [] + interm_next_nodes = [] for node in current_nodes: - # detect if node is for number == 1 - # store a reference to that node - # skip processing the other 49 numbers - # at the end, copy that initial reference 49 times and add to request with correct number - - stored_val = None - if node.axis.name == datacube.axis_with_identical_structure_after: - stored_val = node.value - cached_node = node - elif node.axis.name == datacube.axis_with_identical_structure_after and node.value != stored_val: - repeated_sub_nodes.append(node) - del node["unsliced_polytopes"] - continue - - self._build_branch(ax, node, datacube, next_nodes) + self._build_branch(ax, node, datacube, interm_next_nodes) + next_nodes.extend(interm_next_nodes) + interm_next_nodes = [] current_nodes = next_nodes - for n in repeated_sub_nodes: - n.copy_children_from_other(cached_node) - request.merge(r) return request diff --git a/polytope/shapes.py b/polytope/shapes.py index 0c170b3fc..384a3517e 100644 --- a/polytope/shapes.py +++ b/polytope/shapes.py @@ -23,17 +23,18 @@ def axes(self) -> List[str]: class ConvexPolytope(Shape): - def __init__(self, axes, points, method=None): + def __init__(self, axes, points, method=None, is_1D=False): self._axes = list(axes) self.is_flat = False if len(self._axes) == 1: self.is_flat = True self.points = points self.method = method + self.is_natively_1D = is_1D def extents(self, axis): if self.is_flat: - slice_axis_idx = 1 + slice_axis_idx = 0 lower = min(self.points)[0] upper = max(self.points)[0] else: @@ -44,7 +45,7 @@ def extents(self, axis): return (lower, upper, slice_axis_idx) def __str__(self): - return f"Polytope in {self.axes} with points {self.points}" + return f"Polytope in {self.axes()} with points {self.points}" def axes(self): return self._axes @@ -66,7 +67,7 @@ def axes(self): return [self.axis] def polytope(self): - return [ConvexPolytope([self.axis], [[v]], self.method) for v in self.values] + return [ConvexPolytope([self.axis], [[v]], self.method, is_1D=True) for v in self.values] def __repr__(self): return f"Select in {self.axis} with points {self.values}" @@ -110,7 +111,7 @@ def axes(self): return [self.axis] def polytope(self): - return [ConvexPolytope([self.axis], [[self.lower], [self.upper]])] + return [ConvexPolytope([self.axis], [[self.lower], [self.upper]], is_1D=True)] def __repr__(self): return f"Span in {self.axis} with range from {self.lower} to {self.upper}" diff --git a/polytope/utility/geometry.py b/polytope/utility/geometry.py index 2906fed19..fd5936615 100644 --- a/polytope/utility/geometry.py +++ b/polytope/utility/geometry.py @@ -7,9 +7,14 @@ def lerp(a, b, value): def nearest_pt(pts_list, pt): - nearest_pt = pts_list[0] - distance = l2_norm(pts_list[0], pt) - for new_pt in pts_list[1:]: + new_pts_list = [] + for potential_pt in pts_list: + for first_val in potential_pt[0]: + for second_val in potential_pt[1]: + new_pts_list.append((first_val, second_val)) + nearest_pt = new_pts_list[0] + distance = l2_norm(new_pts_list[0], pt) + for new_pt in new_pts_list[1:]: new_distance = l2_norm(new_pt, pt) if new_distance < distance: distance = new_distance diff --git a/requirements.txt b/requirements.txt index afbacd4ca..f75807725 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ scipy sortedcontainers tripy xarray +protobuf diff --git a/tests/test_cyclic_axis_over_negative_vals.py b/tests/test_cyclic_axis_over_negative_vals.py index c11bcb098..1b893a6cc 100644 --- a/tests/test_cyclic_axis_over_negative_vals.py +++ b/tests/test_cyclic_axis_over_negative_vals.py @@ -36,27 +36,27 @@ def test_cyclic_float_axis_across_seam(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, + assert [leaf.values for leaf in result.leaves] == [ + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), ] def test_cyclic_float_axis_inside_cyclic_range(self): @@ -66,23 +66,23 @@ def test_cyclic_float_axis_inside_cyclic_range(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 16 - assert [leaf.value for leaf in result.leaves] == [ - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, + assert [leaf.values for leaf in result.leaves] == [ + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), ] def test_cyclic_float_axis_above_axis_range(self): @@ -92,7 +92,18 @@ def test_cyclic_float_axis_above_axis_range(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 10 - assert [leaf.value for leaf in result.leaves] == [-0.7, -0.6, -0.5, -0.4, -0.3, -0.7, -0.6, -0.5, -0.4, -0.3] + assert [leaf.values for leaf in result.leaves] == [ + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + ] def test_cyclic_float_axis_two_range_loops(self): request = Request( @@ -101,27 +112,27 @@ def test_cyclic_float_axis_two_range_loops(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, + assert [leaf.values for leaf in result.leaves] == [ + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), ] def test_cyclic_float_axis_below_axis_range(self): @@ -130,7 +141,18 @@ def test_cyclic_float_axis_below_axis_range(self): ) result = self.API.retrieve(request) assert len(result.leaves) == 10 - assert [leaf.value for leaf in result.leaves] == [-0.7, -0.6, -0.5, -0.4, -0.3, -0.7, -0.6, -0.5, -0.4, -0.3] + assert [leaf.values for leaf in result.leaves] == [ + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + ] def test_cyclic_float_axis_below_axis_range_crossing_seam(self): request = Request( @@ -139,25 +161,25 @@ def test_cyclic_float_axis_below_axis_range_crossing_seam(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -0.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -0.1, + assert [leaf.values for leaf in result.leaves] == [ + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-0.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-0.1,), ] diff --git a/tests/test_cyclic_axis_slicer_not_0.py b/tests/test_cyclic_axis_slicer_not_0.py index 526473beb..3034deade 100644 --- a/tests/test_cyclic_axis_slicer_not_0.py +++ b/tests/test_cyclic_axis_slicer_not_0.py @@ -36,27 +36,27 @@ def test_cyclic_float_axis_across_seam(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, + assert [leaf.values for leaf in result.leaves] == [ + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), ] def test_cyclic_float_axis_inside_cyclic_range(self): @@ -65,23 +65,23 @@ def test_cyclic_float_axis_inside_cyclic_range(self): ) result = self.API.retrieve(request) assert len(result.leaves) == 16 - assert [leaf.value for leaf in result.leaves] == [ - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, + assert [leaf.values for leaf in result.leaves] == [ + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), ] def test_cyclic_float_axis_above_axis_range(self): @@ -90,7 +90,18 @@ def test_cyclic_float_axis_above_axis_range(self): ) result = self.API.retrieve(request) assert len(result.leaves) == 10 - assert [leaf.value for leaf in result.leaves] == [-0.7, -0.6, -0.5, -0.4, -0.3, -0.7, -0.6, -0.5, -0.4, -0.3] + assert [leaf.values for leaf in result.leaves] == [ + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + ] def test_cyclic_float_axis_two_range_loops(self): request = Request( @@ -98,27 +109,27 @@ def test_cyclic_float_axis_two_range_loops(self): ) result = self.API.retrieve(request) assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -1.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, + assert [leaf.values for leaf in result.leaves] == [ + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-1.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), ] def test_cyclic_float_axis_below_axis_range(self): @@ -127,7 +138,18 @@ def test_cyclic_float_axis_below_axis_range(self): ) result = self.API.retrieve(request) assert len(result.leaves) == 10 - assert [leaf.value for leaf in result.leaves] == [-0.7, -0.6, -0.5, -0.4, -0.3, -0.7, -0.6, -0.5, -0.4, -0.3] + assert [leaf.values for leaf in result.leaves] == [ + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + ] def test_cyclic_float_axis_below_axis_range_crossing_seam(self): request = Request( @@ -135,25 +157,25 @@ def test_cyclic_float_axis_below_axis_range_crossing_seam(self): ) result = self.API.retrieve(request) assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -0.1, - -1.0, - -0.9, - -0.8, - -0.7, - -0.6, - -0.5, - -0.4, - -0.3, - -0.2, - -0.1, + assert [leaf.values for leaf in result.leaves] == [ + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-0.1,), + (-1.0,), + (-0.9,), + (-0.8,), + (-0.7,), + (-0.6,), + (-0.5,), + (-0.4,), + (-0.3,), + (-0.2,), + (-0.1,), ] diff --git a/tests/test_cyclic_axis_slicing.py b/tests/test_cyclic_axis_slicing.py index 5a49be0b1..ae9ade1d9 100644 --- a/tests/test_cyclic_axis_slicing.py +++ b/tests/test_cyclic_axis_slicing.py @@ -36,27 +36,27 @@ def test_cyclic_float_axis_across_seam(self): result = self.API.retrieve(request) assert len(result.leaves) == 20 result.pprint() - assert [leaf.value for leaf in result.leaves] == [ - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, + assert [leaf.values for leaf in result.leaves] == [ + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), ] def test_cyclic_float_axis_across_seam_repeated(self): @@ -66,29 +66,29 @@ def test_cyclic_float_axis_across_seam_repeated(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 22 - assert [leaf.value for leaf in result.leaves] == [ - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, + assert [leaf.values for leaf in result.leaves] == [ + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), ] def test_cyclic_float_axis_across_seam_repeated_twice(self): @@ -98,29 +98,29 @@ def test_cyclic_float_axis_across_seam_repeated_twice(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 22 - assert [leaf.value for leaf in result.leaves] == [ - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, + assert [leaf.values for leaf in result.leaves] == [ + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), ] def test_cyclic_float_axis_inside_cyclic_range(self): @@ -130,23 +130,23 @@ def test_cyclic_float_axis_inside_cyclic_range(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 16 - assert [leaf.value for leaf in result.leaves] == [ - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, + assert [leaf.values for leaf in result.leaves] == [ + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), ] def test_cyclic_float_axis_above_axis_range(self): @@ -158,7 +158,18 @@ def test_cyclic_float_axis_above_axis_range(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 10 - assert [leaf.value for leaf in result.leaves] == [0.3, 0.4, 0.5, 0.6, 0.7, 0.3, 0.4, 0.5, 0.6, 0.7] + assert [leaf.values for leaf in result.leaves] == [ + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + ] def test_cyclic_float_axis_two_range_loops(self): request = Request( @@ -167,29 +178,29 @@ def test_cyclic_float_axis_two_range_loops(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 22 - assert [leaf.value for leaf in result.leaves] == [ - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 1.0, + assert [leaf.values for leaf in result.leaves] == [ + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (1.0,), ] def test_cyclic_float_axis_below_axis_range(self): @@ -199,7 +210,18 @@ def test_cyclic_float_axis_below_axis_range(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 10 - assert [leaf.value for leaf in result.leaves] == [0.3, 0.4, 0.5, 0.6, 0.7, 0.3, 0.4, 0.5, 0.6, 0.7] + assert [leaf.values for leaf in result.leaves] == [ + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + ] def test_cyclic_float_axis_below_axis_range_crossing_seam(self): request = Request( @@ -208,27 +230,27 @@ def test_cyclic_float_axis_below_axis_range_crossing_seam(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, - 0.0, - 0.1, - 0.2, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.8, - 0.9, + assert [leaf.values for leaf in result.leaves] == [ + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), + (0.0,), + (0.1,), + (0.2,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.8,), + (0.9,), ] def test_cyclic_float_axis_reversed(self): @@ -238,34 +260,45 @@ def test_cyclic_float_axis_reversed(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 10 - assert [leaf.value for leaf in result.leaves] == [0.3, 0.4, 0.5, 0.6, 0.7, 0.3, 0.4, 0.5, 0.6, 0.7] + assert [leaf.values for leaf in result.leaves] == [ + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + ] def test_two_cyclic_axis_wrong_axis_order(self): request = Request(Box(["step", "long", "level"], [0, 1.3, 131], [3, 1.7, 132]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, + assert [leaf.values for leaf in result.leaves] == [ + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), ] def test_two_cyclic_axis(self): @@ -273,27 +306,27 @@ def test_two_cyclic_axis(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 20 - assert [leaf.value for leaf in result.leaves] == [ - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, - 0.3, - 0.4, - 0.5, - 0.6, - 0.7, + assert [leaf.values for leaf in result.leaves] == [ + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), + (0.3,), + (0.4,), + (0.5,), + (0.6,), + (0.7,), ] def test_select_cyclic_float_axis_edge(self): @@ -301,11 +334,11 @@ def test_select_cyclic_float_axis_edge(self): result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 6 - assert [leaf.value for leaf in result.leaves] == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] + assert [leaf.values for leaf in result.leaves] == [(0.0,), (0.0,), (0.0,), (0.0,), (0.0,), (0.0,)] def test_cyclic_int_axis(self): request = Request(Box(["step", "level"], [0, 3], [3, 5]), Select("date", ["2000-01-01"]), Select("long", [0.1])) result = self.API.retrieve(request) # result.pprint() assert len(result.leaves) == 6 - assert [leaf.value for leaf in result.leaves] == [0.1, 0.1, 0.1, 0.1, 0.1, 0.1] + assert [leaf.values for leaf in result.leaves] == [(0.1,), (0.1,), (0.1,), (0.1,), (0.1,), (0.1,)] diff --git a/tests/test_cyclic_nearest.py b/tests/test_cyclic_nearest.py index 2dcf6aabd..079c92a2d 100644 --- a/tests/test_cyclic_nearest.py +++ b/tests/test_cyclic_nearest.py @@ -76,8 +76,8 @@ def test_regular_grid(self): ) result = self.API.retrieve(request) longitude_val_1 = result.leaves[0].flatten()["longitude"] - result.pprint_2() - assert longitude_val_1 == 283.561643835616 + result.pprint() + assert longitude_val_1 == (283.561643835616,) request = Request( Select("step", [0]), @@ -93,5 +93,5 @@ def test_regular_grid(self): ) result = self.API.retrieve(request) longitude_val_1 = result.leaves[0].flatten()["longitude"] - result.pprint_2() - assert longitude_val_1 == 283.561643835616 + result.pprint() + assert longitude_val_1 == (283.561643835616,) diff --git a/tests/test_cyclic_simple.py b/tests/test_cyclic_simple.py index f900cac1b..d9efcd4a1 100644 --- a/tests/test_cyclic_simple.py +++ b/tests/test_cyclic_simple.py @@ -33,7 +33,7 @@ def test_cyclic_float_axis_across_seam(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 4 - assert [leaf.value for leaf in result.leaves] == [0.1, 0.2, 0.9, 1.0] + assert [leaf.values for leaf in result.leaves] == [(0.1,), (0.2,), (0.9,), (1.0,)] def test_cyclic_float_surrounding(self): request = Request( @@ -47,7 +47,8 @@ def test_cyclic_float_surrounding(self): for leaf in result.leaves: path = leaf.flatten() lon_val = path["long"] - assert lon_val in [0.0, 0.1, 0.9, 1.0] + for val in lon_val: + assert val in [0.0, 0.1, 0.9, 1.0] def test_cyclic_float_surrounding_below_seam(self): request = Request( @@ -61,4 +62,5 @@ def test_cyclic_float_surrounding_below_seam(self): for leaf in result.leaves: path = leaf.flatten() lon_val = path["long"] - assert lon_val in [0.0, 0.1, 0.9, 1.0] + for val in lon_val: + assert val in [0.0, 0.1, 0.9, 1.0] diff --git a/tests/test_cyclic_snapping.py b/tests/test_cyclic_snapping.py index fa10fbbd3..3845d87ab 100644 --- a/tests/test_cyclic_snapping.py +++ b/tests/test_cyclic_snapping.py @@ -25,8 +25,8 @@ def test_cyclic_float_axis_across_seam(self): request = Request(Select("long", [-0.2], method="surrounding")) result = self.API.retrieve(request) result.pprint() - assert len(result.leaves) == 2 - assert result.leaves[0].flatten()["long"] == 0.0 - assert result.leaves[1].flatten()["long"] == 0.5 - assert result.leaves[0].result == (None, 0) - assert result.leaves[1].result == (None, 1) + assert len(result.leaves) == 1 + assert result.leaves[0].flatten()["long"] == (0.5, 0.0) + assert result.leaves[0].result[0] is None + assert result.leaves[0].result[1][0] == 1 + assert result.leaves[0].result[1][1] == 0 diff --git a/tests/test_datacube_axes_init.py b/tests/test_datacube_axes_init.py index bedaf8af7..b05c35ab9 100644 --- a/tests/test_datacube_axes_init.py +++ b/tests/test_datacube_axes_init.py @@ -37,7 +37,7 @@ def test_created_axes(self): 89.73614327160958, 89.6658939412157, ] - assert self.datacube._axes["longitude"].find_indexes({"latitude": 89.94618771566562}, self.datacube)[:8] == [ + assert self.datacube._axes["longitude"].find_indexes({"latitude": (89.94618771566562,)}, self.datacube)[:8] == [ 0.0, 18.0, 36.0, @@ -47,7 +47,9 @@ def test_created_axes(self): 108.0, 126.0, ] - assert len(self.datacube._axes["longitude"].find_indexes({"latitude": 89.94618771566562}, self.datacube)) == 20 + assert ( + len(self.datacube._axes["longitude"].find_indexes({"latitude": (89.94618771566562,)}, self.datacube)) == 20 + ) assert self.datacube._axes["latitude"].find_indexes({}, self.datacube)[:5] == [ 89.94618771566562, 89.87647835333229, @@ -55,7 +57,7 @@ def test_created_axes(self): 89.73614327160958, 89.6658939412157, ] - assert self.datacube._axes["longitude"].find_indexes({"latitude": 89.94618771566562}, self.datacube)[:8] == [ + assert self.datacube._axes["longitude"].find_indexes({"latitude": (89.94618771566562,)}, self.datacube)[:8] == [ 0.0, 18.0, 36.0, @@ -65,15 +67,19 @@ def test_created_axes(self): 108.0, 126.0, ] - assert len(self.datacube._axes["longitude"].find_indexes({"latitude": 89.94618771566562}, self.datacube)) == 20 + assert ( + len(self.datacube._axes["longitude"].find_indexes({"latitude": (89.94618771566562,)}, self.datacube)) == 20 + ) lon_ax = self.datacube._axes["longitude"] lat_ax = self.datacube._axes["latitude"] (path_key, path, unmapped_path) = lat_ax.unmap_path_key({"latitude": 89.94618771566562}, {}, {}) assert path == {} assert unmapped_path == {"latitude": 89.94618771566562} - (path_key, path, unmapped_path) = lon_ax.unmap_path_key({"longitude": 0.0}, {}, {"latitude": 89.94618771566562}) + (path_key, path, unmapped_path) = lon_ax.unmap_path_key( + {"longitude": (0.0,)}, {}, {"latitude": (89.94618771566562,)} + ) assert path == {} - assert unmapped_path == {"latitude": 89.94618771566562} + assert unmapped_path == {"latitude": (89.94618771566562,)} assert path_key == {"values": 0} assert lat_ax.find_indices_between([89.94618771566562, 89.87647835333229], 89.87, 90, self.datacube, 0) == [ 89.94618771566562, diff --git a/tests/test_datacube_xarray.py b/tests/test_datacube_xarray.py index 2074c7c8e..99827e5f5 100644 --- a/tests/test_datacube_xarray.py +++ b/tests/test_datacube_xarray.py @@ -1,3 +1,5 @@ +import datetime + import numpy as np import pandas as pd import pytest @@ -94,7 +96,7 @@ def test_create(self): assert len(idxs) == 0 # Tests on "step" axis, path is a sub-datacube at a specific date - partial_request = DatacubePath(date="2000-01-01") + partial_request["date"] = (datetime.datetime.strptime("2000-01-01", "%Y-%m-%d"),) # Check parsing a step correctly converts type to int assert type(datacube.get_mapper("step").parse(3)) == float @@ -107,18 +109,23 @@ def test_create(self): assert idxs == [0, 3, 6, 9] assert isinstance(idxs[0], int) + partial_request["date"] = (datetime.datetime.strptime("2000-01-01", "%Y-%m-%d"),) + # Check discretizing along 'step' axis at a specific step gives one value idxs = datacube.get_indices(partial_request, label, 3, 3) assert len(idxs) == 1 assert idxs[0] == 3 assert isinstance(idxs[0], int) + partial_request["date"] = (datetime.datetime.strptime("2000-01-01", "%Y-%m-%d"),) + # Check discretizing along 'step' axis at a step which does not exist in discrete space gives no values idxs = datacube.get_indices(partial_request, label, 4, 4) assert len(idxs) == 0 # Tests on "level" axis, path is a sub-datacube at a specific date/step - partial_request = DatacubePath(date="2000-01-01", step=3) + partial_request["date"] = (datetime.datetime.strptime("2000-01-01", "%Y-%m-%d"),) + partial_request["step"] = (3,) # Check parsing a level correctly converts type to int assert type(datacube.get_mapper("level").parse(3)) == float diff --git a/tests/test_fdb_datacube.py b/tests/test_fdb_datacube.py index 093750926..513a608e7 100644 --- a/tests/test_fdb_datacube.py +++ b/tests/test_fdb_datacube.py @@ -3,7 +3,7 @@ from polytope.engine.hullslicer import HullSlicer from polytope.polytope import Polytope, Request -from polytope.shapes import Box, Select +from polytope.shapes import Box, Select, Span # import geopandas as gpd # import matplotlib.pyplot as plt @@ -44,6 +44,10 @@ def test_fdb_datacube(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 9 + assert result.leaves[1].flatten()["longitude"][0] == 0.070093457944 + assert result.leaves[4].flatten()["longitude"][0] == 0.070148090413 + assert result.leaves[7].flatten()["longitude"][0] == 0.070202808112 + assert result.leaves[0].result == 297.9250183105469 # lats = [] # lons = [] @@ -62,3 +66,23 @@ def test_fdb_datacube(self): # plt.scatter(lons, lats, s=16, c="red", cmap="YlOrRd") # plt.colorbar(label="Temperature") # plt.show() + + @pytest.mark.fdb + def test_fdb_datacube_select_grid(self): + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20230625T120000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["an"]), + Select("latitude", [0.035149384216]), + Span("longitude", 0, 0.070093457944), + ) + result = self.API.retrieve(request) + result.pprint() + assert len(result.leaves) == 1 + assert result.leaves[0].flatten()["longitude"] == (0.0, 0.070093457944) diff --git a/tests/test_float_type.py b/tests/test_float_type.py index 6593234b3..01906a5d5 100644 --- a/tests/test_float_type.py +++ b/tests/test_float_type.py @@ -26,7 +26,9 @@ def test_slicing_span(self): request = Request(Span("lat", 4.1, 4.3), Select("long", [4.1]), Select("alt", [4.1])) result = self.API.retrieve(request) result.pprint() - assert len(result.leaves) == 3 + assert len(result.leaves) == 1 + path = result.leaves[0].flatten() + assert path["lat"] == (4.1, 4.2, 4.3) def test_slicing_point(self): request = Request(Select("lat", [4.1]), Select("long", [4.1]), Select("alt", [4.1])) diff --git a/tests/test_healpix_mapper.py b/tests/test_healpix_mapper.py index 8b6121da7..60f4ff6c8 100644 --- a/tests/test_healpix_mapper.py +++ b/tests/test_healpix_mapper.py @@ -40,16 +40,17 @@ def test_healpix_grid(self): tol = 1e-8 for i in range(len(result.leaves)): cubepath = result.leaves[i].flatten() - lat = cubepath["latitude"] - lon = cubepath["longitude"] - lats.append(lat) - lons.append(lon) - nearest_points = find_nearest_latlon("./tests/data/healpix.grib", lat, lon) - eccodes_lat = nearest_points[0][0]["lat"] - eccodes_lon = nearest_points[0][0]["lon"] - eccodes_lats.append(eccodes_lat) - assert eccodes_lat - tol <= lat - assert lat <= eccodes_lat + tol - assert eccodes_lon - tol <= lon - assert lon <= eccodes_lon + tol + lat = cubepath["latitude"][0] + new_lons = cubepath["longitude"] + for lon in new_lons: + lats.append(lat) + lons.append(lon) + nearest_points = find_nearest_latlon("./tests/data/healpix.grib", lat, lon) + eccodes_lat = nearest_points[0][0]["lat"] + eccodes_lon = nearest_points[0][0]["lon"] + assert eccodes_lat - tol <= lat + assert lat <= eccodes_lat + tol + assert eccodes_lon - tol <= lon + assert lon <= eccodes_lon + tol + eccodes_lats.append(lat) assert len(eccodes_lats) == 40 diff --git a/tests/test_hullslicer_engine.py b/tests/test_hullslicer_engine.py index ded90cb84..2a43732ce 100644 --- a/tests/test_hullslicer_engine.py +++ b/tests/test_hullslicer_engine.py @@ -2,7 +2,7 @@ import xarray as xr from polytope.datacube.backends.xarray import XArrayDatacube -from polytope.datacube.index_tree import IndexTree +from polytope.datacube.tensor_index_tree import TensorIndexTree from polytope.engine.hullslicer import HullSlicer from polytope.polytope import Polytope from polytope.shapes import Box @@ -26,14 +26,14 @@ def test_extract(self): box = Box(["step", "level"], [3.0, 1.0], [6.0, 3.0]) polytope = box.polytope() request = self.slicer.extract(self.xarraydatacube, polytope) - assert request.axis == IndexTree.root + assert request.axis == TensorIndexTree.root assert request.parent is None - assert request.value is None + assert request.values is tuple() assert len(request.leaves) == 6 assert request.leaves[0].axis.name == "level" assert len(request.children) == 2 assert request.children[0].axis.name == "step" - assert request.children[0].value == 3.0 - assert request.children[1].value == 6.0 + assert request.children[0].values == (3.0,) + assert request.children[1].values == (6.0,) for i in range(len(request.leaves)): - assert request.leaves[i].value in [1.0, 2.0, 3.0] + assert request.leaves[i].values in [(1.0,), (2.0,), (3.0,)] diff --git a/tests/test_local_grid_cyclic.py b/tests/test_local_grid_cyclic.py index da8b71f25..f5108f861 100644 --- a/tests/test_local_grid_cyclic.py +++ b/tests/test_local_grid_cyclic.py @@ -47,10 +47,10 @@ def test_fdb_datacube(self): Point(["latitude", "longitude"], [[-20, -20]]), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].flatten()["latitude"] == -20 - assert result.leaves[0].flatten()["longitude"] == -20 + assert result.leaves[0].flatten()["latitude"] == (-20,) + assert result.leaves[0].flatten()["longitude"] == (-20,) @pytest.mark.fdb def test_fdb_datacube_2(self): @@ -67,7 +67,7 @@ def test_fdb_datacube_2(self): Point(["latitude", "longitude"], [[-20, 50 + 360]]), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].flatten()["latitude"] == -20 - assert result.leaves[0].flatten()["longitude"] == 50 + assert result.leaves[0].flatten()["latitude"] == (-20,) + assert result.leaves[0].flatten()["longitude"] == (50,) diff --git a/tests/test_local_regular_grid.py b/tests/test_local_regular_grid.py index 62be20962..04cd4cabd 100644 --- a/tests/test_local_regular_grid.py +++ b/tests/test_local_regular_grid.py @@ -46,10 +46,10 @@ def test_fdb_datacube(self): Point(["latitude", "longitude"], [[0.16, 0.176]], method="nearest"), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].flatten()["latitude"] == 0 - assert result.leaves[0].flatten()["longitude"] == 0 + assert result.leaves[0].flatten()["latitude"] == (0,) + assert result.leaves[0].flatten()["longitude"] == (0,) @pytest.mark.fdb def test_point_outside_local_region(self): @@ -66,10 +66,10 @@ def test_point_outside_local_region(self): Point(["latitude", "longitude"], [[0.16, 61]], method="nearest"), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].flatten()["latitude"] == 0 - assert result.leaves[0].flatten()["longitude"] == 60 + assert result.leaves[0].flatten()["latitude"] == (0,) + assert result.leaves[0].flatten()["longitude"] == (60,) @pytest.mark.fdb def test_point_outside_local_region_2(self): @@ -86,10 +86,10 @@ def test_point_outside_local_region_2(self): Point(["latitude", "longitude"], [[41, 1]], method="nearest"), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].flatten()["latitude"] == 40 - assert result.leaves[0].flatten()["longitude"] == 1 + assert result.leaves[0].flatten()["latitude"] == (40,) + assert result.leaves[0].flatten()["longitude"] == (1,) @pytest.mark.fdb def test_point_outside_local_region_3(self): @@ -106,7 +106,7 @@ def test_point_outside_local_region_3(self): Point(["latitude", "longitude"], [[1, 61]]), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 assert result.is_root() @@ -125,7 +125,7 @@ def test_point_outside_local_region_4(self): Point(["latitude", "longitude"], [[41, 1]]), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 assert result.is_root() @@ -144,7 +144,7 @@ def test_point_outside_local_region_5(self): Point(["latitude", "longitude"], [[-41, 1]]), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 assert result.is_root() @@ -163,7 +163,7 @@ def test_point_outside_local_region_6(self): Point(["latitude", "longitude"], [[-30, -21]]), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 assert result.is_root() @@ -182,10 +182,10 @@ def test_point_outside_local_region_7(self): Point(["latitude", "longitude"], [[-41, 1]], method="nearest"), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].flatten()["latitude"] == -40 - assert result.leaves[0].flatten()["longitude"] == 1 + assert result.leaves[0].flatten()["latitude"] == (-40,) + assert result.leaves[0].flatten()["longitude"] == (1,) @pytest.mark.fdb def test_point_outside_local_region_8(self): @@ -202,10 +202,10 @@ def test_point_outside_local_region_8(self): Point(["latitude", "longitude"], [[-30, -21]], method="nearest"), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].flatten()["latitude"] == -30 - assert result.leaves[0].flatten()["longitude"] == -20 + assert result.leaves[0].flatten()["latitude"] == (-30,) + assert result.leaves[0].flatten()["longitude"] == (-20,) @pytest.mark.fdb def test_point_outside_local_region_9(self): @@ -222,7 +222,7 @@ def test_point_outside_local_region_9(self): Point(["latitude", "longitude"], [[-30, -21]], method="surrounding"), ) result = self.API.retrieve(request) - result.pprint_2() + result.pprint() assert len(result.leaves) == 3 - assert result.leaves[0].flatten()["latitude"] == -31 - assert result.leaves[0].flatten()["longitude"] == -20 + assert result.leaves[0].flatten()["latitude"] == (-31.0,) + assert result.leaves[0].flatten()["longitude"] == (-20,) diff --git a/tests/test_mappers.py b/tests/test_mappers.py index 1231f19d7..fb4783a33 100644 --- a/tests/test_mappers.py +++ b/tests/test_mappers.py @@ -58,19 +58,19 @@ def test_second_axis_vals(self): base_axis = "base" resolution = 1280 octahedral_mapper = OctahedralGridMapper(base_axis, mapped_axes, resolution) - assert octahedral_mapper.second_axis_vals(0.035149384215604956)[0] == 0 - assert octahedral_mapper.second_axis_vals(10.017574499477174)[0] == 0 - assert octahedral_mapper.second_axis_vals(89.94618771566562)[10] == 180 - assert len(octahedral_mapper.second_axis_vals(89.94618771566562)) == 20 - assert len(octahedral_mapper.second_axis_vals(89.87647835333229)) == 24 - assert len(octahedral_mapper.second_axis_vals(0.035149384215604956)) == 5136 + assert octahedral_mapper.second_axis_vals((0.035149384215604956,))[0] == 0 + assert octahedral_mapper.second_axis_vals((10.017574499477174,))[0] == 0 + assert octahedral_mapper.second_axis_vals((89.94618771566562,))[10] == 180 + assert len(octahedral_mapper.second_axis_vals((89.94618771566562,))) == 20 + assert len(octahedral_mapper.second_axis_vals((89.87647835333229,))) == 24 + assert len(octahedral_mapper.second_axis_vals((0.035149384215604956,))) == 5136 def test_map_second_axis(self): mapped_axes = ["lat", "lon"] base_axis = "base" resolution = 1280 octahedral_mapper = OctahedralGridMapper(base_axis, mapped_axes, resolution) - assert octahedral_mapper.map_second_axis(89.94618771566562, 0, 90) == [0, 18, 36, 54, 72, 90] + assert octahedral_mapper.map_second_axis((89.94618771566562,), 0, 90) == [0, 18, 36, 54, 72, 90] def test_axes_idx_to_octahedral_idx(self): mapped_axes = ["lat", "lon"] @@ -102,6 +102,6 @@ def test_unmap(self): base_axis = "base" resolution = 1280 octahedral_mapper = OctahedralGridMapper(base_axis, mapped_axes, resolution) - assert octahedral_mapper.unmap(89.94618771566562, 0) == 0 - assert octahedral_mapper.unmap(0.035149384215604956, 0) == 3299840 - 5136 - assert octahedral_mapper.unmap(-0.035149384215604956, 0) == 3299840 + assert octahedral_mapper.unmap((89.94618771566562,), (0,)) == 0 + assert octahedral_mapper.unmap((0.035149384215604956,), (0,)) == 3299840 - 5136 + assert octahedral_mapper.unmap((-0.035149384215604956,), (0,)) == 3299840 diff --git a/tests/test_merge_octahedral_one_axis.py b/tests/test_merge_octahedral_one_axis.py index 82d95fb22..d56378d3b 100644 --- a/tests/test_merge_octahedral_one_axis.py +++ b/tests/test_merge_octahedral_one_axis.py @@ -34,6 +34,6 @@ def test_merge_axis(self): Box(["latitude", "longitude"], [0, 359.8], [0.2, 361.2]), ) result = self.API.retrieve(request) - # result.pprint() - assert result.leaves[-1].flatten()["longitude"] == 360.0 - assert result.leaves[0].flatten()["longitude"] == 0.070093457944 + result.pprint() + assert max(result.leaves[-1].flatten()["longitude"]) == 360.0 + assert min(result.leaves[0].flatten()["longitude"]) == 0.070093457944 diff --git a/tests/test_merge_transformation.py b/tests/test_merge_transformation.py index 71d6f3129..12d1f76cf 100644 --- a/tests/test_merge_transformation.py +++ b/tests/test_merge_transformation.py @@ -25,4 +25,4 @@ def setup_method(self, method): def test_merge_axis(self): request = Request(Select("date", [pd.Timestamp("20000101T060000")])) result = self.API.retrieve(request) - assert result.leaves[0].flatten()["date"] == pd.Timestamp("2000-01-01T06:00:00") + assert result.leaves[0].flatten()["date"] == (np.datetime64("2000-01-01T06:00:00"),) diff --git a/tests/test_octahedral_grid.py b/tests/test_octahedral_grid.py index e041b0e53..131b996fd 100644 --- a/tests/test_octahedral_grid.py +++ b/tests/test_octahedral_grid.py @@ -33,6 +33,7 @@ def test_octahedral_grid(self): Select("valid_time", ["2023-06-25T12:00:00"]), ) result = self.API.retrieve(request) + result.pprint() assert len(result.leaves) == 9 lats = [] @@ -45,12 +46,12 @@ def test_octahedral_grid(self): lon = cubepath["longitude"] lats.append(lat) lons.append(lon) - nearest_points = find_nearest_latlon("./tests/data/foo.grib", lat, lon) + nearest_points = find_nearest_latlon("./tests/data/foo.grib", lat[0], lon[0]) eccodes_lat = nearest_points[0][0]["lat"] eccodes_lon = nearest_points[0][0]["lon"] eccodes_lats.append(eccodes_lat) - assert eccodes_lat - tol <= lat - assert lat <= eccodes_lat + tol - assert eccodes_lon - tol <= lon - assert lon <= eccodes_lon + tol + assert eccodes_lat - tol <= lat[0] + assert lat[0] <= eccodes_lat + tol + assert eccodes_lon - tol <= lon[0] + assert lon[0] <= eccodes_lon + tol assert len(eccodes_lats) == 9 diff --git a/tests/test_point_nearest.py b/tests/test_point_nearest.py index 834b42d3f..d3e076d63 100644 --- a/tests/test_point_nearest.py +++ b/tests/test_point_nearest.py @@ -95,7 +95,7 @@ def test_fdb_datacube_true_point_3(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].value == 359.929906542056 + assert result.leaves[0].values == (359.929906542056,) assert result.leaves[0].axis.name == "longitude" @pytest.mark.fdb @@ -115,7 +115,7 @@ def test_fdb_datacube_true_point_5(self): result = self.API.retrieve(request) result.pprint() assert len(result.leaves) == 1 - assert result.leaves[0].value == 359.929906542056 + assert result.leaves[0].values == (359.929906542056,) assert result.leaves[0].axis.name == "longitude" @pytest.mark.fdb @@ -135,5 +135,5 @@ def test_fdb_datacube_true_point_4(self): result = self.API.retrieve(request) # result.pprint_2() assert len(result.leaves) == 1 - assert result.leaves[0].value == 359.929906542056 + assert result.leaves[0].values == (359.929906542056,) assert result.leaves[0].axis.name == "longitude" diff --git a/tests/test_protobuf_encoder_timings.py b/tests/test_protobuf_encoder_timings.py new file mode 100644 index 000000000..6857bd3d0 --- /dev/null +++ b/tests/test_protobuf_encoder_timings.py @@ -0,0 +1,55 @@ +import pandas as pd +import pytest + +from polytope.datacube.tree_encoding import decode_tree, encode_tree +from polytope.engine.hullslicer import HullSlicer +from polytope.polytope import Polytope, Request +from polytope.shapes import Box, Select + + +class TestEncoder: + def setup_method(self): + from polytope.datacube.backends.fdb import FDBDatacube + + # Create a dataarray with 3 labelled axes using different index types + self.options = { + "values": {"mapper": {"type": "octahedral", "resolution": 1280, "axes": ["latitude", "longitude"]}}, + "date": {"merge": {"with": "time", "linkers": ["T", "00"]}}, + "step": {"type_change": "int"}, + "latitude": {"reverse": {True}}, + } + self.config = {"class": "od", "expver": "0001", "levtype": "sfc", "stream": "oper", "type": "fc"} + self.datacube = FDBDatacube(self.config, axis_options=self.options) + self.slicer = HullSlicer() + self.API = Polytope(datacube=self.datacube, engine=self.slicer, axis_options=self.options) + request = Request( + Select("step", [0]), + Select("levtype", ["sfc"]), + Select("date", [pd.Timestamp("20240118T000000")]), + Select("domain", ["g"]), + Select("expver", ["0001"]), + Select("param", ["49", "167"]), + Select("class", ["od"]), + Select("stream", ["oper"]), + Select("type", ["fc"]), + Box(["latitude", "longitude"], [0, 0], [5, 5]), + ) + self.tree = self.API.retrieve(request) + # self.tree.pprint() + + @pytest.mark.fdb + def test_encoding(self): + import time + + time0 = time.time() + encoded_bytes = encode_tree(self.tree) + time1 = time.time() + print("TIME TO ENCODE") + print(time1 - time0) + print(len(self.tree.leaves)) + time2 = time.time() + decoded_tree = decode_tree(self.datacube, encoded_bytes) + time3 = time.time() + print("TIME TO DECODE") + print(time3 - time2) + decoded_tree.pprint() diff --git a/tests/test_reduced_ll_grid.py b/tests/test_reduced_ll_grid.py index 389af4171..fdbf5f121 100644 --- a/tests/test_reduced_ll_grid.py +++ b/tests/test_reduced_ll_grid.py @@ -62,8 +62,8 @@ def test_reduced_ll_grid(self): leaves = result.leaves for i in range(len(leaves)): cubepath = leaves[i].flatten() - lat = cubepath["latitude"] - lon = cubepath["longitude"] + lat = cubepath["latitude"][0] + lon = cubepath["longitude"][0] del cubepath lats.append(lat) lons.append(lon) diff --git a/tests/test_regular_grid.py b/tests/test_regular_grid.py index a28f005a8..ffb02bfc4 100644 --- a/tests/test_regular_grid.py +++ b/tests/test_regular_grid.py @@ -63,8 +63,8 @@ def test_regular_grid(self): leaves = result.leaves for i in range(len(leaves)): cubepath = leaves[i].flatten() - lat = cubepath["latitude"] - lon = cubepath["longitude"] + lat = cubepath["latitude"][0] + lon = cubepath["longitude"][0] lats.append(lat) lons.append(lon) nearest_points = find_nearest_latlon("./tests/data/era5-levels-members.grib", lat, lon) diff --git a/tests/test_request_trees_after_slicing.py b/tests/test_request_trees_after_slicing.py index e64917aa3..82c4e870c 100644 --- a/tests/test_request_trees_after_slicing.py +++ b/tests/test_request_trees_after_slicing.py @@ -28,7 +28,7 @@ def test_path_values(self): request = self.slicer.extract(self.xarraydatacube, polytope) datacube_path = request.leaves[0].flatten() # request.pprint() - assert datacube_path.values() == tuple([3.0, 1.0]) + assert datacube_path.values() == tuple([tuple([3.0]), tuple([1.0])]) assert len(datacube_path.values()) == 2 def test_path_keys(self): @@ -51,8 +51,8 @@ def test_flatten(self): polytope = box.polytope() request = self.slicer.extract(self.xarraydatacube, polytope) path = request.leaves[0].flatten() - assert path["step"] == 3.0 - assert path["level"] == 1.0 + assert path["step"] == tuple([3.0]) + assert path["level"] == tuple([1.0]) def test_add_child(self): box = Box(["step", "level"], [3.0, 1.0], [6.0, 3.0]) @@ -63,14 +63,14 @@ def test_add_child(self): # Test adding child axis1 = IntDatacubeAxis() axis1.name = "lat" - request2.create_child(axis1, 4.1) + request2.create_child(axis1, 4.1, [], []) assert request2.leaves[0].axis.name == "lat" - assert request2.leaves[0].value == 4.1 + assert request2.leaves[0].values == tuple([4.1]) axis2 = IntDatacubeAxis() axis2.name = "level" # Test getting child - assert request1.create_child(axis2, 3.0).axis.name == "level" - assert request1.create_child(axis2, 3.0).value == 3.0 + assert request1.create_child(axis2, 3.0, [], [])[0].axis.name == "level" + assert request1.create_child(axis2, 3.0, [], [])[0].values == tuple([3.0]) def test_pprint(self): box = Box(["step", "level"], [3.0, 1.0], [6.0, 3.0]) @@ -91,7 +91,7 @@ def test_remove_branch(self): axis1.name = "step" axis2.name = "level" # Test if remove_branch() also removes longer branches - request1 = request.create_child(axis1, 1.0) - request2 = request1.create_child(axis2, 0.0) - request2.remove_branch() - assert request1.is_root() # removed from original + request1 = request.create_child(axis1, 1.0, [], []) + request2 = request1[0].create_child(axis2, 0.0, [], []) + request2[0].remove_branch() + assert request1[0].is_root() # removed from original diff --git a/tests/test_shapes.py b/tests/test_shapes.py index b0daa8b10..593d8b260 100644 --- a/tests/test_shapes.py +++ b/tests/test_shapes.py @@ -30,12 +30,16 @@ def setup_method(self, method): def test_all(self): request = Request(Select("step", [3]), Select("date", ["2000-01-01"]), All("level"), Select("longitude", [1])) result = self.API.retrieve(request) - assert len(result.leaves) == 129 + assert len(result.leaves) == 1 + path = result.leaves[0].flatten() + assert path["level"] == tuple(range(1, 130)) def test_all_cyclic(self): request = Request(Select("step", [3]), Select("date", ["2000-01-01"]), Select("level", [1]), All("longitude")) result = self.API.retrieve(request) - assert len(result.leaves) == 360 + assert len(result.leaves) == 1 + path = result.leaves[0].flatten() + assert path["longitude"] == tuple(range(0, 360)) @pytest.mark.fdb def test_all_mapper_cyclic(self): @@ -70,4 +74,26 @@ def test_all_mapper_cyclic(self): ) result = self.API.retrieve(request) # result.pprint() - assert len(result.leaves) == 20 + assert len(result.leaves) == 1 + assert result.leaves[0].flatten()["longitude"] == ( + 0.0, + 18.0, + 36.0, + 54.0, + 72.0, + 90.0, + 108.0, + 126.0, + 144.0, + 162.0, + 180.0, + 198.0, + 216.0, + 234.0, + 252.0, + 270.0, + 288.0, + 306.0, + 324.0, + 342.0, + ) diff --git a/tests/test_slice_date_range_fdb_v2.py b/tests/test_slice_date_range_fdb_v2.py index 7fe005450..560649f57 100644 --- a/tests/test_slice_date_range_fdb_v2.py +++ b/tests/test_slice_date_range_fdb_v2.py @@ -1,3 +1,4 @@ +import numpy as np import pandas as pd import pytest @@ -42,4 +43,18 @@ def test_fdb_datacube(self): ) result = self.API.retrieve(request) result.pprint() - assert len(result.leaves) == 6 + assert len(result.leaves) == 2 + path1 = result.leaves[0].flatten() + assert path1["date"] == ( + np.datetime64("2017-01-01T12:00:00"), + np.datetime64("2017-01-02T00:00:00"), + np.datetime64("2017-01-02T12:00:00"), + ) + assert path1["levelist"] == ("500",) + path1 = result.leaves[1].flatten() + assert path1["date"] == ( + np.datetime64("2017-01-01T12:00:00"), + np.datetime64("2017-01-02T00:00:00"), + np.datetime64("2017-01-02T12:00:00"), + ) + assert path1["levelist"] == ("850",) diff --git a/tests/test_slicer_engine.py b/tests/test_slicer_engine.py index dd876beb4..1d498f796 100644 --- a/tests/test_slicer_engine.py +++ b/tests/test_slicer_engine.py @@ -2,7 +2,7 @@ import xarray as xr from polytope.datacube.backends.xarray import XArrayDatacube -from polytope.datacube.index_tree import IndexTree +from polytope.datacube.tensor_index_tree import TensorIndexTree from polytope.engine.hullslicer import HullSlicer from polytope.polytope import Polytope from polytope.shapes import Box @@ -26,14 +26,14 @@ def test_extract(self): box = Box(["step", "level"], [3.0, 1.0], [6.0, 3.0]) polytope = box.polytope() request = self.slicer.extract(self.xarraydatacube, polytope) - assert request.axis == IndexTree.root + assert request.axis == TensorIndexTree.root assert request.parent is None - assert request.value is None + assert request.values is tuple() assert len(request.leaves) == 6 assert request.leaves[0].axis.name == "level" assert len(request.children) == 2 assert request.children[0].axis.name == "step" - assert request.children[0].value == 3.0 - assert request.children[1].value == 6.0 + assert request.children[0].values == (3.0,) + assert request.children[1].values == (6.0,) for i in range(len(request.leaves)): - assert request.leaves[i].value in [1.0, 2.0, 3.0] + assert request.leaves[i].values in [(1.0,), (2.0,), (3.0,)] diff --git a/tests/test_slicing_xarray_3D.py b/tests/test_slicing_xarray_3D.py index 9742f909e..8ed7e9c6a 100644 --- a/tests/test_slicing_xarray_3D.py +++ b/tests/test_slicing_xarray_3D.py @@ -6,7 +6,7 @@ import xarray as xr from polytope.datacube.backends.xarray import XArrayDatacube -from polytope.datacube.index_tree import IndexTree +from polytope.datacube.tensor_index_tree import TensorIndexTree from polytope.engine.hullslicer import HullSlicer from polytope.polytope import Polytope, Request from polytope.shapes import ( @@ -66,7 +66,9 @@ def test_point(self): def test_segment(self): request = Request(Span("level", 10, 11), Select("date", ["2000-01-01"]), Select("step", [9])) result = self.API.retrieve(request) - assert len(result.leaves) == 2 + assert len(result.leaves) == 1 + path = result.leaves[0].flatten() + assert path["level"] == (10, 11) def test_union_line_point(self): seg1 = Span("step", 4.3, 6.2) @@ -114,49 +116,49 @@ def test_union_empty_lines(self): seg2 = Span("step", 10, 11) request = Request(Union(["step"], seg1, seg2), Select("date", ["2000-01-01"]), Select("level", [100])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_empty_box_no_level(self): # Slices non-existing level data request = Request(Box(["step", "level"], [3, 10.5], [7, 10.99]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_empty_box_no_level_step(self): # Slices non-existing level and step data request = Request(Box(["step", "level"], [4, 10.5], [5, 10.99]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_empty_box_no_step(self): # Slices non-existing step and level data request = Request(Box(["step", "level"], [4, 10], [5, 10.49]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_empty_box_floating_steps(self): # Slices through no step data and float type level data request = Request(Box(["step", "level"], [4.1, 10.3], [5.7, 11.8]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_empty_box_no_step_level_float(self): # Slices empty step and level box request = Request(Box(["step", "level"], [4.1, 10.3], [5.7, 10.8]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_empty_no_step_unordered(self): # Slice empty box because no step is available request = Request(Box(["level", "step"], [10, 4], [10, 5]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_nonexisting_date(self): # Slices non-existing date data request = Request(Select("date", ["2000-01-04"]), Select("level", [100]), Select("step", [3])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_two_nonexisting_close_points(self): # Slices two close points neither of which are available in the datacube @@ -164,7 +166,7 @@ def test_two_nonexisting_close_points(self): pt2 = Select("step", [3.001]) request = Request(Union(["step"], pt1, pt2), Select("level", [100]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_union_two_nonexisting_points(self): # Slices two close points neither of which are available in the datacube. @@ -173,7 +175,7 @@ def test_union_two_nonexisting_points(self): pt2 = Select("step", [3.001]) request = Request(Union(["step"], pt1, pt2), Select("level", [100]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_two_close_points_no_level(self): # Slices non-existing step points and non-existing level @@ -181,19 +183,19 @@ def test_two_close_points_no_level(self): pt2 = Select("step", [3.001]) request = Request(Union(["step"], pt1, pt2), Select("level", [100.1]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_nonexisting_point_float_level(self): # Slices non-existing level data request = Request(Select("step", [3]), Select("level", [99.1]), Select("date", ["2000-01-02"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_nonexisting_segment(self): # Slices non-existing step data request = Request(Span("step", 3.2, 3.23), Select("level", [99]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root # Testing edge cases @@ -230,4 +232,4 @@ def test_intersection_point_disk_polygon(self): request = Request(Disk(["level", "step"], [0, 0], [r1, r2]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) paths = [r.flatten().values() for r in result.leaves] - assert (pd.Timestamp("2000-01-01 00:00:00"), 3.0, 1.0) in paths + assert ((pd.Timestamp("2000-01-01 00:00:00"),), (3,), (1,)) in paths diff --git a/tests/test_slicing_xarray_4D.py b/tests/test_slicing_xarray_4D.py index a19c260d4..bee391ca2 100644 --- a/tests/test_slicing_xarray_4D.py +++ b/tests/test_slicing_xarray_4D.py @@ -3,7 +3,7 @@ import pytest import xarray as xr -from polytope.datacube.index_tree import IndexTree +from polytope.datacube.tensor_index_tree import TensorIndexTree from polytope.engine.hullslicer import HullSlicer from polytope.polytope import Polytope, Request from polytope.shapes import ( @@ -140,7 +140,7 @@ def test_empty_circle(self): Disk(["step", "level"], [5, 3.4], [0.5, 0.2]), Select("date", ["2000-01-01"]), Select("lat", [5.1]) ) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_float_box(self): # Slices a box with no data inside @@ -148,7 +148,7 @@ def test_float_box(self): Box(["step", "lat"], [10.1, 1.01], [10.3, 1.04]), Select("date", ["2000-01-01"]), Select("level", [10]) ) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_path_empty_box(self): # Slices the path of a box with no data inside, but gives data because the box is swept over a datacube value @@ -168,14 +168,14 @@ def test_path_empty_box_empty(self): Select("date", ["2000-01-01"]), ) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_ellipsoid_empty(self): # Slices an empty ellipsoid which doesn't have any step value ellipsoid = Ellipsoid(["step", "level", "lat"], [5, 3, 2.1], [0, 0, 0]) request = Request(ellipsoid, Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root # Testing special properties @@ -185,7 +185,9 @@ def test_span_bounds(self): Span("level", 100, 98), Select("step", [3]), Select("lat", [5.5]), Select("date", ["2000-01-01"]) ) result = self.API.retrieve(request) - assert len(result.leaves) == 3 + assert len(result.leaves) == 1 + path = result.leaves[0].flatten() + assert path["level"] == (98, 99, 100) # Testing edge cases @@ -195,7 +197,7 @@ def test_ellipsoid_one_point(self): request = Request(ellipsoid, Select("date", ["2000-01-01"])) result = self.API.retrieve(request) assert len(result.leaves) == 1 - assert not result.leaves[0].axis == IndexTree.root + assert not result.leaves[0].axis == TensorIndexTree.root def test_flat_box_level(self): # Slices a line in the step direction @@ -213,7 +215,7 @@ def test_flat_disk_nonexisting(self): # Slices an empty disk because there is no step level request = Request(Disk(["level", "step"], [4, 5], [4, 0]), Select("lat", [6]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_flat_disk_line(self): # Slices a line in the level direction @@ -231,20 +233,20 @@ def test_flat_disk_empty(self): # Slices an empty disk because there is no step request = Request(Disk(["level", "step"], [4, 5], [0, 0.5]), Select("lat", [6]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_disk_point(self): # Slices a point because the origin of the disk is a datacube point request = Request(Disk(["level", "step"], [4, 6], [0, 0]), Select("lat", [6]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) assert len(result.leaves) == 1 - assert not result.leaves[0].axis == IndexTree.root + assert not result.leaves[0].axis == TensorIndexTree.root def test_empty_disk(self): # Slices an empty object because the origin of the disk is not a datacube point request = Request(Disk(["level", "step"], [4, 5], [0, 0]), Select("lat", [6]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root def test_polygon_line(self): # Slices a line defined through the polygon shape @@ -260,14 +262,14 @@ def test_polygon_point(self): request = Request(polygon, Select("lat", [4.3]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) assert len(result.leaves) == 1 - assert not result.leaves[0].axis == IndexTree.root + assert not result.leaves[0].axis == TensorIndexTree.root def test_polygon_empty(self): # Slices a point which isn't in the datacube (defined through the polygon shape) polygon = Polygon(["step", "level"], [[2, 3.1]]) request = Request(polygon, Select("lat", [4.3]), Select("date", ["2000-01-01"])) result = self.API.retrieve(request) - assert result.leaves[0].axis == IndexTree.root + assert result.leaves[0].axis == TensorIndexTree.root # Test exceptions are returned correctly diff --git a/tests/test_snapping.py b/tests/test_snapping.py index 41492f06f..8649b5a7a 100644 --- a/tests/test_snapping.py +++ b/tests/test_snapping.py @@ -26,31 +26,31 @@ def test_2D_point(self): request = Request(Select("level", [2], method="surrounding"), Select("step", [4], method="surrounding")) result = self.API.retrieve(request) result.pprint() - assert len(result.leaves) == 4 + assert len(result.leaves) == 1 for leaf in result.leaves: path = leaf.flatten() - assert path["level"] in [1, 3] - assert path["step"] in [3, 5] + assert path["level"] == (1, 3) + assert path["step"] == (3, 5) def test_2D_point_outside_datacube_left(self): request = Request(Select("level", [2], method="surrounding"), Select("step", [0], method="surrounding")) result = self.API.retrieve(request) result.pprint() - assert len(result.leaves) == 2 + assert len(result.leaves) == 1 for leaf in result.leaves: path = leaf.flatten() - assert path["level"] in [1, 3] - assert path["step"] == 1 + assert path["level"] == (1, 3) + assert path["step"] == (1,) def test_2D_point_outside_datacube_right(self): request = Request(Select("level", [2], method="surrounding"), Select("step", [6], method="surrounding")) result = self.API.retrieve(request) result.pprint() - assert len(result.leaves) == 2 + assert len(result.leaves) == 1 for leaf in result.leaves: path = leaf.flatten() - assert path["level"] in [1, 3] - assert path["step"] == 5 + assert path["level"] == (1, 3) + assert path["step"] == (5,) def test_1D_point_outside_datacube_right(self): request = Request(Select("level", [1]), Select("step", [6], method="surrounding")) @@ -59,8 +59,8 @@ def test_1D_point_outside_datacube_right(self): assert len(result.leaves) == 1 for leaf in result.leaves: path = leaf.flatten() - assert path["level"] == 1 - assert path["step"] == 5 + assert path["level"] == (1,) + assert path["step"] == (5,) def test_1D_nonexisting_point(self): request = Request(Select("level", [2]), Select("step", [6], method="surrounding")) @@ -83,5 +83,5 @@ def test_1D_nonexisting_point_surrounding(self): assert len(result.leaves) == 1 for leaf in result.leaves: path = leaf.flatten() - assert path["level"] == 1 - assert path["step"] == 5 + assert path["level"] == (1,) + assert path["step"] == (5,) diff --git a/tests/test_snapping_real_data.py b/tests/test_snapping_real_data.py index 1f113dfca..964fdacff 100644 --- a/tests/test_snapping_real_data.py +++ b/tests/test_snapping_real_data.py @@ -32,7 +32,6 @@ def test_surrounding_on_grid_point(self): request = Request( Box(["number", "isobaricInhPa"], [6, 500.0], [6, 850.0]), Select("time", ["2017-01-02T12:00:00"]), - # Box(["latitude", "longitude"], lower_corner=[0.0, 0.0], upper_corner=[10.0, 30.0]), Select("latitude", [requested_lat], method="surrounding"), Select("longitude", [requested_lon], method="surrounding"), Select("step", [np.timedelta64(0, "s")]), @@ -64,5 +63,8 @@ def test_surrounding_on_grid_point(self): # plt.scatter([requested_lon], [requested_lat], s=16, c="blue") # plt.colorbar(label="Temperature") # plt.show() + assert len(longs) == 2 for lon in longs: - assert lon in [357, 0, 3] + assert lon == (0.0, 3.0, 357.0) + for lat in lats: + assert lat == (-3.0, 0.0, 3.0) diff --git a/tests/test_tree_protobuf.py b/tests/test_tree_protobuf.py new file mode 100644 index 000000000..03fe08158 --- /dev/null +++ b/tests/test_tree_protobuf.py @@ -0,0 +1,30 @@ +import polytope.datacube.index_tree_pb2 as pb2 + + +class TestTreeProtobuf: + def test_protobuf_tree(self): + node = pb2.Node() + node2 = pb2.Node() + node3 = pb2.Node() + val1 = pb2.Value() + val2 = pb2.Value() + val3 = pb2.Value() + val1.int_val = 1 + node.value.append(val1) + val2.int_val = 2 + node2.value.append(val2) + val3.int_val = 3 + node3.value.append(val3) + node4 = pb2.Node() + val4 = pb2.Value() + val4.int_val = 4 + node4.value.append(val4) + node3.children.extend([node4]) + node5 = node.children.add() + val5 = pb2.Value() + val5.int_val = 5 + node5.value.append(val5) + node.children.extend([node2, node3]) + + assert len(node.children) == 3 + assert len(node.children[2].children) == 1 diff --git a/tests/test_tree_protobuf_encoding.py b/tests/test_tree_protobuf_encoding.py new file mode 100644 index 000000000..8f6445c10 --- /dev/null +++ b/tests/test_tree_protobuf_encoding.py @@ -0,0 +1,53 @@ +import numpy as np +import pandas as pd + +from polytope.datacube.backends.mock import MockDatacube +from polytope.datacube.datacube_axis import ( + FloatDatacubeAxis, + IntDatacubeAxis, + PandasTimedeltaDatacubeAxis, + PandasTimestampDatacubeAxis, + UnsliceableDatacubeAxis, +) +from polytope.datacube.tensor_index_tree import TensorIndexTree +from polytope.datacube.tree_encoding import decode_tree, encode_tree + + +class TestEncoder: + def setup_method(self): + self.fake_tree = TensorIndexTree() + child_ax1 = IntDatacubeAxis() + child_ax1.name = "ax1" + child1 = TensorIndexTree(child_ax1, (1,)) + child_ax2 = PandasTimestampDatacubeAxis() + child_ax2.name = "timestamp_ax" + child2 = TensorIndexTree(child_ax2, (pd.Timestamp("2000-01-01 00:00:00"),)) + grandchild_ax1 = FloatDatacubeAxis() + grandchild_ax1.name = "ax2" + grandchild1 = TensorIndexTree(grandchild_ax1, (2.3,)) + grandchild_ax2 = UnsliceableDatacubeAxis() + grandchild_ax2.name = "ax3" + grandchild2 = TensorIndexTree(grandchild_ax2, ("var1",)) + grandchild_ax3 = PandasTimedeltaDatacubeAxis() + grandchild_ax3.name = "timedelta_ax" + grandchild3 = TensorIndexTree(grandchild_ax3, (np.timedelta64(0, "s"),)) + child1.add_child(grandchild2) + child1.add_child(grandchild1) + child2.add_child(grandchild3) + # TODO: test the timestamp and timedelta axes too + self.fake_tree.add_child(child1) + self.fake_tree.add_child(child2) + self.datacube = MockDatacube({"ax1": 1, "ax2": 1, "ax3": 1, "timestamp_ax": 1, "timedelta_ax": 1}) + self.datacube._axes = { + "ax1": child_ax1, + "ax2": grandchild_ax1, + "ax3": grandchild_ax2, + "timestamp_ax": child_ax2, + "timedelta_ax": grandchild_ax3, + } + + def test_encoding(self): + encoded_bytes = encode_tree(self.fake_tree) + decoded_tree = decode_tree(self.datacube, encoded_bytes) + decoded_tree.pprint() + assert decoded_tree.leaves[0].result_size == [1, 1] diff --git a/tests/test_type_change_transformation.py b/tests/test_type_change_transformation.py index 5291e4180..aa8da306b 100644 --- a/tests/test_type_change_transformation.py +++ b/tests/test_type_change_transformation.py @@ -25,4 +25,4 @@ def test_merge_axis(self): request = Request(Select("step", [0])) result = self.API.retrieve(request) result.pprint() - assert result.leaves[0].flatten()["step"] == 0 + assert result.leaves[0].flatten()["step"] == (0,)