|
1 | 1 | from functools import wraps
|
| 2 | +from itertools import chain |
2 | 3 |
|
3 | 4 | import gsw
|
4 | 5 | import xarray as xr
|
|
7 | 8 | from ._names import _names
|
8 | 9 | from ._check_funcs import _check_funcs
|
9 | 10 |
|
| 11 | +try: |
| 12 | + import pint_xarray |
| 13 | + import pint |
| 14 | + |
| 15 | +except ImportError: |
| 16 | + pint_xarray = None |
| 17 | + |
10 | 18 |
|
11 | 19 | def add_attrs(rv, attrs, name):
|
12 | 20 | if isinstance(rv, xr.DataArray):
|
13 | 21 | rv.name = name
|
14 | 22 | rv.attrs = attrs
|
15 | 23 |
|
16 | 24 |
|
| 25 | +def quantify(rv, attrs, unit_registry=None): |
| 26 | + if unit_registry is None: |
| 27 | + return rv |
| 28 | + |
| 29 | + if isinstance(rv, xr.DataArray): |
| 30 | + rv = rv.pint.quantify(unit_registry=unit_registry) |
| 31 | + else: |
| 32 | + if attrs is not None: |
| 33 | + # Necessary to use the Q_ and not simply multiplication with ureg unit because of temperature |
| 34 | + # see https://pint.readthedocs.io/en/latest/nonmult.html |
| 35 | + rv = unit_registry.Quantity(rv, attrs["units"]) |
| 36 | + return rv |
| 37 | + |
| 38 | + |
| 39 | +def pint_compat(args, kwargs): |
| 40 | + if pint_xarray is None: |
| 41 | + return args, kwargs, None |
| 42 | + |
| 43 | + using_pint = False |
| 44 | + new_args = [] |
| 45 | + new_kwargs = {} |
| 46 | + registries = [] |
| 47 | + for arg in args: |
| 48 | + if isinstance(arg, xr.DataArray): |
| 49 | + if arg.pint.units is not None: |
| 50 | + new_args.append(arg.pint.dequantify()) |
| 51 | + registries.append(arg.pint.registry) |
| 52 | + else: |
| 53 | + new_args.append(arg) |
| 54 | + elif isinstance(arg, pint.Quantity): |
| 55 | + new_args.append(arg.magnitude) |
| 56 | + registries.append(arg._REGISTRY) |
| 57 | + else: |
| 58 | + new_args.append(arg) |
| 59 | + |
| 60 | + for kw, arg in kwargs.items(): |
| 61 | + if isinstance(arg, xr.DataArray): |
| 62 | + if arg.pint.units is not None: |
| 63 | + new_kwargs[kw] = arg.pint.dequantify() |
| 64 | + registries.append(arg.pint.registry) |
| 65 | + else: |
| 66 | + new_kwargs[kw] = arg |
| 67 | + elif isinstance(arg, pint.Quantity): |
| 68 | + new_kwargs[kw] = arg.magnitude |
| 69 | + registries.append(arg._REGISTRY) |
| 70 | + else: |
| 71 | + new_kwargs[kw] = arg |
| 72 | + |
| 73 | + registries = set(registries) |
| 74 | + if len(registries) > 1: |
| 75 | + raise ValueError("Quantity arguments must all belong to the same unit registry") |
| 76 | + elif len(registries) == 0: |
| 77 | + registries = None |
| 78 | + else: |
| 79 | + (registries,) = registries |
| 80 | + return new_args, new_kwargs, registries |
| 81 | + |
| 82 | + |
17 | 83 | def cf_attrs(attrs, name, check_func):
|
18 | 84 | def cf_attrs_decorator(func):
|
19 | 85 | @wraps(func)
|
20 | 86 | def cf_attrs_wrapper(*args, **kwargs):
|
| 87 | + args, kwargs, unit_registry = pint_compat(args, kwargs) |
21 | 88 | rv = func(*args, **kwargs)
|
22 | 89 | attrs_checked = check_func(attrs, args, kwargs)
|
23 | 90 | if isinstance(rv, tuple):
|
| 91 | + rv_updated = [] |
24 | 92 | for (i, da) in enumerate(rv):
|
25 | 93 | add_attrs(da, attrs_checked[i], name[i])
|
| 94 | + rv_updated.append( |
| 95 | + quantify(da, attrs_checked[i], unit_registry=unit_registry) |
| 96 | + ) |
| 97 | + |
| 98 | + rv = tuple(rv_updated) |
| 99 | + |
26 | 100 | else:
|
27 | 101 | add_attrs(rv, attrs_checked, name)
|
| 102 | + rv = quantify(rv, attrs_checked, unit_registry=unit_registry) |
28 | 103 | return rv
|
29 | 104 |
|
30 | 105 | return cf_attrs_wrapper
|
|
0 commit comments