Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 42 additions & 32 deletions python/ffsim/variational/orbital_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,43 @@
jax.config.update("jax_enable_x64", True)


def _orbital_rotation_from_parameters_jax(
def _generator_from_parameters(
params: np.ndarray, norb: int, real: bool = False
) -> jax.Array:
"""Construct an orbital rotation from parameters.

Converts a real-valued parameter vector to an orbital rotation. The parameter vector
contains non-redundant real and imaginary parts of the elements of the matrix
logarithm of the orbital rotation matrix.

Args:
params: The real-valued parameters.
norb: The number of spatial orbitals, which gives the width and height of the
orbital rotation matrix.
real: Whether the parameter vector describes a real-valued orbital rotation.

Returns:
The orbital rotation.
"""
generator = jnp.zeros((norb, norb), dtype=float if real else complex)
) -> np.ndarray:
generator = np.zeros((norb, norb), dtype=float if real else complex)
n_triu = norb * (norb - 1) // 2
if not real:
# imaginary part
rows, cols = jnp.triu_indices(norb)
rows, cols = np.triu_indices(norb)
vals = 1j * params[n_triu:]
generator = generator.at[rows, cols].set(vals)
generator = generator.at[cols, rows].set(vals)
generator[rows, cols] = vals
generator[cols, rows] = vals
# real part
vals = params[:n_triu]
rows, cols = jnp.triu_indices(norb, k=1)
generator = generator.at[rows, cols].add(vals)
# the subtract method is only available in JAX starting with Python 3.10
generator = generator.at[cols, rows].add(-vals)
return jax.scipy.linalg.expm(generator)
rows, cols = np.triu_indices(norb, k=1)
generator[rows, cols] += vals
generator[cols, rows] -= vals
return generator


def _generator_to_parameters(mat: np.ndarray, real: bool = False) -> np.ndarray:
if real and np.iscomplexobj(mat):
raise TypeError(
"real was set to True, but the orbital rotation has a complex data type. "
"Try passing an orbital rotation with a real-valued data type, or else "
"set real=False."
)
norb, _ = mat.shape
triu_indices = np.triu_indices(norb, k=1)
n_triu = norb * (norb - 1) // 2
params = np.zeros(n_triu if real else norb**2)
# real part
params[:n_triu] = mat[triu_indices].real
# imaginary part
if not real:
triu_indices = np.triu_indices(norb)
params[n_triu:] = mat[triu_indices].imag
return params


def optimize_orbitals(
Expand Down Expand Up @@ -136,10 +140,9 @@ def optimize_orbitals(
one_body_tensor = jnp.array(hamiltonian.one_body_tensor)
two_body_tensor = jnp.array(hamiltonian.two_body_tensor)

def fun(x: np.ndarray):
orbital_rotation = _orbital_rotation_from_parameters_jax(
x, norb=norb, real=real
)
def fun_jax(generator: np.ndarray) -> float:
generator = 0.5 * (generator - generator.T.conj())
orbital_rotation = jax.scipy.linalg.expm(generator)
one_rdm_rotated = contract(
"ab,Aa,Bb->AB",
one_rdm,
Expand All @@ -162,13 +165,20 @@ def fun(x: np.ndarray):
+ 0.5 * contract("abcd,abcd->", two_body_tensor, two_rdm_rotated)
).real

value_and_grad = jax.value_and_grad(fun)
value_and_grad = jax.value_and_grad(fun_jax)

def fun(x: np.ndarray) -> tuple[float, np.ndarray]:
generator = _generator_from_parameters(x, norb=norb, real=real)
val, grad = value_and_grad(generator)
# The complex conjugate of the gradient is actually returned
# See https://github.com/jax-ml/jax/issues/4891
return val, _generator_to_parameters(grad.conj(), real=real)

if initial_orbital_rotation is None:
initial_orbital_rotation = np.eye(norb)

result = scipy.optimize.minimize(
value_and_grad,
fun,
orbital_rotation_to_parameters(initial_orbital_rotation, real=real),
method=method,
jac=True,
Expand Down
12 changes: 6 additions & 6 deletions tests/python/variational/orbital_optimization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def test_optimize_orbitals():
np.testing.assert_allclose(energy, -108.58613393502857)
assert np.isrealobj(orbital_rotation)
assert len(result.x) == norb * (norb - 1) // 2
assert result.nit <= 7
assert result.nfev <= 9
assert result.njev <= 9
assert result.nit <= 8
assert result.nfev <= 10
assert result.njev <= 10

# Optimize orbitals with complex rotations
orbital_rotation, result = ffsim.optimize_orbitals(
Expand All @@ -86,6 +86,6 @@ def test_optimize_orbitals():
np.testing.assert_allclose(energy, -108.58613393502857)
assert np.iscomplexobj(orbital_rotation)
assert len(result.x) == norb**2
assert result.nit <= 8
assert result.nfev <= 11
assert result.njev <= 11
assert result.nit <= 12
assert result.nfev <= 14
assert result.njev <= 14
Loading