Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
bdc477b
sqgturb tracer error fix: delta_t shouldn't be traced, messes up even…
kysolvik May 5, 2025
d5a43d0
Remove static B covariance arg from etkf
kysolvik May 12, 2025
b7bfc2a
Move var3d analysis computation to separate function
kysolvik May 12, 2025
96526aa
New hybrid gain DA method
kysolvik May 12, 2025
e60ec44
Enable nearest obs selection and fix multi-dimensional selection (wit…
kysolvik May 14, 2025
fbdd867
Fixing multi-dim sampling for nonstationary sampler
kysolvik May 14, 2025
a9c5d0a
Observer separate case for non-stationary but regular (same num per t…
kysolvik May 14, 2025
d074eb4
observer fix: use dims not coords att for finding nontime dimensions,…
kysolvik May 21, 2025
a8fc313
Add _rebuild_dataset to properly reconscutrct multi-variable datasets…
kysolvik Jun 5, 2025
75674fc
3D-Var and ETKF: fix H shape, obs_error_sd array support, and use reb…
kysolvik Jun 5, 2025
6c42136
4D-Var and Backprop: Support array obs_error_sds (warning message for…
kysolvik Jun 5, 2025
c941404
Fixes for dimension issues in xarray accessors
kysolvik Jun 5, 2025
c6a3410
Pyqg jax with xarray output
kysolvik Jun 5, 2025
e7ca104
sqturb with proper xarray outputs
kysolvik Jun 5, 2025
86e32e7
Remove defaults from obs_error_sd and analysis_window
kysolvik Jun 5, 2025
943fce4
Var3d test fix to include obs_error_sd, which is now required arg for…
kysolvik Jun 5, 2025
6eb5aeb
Observation sampling was updated to fix multi-dim problems, broke the…
kysolvik Jun 5, 2025
b2930fe
Sqgturb generate produces exactly n_steps now, not n_steps+1, in orde…
kysolvik Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dabench/dacycler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from ._etkf import ETKF
from ._var4d_backprop import Var4DBackprop
from ._var4d import Var4D
from ._hybrid_gain import HybridGain

__all__ = [
'DACycler',
'Var3D',
'ETKF',
'Var4DBackprop',
'Var4D',
'HybridGain'
]
36 changes: 31 additions & 5 deletions dabench/dacycler/_dacycler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ def _calc_default_B(self) -> jax.Array:
"""If B is not provided, identity matrix with shape (system_dim, system_dim."""
return jnp.identity(self.system_dim)

def _rebuild_dataset(self,
xb: XarrayDatasetLike,
xa: ArrayLike,
) -> XarrayDatasetLike:
xb_as_array = xb.to_array()
xa = xa.reshape(tuple(xb_as_array.sizes[s] for s in xb_as_array.sizes))
xb_as_array.values = xa
xa_ds = xb_as_array.to_dataset(dim='variable')
return xa_ds

def _step_forecast(self,
xa: XarrayDatasetLike,
n_steps: int = 1
Expand Down Expand Up @@ -188,8 +198,8 @@ def cycle(self,
start_time: float | np.datetime64,
obs_vector: XarrayDatasetLike,
n_cycles: int,
obs_error_sd: float | ArrayLike | None = None,
analysis_window: float = 0.2,
obs_error_sd: float | ArrayLike,
analysis_window: float,
analysis_time_in_window: float | None = None,
return_forecast: bool = False
) -> XarrayDatasetLike:
Expand All @@ -201,6 +211,10 @@ def cycle(self,
obs_vector: Observations vector.
n_cycles: Number of analysis cycles to run, each of length
analysis_window.
obs_error_sd: Estimate observation error standard deviation,
used for calculating observation covariance matrix (R).
If float, all observations will have same estimated error.
If ArrayLike, must be of size system_dim.
analysis_window: Time window from which to gather
observations for DA Cycle.
analysis_time_in_window: Where within analysis_window
Expand All @@ -217,8 +231,21 @@ def cycle(self,
self._observed_vars = obs_vector['variable'].values
self._data_vars = list(input_state.data_vars)

if obs_error_sd is None:
obs_error_sd = obs_vector.error_sd
# NOTE: Consider removing this. It may cause problems if the obs_vector
# error_sd is provided as a array of size obs_dim.
# if obs_error_sd is None:
# obs_error_sd = obs_vector.error_sd
# Check if obs_error_sd is array
if jnp.isscalar(obs_error_sd):
self._scalar_obs_error = True
elif len(obs_error_sd) == self.system_dim:
obs_error_sd = jnp.array(obs_error_sd)
self._scalar_obs_error = False
else:
raise ValueError((
'obs_error_sd must be either scalar or array with length'
'system_dim. Currently is: {}'.format(obs_error_sd)
))

self.analysis_window = analysis_window

Expand All @@ -241,7 +268,6 @@ def cycle(self,
start_time,
analysis_window,
n_cycles)


if self.steps_per_window is None:
self.steps_per_window = round(analysis_window/self.delta_t) + 1
Expand Down
23 changes: 9 additions & 14 deletions dabench/dacycler/_etkf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ class ETKF(dacycler.DACycler):
system_dim: System dimension.
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
B: Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
R: Observation error covariance matrix. Shape
(obs_dim, obs_dim). If not provided, will be calculated
automatically.
Expand All @@ -45,7 +42,6 @@ def __init__(self,
system_dim: int,
delta_t: float,
model_obj: Model,
B: ArrayLike | None = None,
R: ArrayLike | None = None,
H: ArrayLike | None = None,
h: Callable | None = None,
Expand All @@ -59,7 +55,7 @@ def __init__(self,
super().__init__(system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
B=B, R=R, H=H, h=h)
R=R, H=H, h=h)

def _step_forecast(self,
Xa: XarrayDatasetLike,
Expand Down Expand Up @@ -170,8 +166,7 @@ def _cycle_obsop(self,
obs_loc_mask: ArrayLike,
H: ArrayLike | None = None,
h: Callable | None = None,
R: ArrayLike | None = None,
B: ArrayLike | None = None
R: ArrayLike | None = None
) -> XarrayDatasetLike:
if H is None and h is None:
if self.H is None:
Expand All @@ -183,14 +178,14 @@ def _cycle_obsop(self,
H = self.H
if R is None:
if self.R is None:
R = self._calc_default_R(obs_values, self.obs_error_sd)
if self._scalar_obs_error:
R = self._calc_default_R(obs_values, self.obs_error_sd)
else:
R = self._calc_default_R(
obs_values,
self.obs_error_sd[obs_loc_indices.flatten()])
else:
R = self.R
if B is None:
if self.B is None:
B = self._calc_default_B()
else:
B = self.B

Xb = Xb_ds.to_stacked_array('system',['ensemble']).data.T
n_sys, n_ens = Xb.shape
Expand All @@ -210,4 +205,4 @@ def _cycle_obsop(self,
R=R,
rho=self.multiplicative_inflation)

return Xb_ds.assign(x=(['ensemble','i'], Xa.T))
return self._rebuild_dataset(Xb_ds, Xa.T)
156 changes: 156 additions & 0 deletions dabench/dacycler/_hybrid_gain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Class for Hybrid Gain (ETKF + 3DVar) Data Assimilation"""

import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy import linalg
import xarray as xr
import xarray_jax as xj
from typing import Callable

from dabench import dacycler
from dabench.model import Model


# For typing
ArrayLike = np.ndarray | jax.Array
XarrayDatasetLike = xr.Dataset | xj.XjDataset

class HybridGain(dacycler.DACycler):
"""HybridGain DA, combining ETKF with 3DVar

Args:
system_dim: System dimension.
delta_t: The timestep of the model (assumed uniform)
model_obj: Forecast model object.
B: Initial / static background error covariance. Shape:
(system_dim, system_dim). If not provided, will be calculated
automatically.
R: Observation error covariance matrix. Shape
(obs_dim, obs_dim). If not provided, will be calculated
automatically.
H: Observation operator with shape: (obs_dim, system_dim).
If not provided will be calculated automatically.
h: Optional observation operator as function. More flexible
(allows for more complex observation operator). Default is None.
alpha: Weight for 3DVar DA analysis. If 0.0, runs pure ETKF. If 1.0,
runs pure 3DVar. Default is 0.2.
ensemble_dim: Number of ensemble instances for ETKF. Default is
4. Higher ensemble_dim increases accuracy but has performance cost.
multiplicative_inflation: Scaling factor by which to multiply ensemble
deviation. Default is 1.0 (no inflation).
"""
_in_4d: bool = False
_uses_ensemble: bool = True

def __init__(self,
system_dim: int,
delta_t: float,
model_obj: Model,
B: ArrayLike | None = None,
R: ArrayLike | None = None,
H: ArrayLike | None = None,
h: Callable | None = None,
alpha: float = 0.2,
ensemble_dim: int = 4,
multiplicative_inflation: float = 1.0
):

self.ensemble_dim = ensemble_dim
self.multiplicative_inflation = multiplicative_inflation
self.alpha = alpha

# Create ETKF DA Cycler
self._etkf_da = dacycler.ETKF(
system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
R=R,
H=H,
h=h,
ensemble_dim=ensemble_dim,
multiplicative_inflation=multiplicative_inflation
)
# Create 3D-Var DA Cycler
self._var3d_da = dacycler.Var3D(
system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
R=R,
H=H,
h=h,
B=B
)

super().__init__(system_dim=system_dim,
delta_t=delta_t,
model_obj=model_obj,
B=B, R=R, H=H, h=h)

def _step_forecast(self,
Xa: XarrayDatasetLike,
n_steps: int = 1
) -> XarrayDatasetLike:
"""Ensemble method needs a slightly different _step_forecast method"""
return self._etkf_da._step_forecast(Xa, n_steps)

def _cycle_obsop(self,
Xb_ds: XarrayDatasetLike,
obs_values: ArrayLike,
obs_loc_indices: ArrayLike,
obs_time_mask: ArrayLike,
obs_loc_mask: ArrayLike,
H: ArrayLike | None = None,
h: Callable | None = None,
R: ArrayLike | None = None,
B: ArrayLike | None = None
) -> XarrayDatasetLike:
if H is None and h is None:
if self.H is None:
if self.h is None:
H = self._calc_default_H(obs_values, obs_loc_indices)
else:
h = self.h
else:
H = self.H
if R is None:
if self.R is None:
R = self._calc_default_R(obs_values, self.obs_error_sd)
else:
R = self.R
if B is None:
if self.B is None:
B = self._calc_default_B()
else:
B = self.B

Xb = Xb_ds.to_stacked_array('system',['ensemble']).data.T
n_sys, n_ens = Xb.shape
assert n_ens == self.ensemble_dim, (
'cycle:: model_forecast must have dimension {}x{}').format(
self.ensemble_dim, self.system_dim)

# Apply obs masks to H
H = jnp.where(obs_time_mask.flatten(), H.T, 0).T
H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T

# Compute ETKF analysis
Xa_etkf = self._etkf_da._compute_analysis(Xb=Xb,
Y=obs_values,
H=H,
h=h,
R=R,
rho=self.multiplicative_inflation)

# Compute Var3D Analysis
xa_var3d = self._var3d_da._compute_analysis(xb=jnp.mean(Xb, axis=1).flatten(),
y=obs_values.flatten(),
H=H,
B=B,
Rinv=jnp.linalg.inv(R))

xa_etkf_mean = jnp.mean(Xa_etkf, axis=1)
xa_final = self.alpha*xa_var3d + (1-self.alpha)*xa_etkf_mean
Xa_final = Xa_etkf.T - (xa_etkf_mean - xa_final)

return Xb_ds.assign(x=(['ensemble','i'], Xa_final))
51 changes: 37 additions & 14 deletions dabench/dacycler/_var3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,27 @@ def __init__(self,
model_obj=model_obj,
B=B, R=R, H=H, h=h)


def _compute_analysis(self,
xb,
y,
B,
H,
Rinv
):
# 'preconditioning with B'
xdim = xb.size
I = jnp.identity(xdim)
BHt = jnp.dot(B, H.T)
BHtRinv = jnp.dot(BHt, Rinv)
A = I + jnp.dot(BHtRinv, H)
b1 = xb + jnp.dot(BHtRinv, y)

# Use minimization algorithm to minimize cost function:
xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb.astype(float), tol=1e-05,
maxiter=1000)
return xa

def _cycle_obsop(self,
xb_ds: XarrayDatasetLike,
obs_values: ArrayLike,
Expand All @@ -72,7 +93,12 @@ def _cycle_obsop(self,
H = self.H
if R is None:
if self.R is None:
R = self._calc_default_R(obs_values, self.obs_error_sd)
if self._scalar_obs_error:
R = self._calc_default_R(obs_values, self.obs_error_sd)
else:
R = self._calc_default_R(
obs_values,
self.obs_error_sd[obs_loc_indices.flatten()])
else:
R = self.R
if B is None:
Expand All @@ -85,22 +111,19 @@ def _cycle_obsop(self,
y = obs_values.flatten()

# Apply masks to H
H = jnp.where(obs_time_mask.flatten(), H.T, 0).T
H = jnp.where(jnp.tile(obs_time_mask.flatten(), obs_loc_mask.shape[0]), H.T, 0).T
H = jnp.where(obs_loc_mask.flatten(), H.T, 0).T

# Set parameters
xdim = xb.size # Size or get one of the shape params?
Rinv = jnp.linalg.inv(R)

# 'preconditioning with B'
I = jnp.identity(xdim)
BHt = jnp.dot(B, H.T)
BHtRinv = jnp.dot(BHt, Rinv)
A = I + jnp.dot(BHtRinv, H)
b1 = xb + jnp.dot(BHtRinv, y)

# Use minimization algorithm to minimize cost function:
xa, ierr = jscipy.sparse.linalg.cg(A, b1, x0=xb, tol=1e-05,
maxiter=1000)
xa = self._compute_analysis(
xb,
y,
B,
H,
Rinv,
)

return xb_ds.assign(x=(xb_ds.dims, xa.T))
# Reshape
return self._rebuild_dataset(xb_ds, xa)
Loading