diff --git a/python/ffsim/variational/orbital_optimization.py b/python/ffsim/variational/orbital_optimization.py index 7439c1d94..b30a0e0f6 100644 --- a/python/ffsim/variational/orbital_optimization.py +++ b/python/ffsim/variational/orbital_optimization.py @@ -26,39 +26,43 @@ jax.config.update("jax_enable_x64", True) -def _orbital_rotation_from_parameters_jax( +def _generator_from_parameters( params: np.ndarray, norb: int, real: bool = False -) -> jax.Array: - """Construct an orbital rotation from parameters. - - Converts a real-valued parameter vector to an orbital rotation. The parameter vector - contains non-redundant real and imaginary parts of the elements of the matrix - logarithm of the orbital rotation matrix. - - Args: - params: The real-valued parameters. - norb: The number of spatial orbitals, which gives the width and height of the - orbital rotation matrix. - real: Whether the parameter vector describes a real-valued orbital rotation. - - Returns: - The orbital rotation. - """ - generator = jnp.zeros((norb, norb), dtype=float if real else complex) +) -> np.ndarray: + generator = np.zeros((norb, norb), dtype=float if real else complex) n_triu = norb * (norb - 1) // 2 if not real: # imaginary part - rows, cols = jnp.triu_indices(norb) + rows, cols = np.triu_indices(norb) vals = 1j * params[n_triu:] - generator = generator.at[rows, cols].set(vals) - generator = generator.at[cols, rows].set(vals) + generator[rows, cols] = vals + generator[cols, rows] = vals # real part vals = params[:n_triu] - rows, cols = jnp.triu_indices(norb, k=1) - generator = generator.at[rows, cols].add(vals) - # the subtract method is only available in JAX starting with Python 3.10 - generator = generator.at[cols, rows].add(-vals) - return jax.scipy.linalg.expm(generator) + rows, cols = np.triu_indices(norb, k=1) + generator[rows, cols] += vals + generator[cols, rows] -= vals + return generator + + +def _generator_to_parameters(mat: np.ndarray, real: bool = False) -> np.ndarray: + if real and np.iscomplexobj(mat): + raise TypeError( + "real was set to True, but the orbital rotation has a complex data type. " + "Try passing an orbital rotation with a real-valued data type, or else " + "set real=False." + ) + norb, _ = mat.shape + triu_indices = np.triu_indices(norb, k=1) + n_triu = norb * (norb - 1) // 2 + params = np.zeros(n_triu if real else norb**2) + # real part + params[:n_triu] = mat[triu_indices].real + # imaginary part + if not real: + triu_indices = np.triu_indices(norb) + params[n_triu:] = mat[triu_indices].imag + return params def optimize_orbitals( @@ -136,10 +140,9 @@ def optimize_orbitals( one_body_tensor = jnp.array(hamiltonian.one_body_tensor) two_body_tensor = jnp.array(hamiltonian.two_body_tensor) - def fun(x: np.ndarray): - orbital_rotation = _orbital_rotation_from_parameters_jax( - x, norb=norb, real=real - ) + def fun_jax(generator: np.ndarray) -> float: + generator = 0.5 * (generator - generator.T.conj()) + orbital_rotation = jax.scipy.linalg.expm(generator) one_rdm_rotated = contract( "ab,Aa,Bb->AB", one_rdm, @@ -162,13 +165,20 @@ def fun(x: np.ndarray): + 0.5 * contract("abcd,abcd->", two_body_tensor, two_rdm_rotated) ).real - value_and_grad = jax.value_and_grad(fun) + value_and_grad = jax.value_and_grad(fun_jax) + + def fun(x: np.ndarray) -> tuple[float, np.ndarray]: + generator = _generator_from_parameters(x, norb=norb, real=real) + val, grad = value_and_grad(generator) + # The complex conjugate of the gradient is actually returned + # See https://github.com/jax-ml/jax/issues/4891 + return val, _generator_to_parameters(grad.conj(), real=real) if initial_orbital_rotation is None: initial_orbital_rotation = np.eye(norb) result = scipy.optimize.minimize( - value_and_grad, + fun, orbital_rotation_to_parameters(initial_orbital_rotation, real=real), method=method, jac=True, diff --git a/tests/python/variational/orbital_optimization_test.py b/tests/python/variational/orbital_optimization_test.py index f5b966d73..34d5627ef 100644 --- a/tests/python/variational/orbital_optimization_test.py +++ b/tests/python/variational/orbital_optimization_test.py @@ -64,9 +64,9 @@ def test_optimize_orbitals(): np.testing.assert_allclose(energy, -108.58613393502857) assert np.isrealobj(orbital_rotation) assert len(result.x) == norb * (norb - 1) // 2 - assert result.nit <= 7 - assert result.nfev <= 9 - assert result.njev <= 9 + assert result.nit <= 8 + assert result.nfev <= 10 + assert result.njev <= 10 # Optimize orbitals with complex rotations orbital_rotation, result = ffsim.optimize_orbitals( @@ -86,6 +86,6 @@ def test_optimize_orbitals(): np.testing.assert_allclose(energy, -108.58613393502857) assert np.iscomplexobj(orbital_rotation) assert len(result.x) == norb**2 - assert result.nit <= 8 - assert result.nfev <= 11 - assert result.njev <= 11 + assert result.nit <= 12 + assert result.nfev <= 14 + assert result.njev <= 14