Skip to content

Commit 1a381db

Browse files
committed
gradient is actually complex conjugate of gradient
1 parent 5a394d7 commit 1a381db

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

python/ffsim/variational/orbital_optimization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def fun_jax(generator: np.ndarray) -> float:
170170
def fun(x: np.ndarray) -> tuple[float, np.ndarray]:
171171
generator = _generator_from_parameters(x, norb=norb, real=real)
172172
val, grad = value_and_grad(generator)
173-
return val, _generator_to_parameters(grad, real=real)
173+
# The complex conjugate of the gradient is actually returned
174+
# See https://github.com/jax-ml/jax/issues/4891
175+
return val, _generator_to_parameters(grad.conj(), real=real)
174176

175177
if initial_orbital_rotation is None:
176178
initial_orbital_rotation = np.eye(norb)

tests/python/variational/orbital_optimization_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,6 @@ def test_optimize_orbitals():
8686
np.testing.assert_allclose(energy, -108.58613393502857)
8787
assert np.iscomplexobj(orbital_rotation)
8888
assert len(result.x) == norb**2
89-
assert result.nit <= 8
90-
assert result.nfev <= 11
91-
assert result.njev <= 11
89+
assert result.nit <= 12
90+
assert result.nfev <= 14
91+
assert result.njev <= 14

0 commit comments

Comments
 (0)