Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ tests = [
"numpy<2", # Required currently due to lack of Numpy v2 compatible pyssht release
"pytest",
"pytest-cov",
"pytest-rerunfailures",
"so3",
"pyssht",
]
Expand Down
86 changes: 79 additions & 7 deletions s2fft/utils/signal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from s2fft.sampling import s2_samples as samples
from s2fft.sampling import so3_samples as wigner_samples

TYPE_CHECKING = False
if TYPE_CHECKING:
import jax


def complex_normal(
rng: np.random.Generator,
Expand Down Expand Up @@ -74,6 +78,7 @@ def generate_flm(
spin: int = 0,
reality: bool = False,
using_torch: bool = False,
size: tuple[int, ...] | int | None = None,
) -> np.ndarray | torch.Tensor:
r"""
Generate a 2D set of random harmonic coefficients.
Expand All @@ -94,29 +99,39 @@ def generate_flm(

using_torch (bool, optional): Desired frontend functionality. Defaults to False.

size (tuple[int, ...] | int | None, optional): Shape of realisations.

Returns:
np.ndarray: Random set of spherical harmonic coefficients.

"""
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
# always turn size into a tuple of int
if size is None:
size = ()
elif isinstance(size, int):
size = (size,)
elif not (isinstance(size, tuple) and all(isinstance(_, int) for _ in size)):
raise TypeError("size must be int or tuple of int")

flm = np.zeros((*size, *samples.flm_shape(L)), dtype=np.complex128)
min_el = max(L_lower, abs(spin))
# m = 0 coefficients are always real
flm[min_el:L, L - 1] = rng.standard_normal(L - min_el)
flm[..., min_el:L, L - 1] = rng.standard_normal((*size, L - min_el))
# Construct arrays of m and el indices for entries in flm corresponding to complex-
# valued coefficients (m > 0)
el_indices, m_indices = complex_el_and_m_indices(L, min_el)
len_indices = len(m_indices)
rand_size = (*size, len(m_indices))
# Generate independent complex coefficients for positive m
flm[el_indices, L - 1 + m_indices] = complex_normal(rng, len_indices, var=2)
flm[..., el_indices, L - 1 + m_indices] = complex_normal(rng, rand_size, var=2)
if reality:
# Real-valued signal so set complex coefficients for negative m using conjugate
# symmetry such that flm[el, L - 1 - m] = (-1)**m * flm[el, L - 1 + m].conj
flm[el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
flm[el_indices, L - 1 + m_indices].conj()
flm[..., el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
flm[..., el_indices, L - 1 + m_indices].conj()
)
else:
# Non-real signal so generate independent complex coefficients for negative m
flm[el_indices, L - 1 - m_indices] = complex_normal(rng, len_indices, var=2)
flm[..., el_indices, L - 1 - m_indices] = complex_normal(rng, rand_size, var=2)
return torch.from_numpy(flm) if using_torch else flm


Expand Down Expand Up @@ -199,3 +214,60 @@ def generate_flmn(
rng, len_indices, var=2
)
return torch.from_numpy(flmn) if using_torch else flmn


def generate_flm_from_spectra(
rng: np.random.Generator,
spectra: np.ndarray | jax.Array,
) -> np.ndarray | jax.Array:
r"""
Generate a stack of random harmonic coefficients from power spectra.

The input power spectra must be a stack of shape *(K, K, L)* where
*K* is the number of fields to be sampled, and *L* is the harmonic
band-limit.

Args:
rng (Generator): Random number generator.

spectra (np.ndarray | jax.Array): Stack of angular power spectra.

Returns:
np.ndarray | jax.Array: A stack of random spherical harmonic
coefficients with the given power spectra.

"""
# get the Array API namespace from spectra
xp = spectra.__array_namespace__()

# check input
if spectra.ndim != 3 or spectra.shape[0] != spectra.shape[1]:
raise ValueError("shape of spectra must be (K, K, L)")

# K is the number of fields, L is the band limit
*_, K, L = spectra.shape

# permute shape (K, K, L) -> (L, K, K)
spectra = xp.permute_dims(spectra, (2, 0, 1))

# SVD for matrix square root
# not using cholesky() here because matrix may be semi-definite
# divides spectra by 2 for correct amplitude
u, s, vh = xp.linalg.svd(spectra / 2, full_matrices=False)

# compute the matrix square root for sampling
a = u @ (xp.sqrt(s[..., None]) * vh)

# permute shape (L, K, K) -> (K, K, L)
a = xp.permute_dims(a, (1, 2, 0))

# sample the random coefficients
# always use reality=True, this could be real fields or E/B modes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# always use reality=True, this could be real fields or E/B modes
# always use reality=True, this could be real fields or E/B modes

'E/B modes' should ideally have some clarification here - from a bit of searching it looks like this may be a common term in cosmology settings, but as someone without a cosmology background I don't know what it means and I'd guess I'm representative of this in other potential users and developers from outside cosmology!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense! I added a more generic comment now.

In any case, my idea was to only tackle flm with the reality condition here. Sampling flm without that condition is trickier due to the multiple spectra that are required for each field. Even then, the easiest way is to use this "E/B" decomposition into two fields with the reality symmetry, and later assemble the complex fields from there. Ideally, there would be a generic helper function for this conversion (complex <-> E/B), at which point the functionality could be added to generate_flm_from_spectra() as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ntessore , I agree that sounds like a good approach, i.e. considering E/B spectra, since that is the typical convention in cosmology. Adding some more documentation on E/B fields as @matt-graham suggests is also a good idea.

# shape of flm is (K, L, M)
flm = generate_flm(rng, L, reality=True, size=K)

# compute the matrix multiplication by hand, because we have a mix of
# contraction (dim=K) and broadcasting (dim=L)
flm = (a[..., None] * flm).sum(axis=-3)

return flm
118 changes: 118 additions & 0 deletions tests/test_signal_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax.test_util import check_grads

import s2fft
import s2fft.sampling as smp
Expand Down Expand Up @@ -55,6 +57,14 @@ def check_flm_conjugate_symmetry(flm, L, min_el):
assert flm[el, L - 1 - m] == (-1) ** m * flm[el, L - 1 + m].conj()


def check_flm_unequal(flm1, flm2, L, min_el):
"""assert that two passed flm are elementwise unequal"""
for el in range(L):
for m in range(L):
if not (el < min_el or m > el):
assert flm1[el, L - 1 + m] != flm2[el, L - 1 - m]


@pytest.mark.parametrize("L", L_values_to_test)
@pytest.mark.parametrize("L_lower", L_lower_to_test)
@pytest.mark.parametrize("spin", spin_to_test)
Expand All @@ -76,6 +86,24 @@ def test_generate_flm(rng, L, L_lower, spin, reality):
assert np.allclose(f_complex.real, f_real)


@pytest.mark.parametrize("L", L_values_to_test)
@pytest.mark.parametrize("L_lower", L_lower_to_test)
@pytest.mark.parametrize("spin", spin_to_test)
@pytest.mark.parametrize("reality", reality_values_to_test)
def test_generate_flm_size(rng, L, L_lower, spin, reality):
if reality and spin != 0:
pytest.skip("Reality only valid for scalar fields (spin=0).")

flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=2)
assert flm.shape == (2,) + smp.s2_samples.flm_shape(L)
check_flm_zeros(flm[0], L, max(L_lower, abs(spin)))
check_flm_zeros(flm[1], L, max(L_lower, abs(spin)))
check_flm_unequal(flm[0], flm[1], L, max(L_lower, abs(spin)))

flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=(3, 4))
assert flm.shape == (3, 4) + smp.s2_samples.flm_shape(L)


def check_flmn_zeros(flmn, L, N, L_lower):
for n in range(-N + 1, N):
min_el = max(L_lower, abs(n))
Expand Down Expand Up @@ -117,3 +145,93 @@ def test_generate_flmn(rng, L, N, L_lower, reality):
assert np.allclose(f_complex.imag, 0)
f_real = s2fft.wigner.inverse(flmn, L, N, reality=True, L_lower=L_lower)
assert np.allclose(f_complex.real, f_real)


def gaussian_covariance(spectra):
"""Gaussian covariance for a stack of spectra.

If the shape of *spectra* is *(K, K, L)*, the shape of the
covariance is *(L, C, C)*, where ``C = K * (K + 1) // 2``
is the number of independent spectra.

"""
_, K, L = spectra.shape
row, col = np.tril_indices(K)
cov = np.zeros((L, row.size, col.size))
ell = np.arange(L)
for i, j in np.ndindex(row.size, col.size):
cov[:, i, j] = (
spectra[row[i], row[j]] * spectra[col[i], col[j]]
+ spectra[row[i], col[j]] * spectra[col[i], row[j]]
) / (2 * ell + 1)
return cov


@pytest.mark.flaky
@pytest.mark.parametrize("L", L_values_to_test)
@pytest.mark.parametrize("xp", [np, jnp])
def test_generate_flm_from_spectra(rng, L, xp):
# number of fields to generate
K = 4

# correlation matrix for fields, applied to all ell
corr = xp.asarray(
[
[1.0, 0.1, -0.1, 0.1],
[0.1, 1.0, 0.1, -0.1],
[-0.1, 0.1, 1.0, 0.1],
[0.1, -0.1, 0.1, 1.0],
],
)

ell = xp.arange(L)

# auto-spectra are power laws
powers = xp.arange(1, K + 1)
auto = 1 / (2 * ell + 1) ** powers[:, None]

# compute the spectra from auto and corr
spectra = xp.sqrt(auto[:, None, :] * auto[None, :, :]) * corr[:, :, None]
assert spectra.shape == (K, K, L)

# generate random flm from spectra
flm = s2fft.utils.signal_generator.generate_flm_from_spectra(rng, spectra)
assert flm.shape == (K, L, 2 * L - 1)

# compute the realised spectra
re, im = flm.real, flm.imag
result = (
re[None, :, :, :] * re[:, None, :, :] + im[None, :, :, :] * im[:, None, :, :]
)
result = result.sum(axis=-1) / (2 * ell + 1)

# compute covariance of sampled spectra
cov = gaussian_covariance(spectra)

# data vector, remove duplicate entries, and put L dim first
x = result - spectra
x = x[np.tril_indices(K)]
x = x.T

# compute chi2/n of realised spectra
y = xp.linalg.solve(cov, x[..., None])[..., 0]
n = x.size
chi2_n = (x * y).sum() / n

# make sure chi2/n is as expected
sigma = np.sqrt(2 / n)
assert np.fabs(chi2_n - 1.0) < 3 * sigma


@pytest.mark.parametrize("L", L_values_to_test)
def test_generate_flm_from_spectra_grads(L):
# fixed set of power spectra
ell = jnp.arange(L)
cl = 1 / (2 * ell + 1)
spectra = cl.reshape(1, 1, L)

def func(x):
rng = np.random.default_rng(42)
return s2fft.utils.signal_generator.generate_flm_from_spectra(rng, x)

check_grads(func, (spectra,), 1)
Loading