Skip to content
113 changes: 107 additions & 6 deletions ext/LinearSolveMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
module LinearSolveMooncakeExt

using Mooncake
using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!
using Mooncake: @from_chainrules, MinimalCtx, ReverseMode, NoRData, increment!!, @is_primitive, primal, zero_fcodual, CoDual, rdata, fdata
using LinearSolve: LinearSolve, SciMLLinearSolveAlgorithm, init, solve!, LinearProblem,
LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver,
defaultalg_adjoint_eval, solve
LinearCache, AbstractKrylovSubspaceMethod, DefaultLinearSolver, LinearSolveAdjoint,
defaultalg_adjoint_eval, solve, LUFactorization
using LinearSolve.LinearAlgebra
using LazyArrays: @~, BroadcastArray
using SciMLBase

@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve), LinearProblem, Nothing} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(SciMLBase.solve),LinearProblem,Nothing} true ReverseMode
@from_chainrules MinimalCtx Tuple{
typeof(SciMLBase.solve), LinearProblem, SciMLLinearSolveAlgorithm} true ReverseMode
typeof(SciMLBase.solve),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{
Type{<:LinearProblem}, AbstractMatrix, AbstractVector, SciMLBase.NullParameters} true ReverseMode
Type{<:LinearProblem},AbstractMatrix,AbstractVector,SciMLBase.NullParameters} true ReverseMode

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearProblem)
f.data.A .+= t.A
Expand All @@ -29,4 +30,104 @@ function Mooncake.to_cr_tangent(x::Mooncake.PossiblyUninitTangent{T}) where {T}
end
end

function Mooncake.increment_and_get_rdata!(f, r::NoRData, t::LinearCache)
f.fields.A .+= t.A
f.fields.b .+= t.b
f.fields.u .+= t.u

return NoRData()
end

# rrules for LinearCache
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,SciMLLinearSolveAlgorithm} true ReverseMode
@from_chainrules MinimalCtx Tuple{typeof(init),LinearProblem,Nothing} true ReverseMode

# rrules for solve!
# NOTE - Avoid Mooncake.prepare_gradient_cache, only use Mooncake.prepare_pullback_cache (and therefore Mooncake.value_and_pullback!!)
# calling Mooncake.prepare_gradient_cache for functions with solve! will activate unsupported Adjoint case exception for below rrules
# This because in Mooncake.prepare_gradient_cache we reset stacks + state by passing in zero gradient in the reverse pass once.
# However, if one has a valid cache then they can directly use Mooncake.value_and_gradient!!.

@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,SciMLLinearSolveAlgorithm,Vararg}
@is_primitive MinimalCtx Tuple{typeof(SciMLBase.solve!),LinearCache,Nothing,Vararg}

function Mooncake.rrule!!(sig::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{Nothing}, args::Vararg{Any,N}; kwargs...) where {N}
cache = primal(_cache)
assump = OperatorAssumptions()
_alg.x = defaultalg(cache.A, cache.b, assump)
Mooncake.rrule!!(sig, _cache, _alg, args...; kwargs...)
end

function Mooncake.rrule!!(::CoDual{typeof(SciMLBase.solve!)}, _cache::CoDual{<:LinearSolve.LinearCache}, _alg::CoDual{<:SciMLLinearSolveAlgorithm}, args::Vararg{Any,N}; alias_A=zero_fcodual(LinearSolve.default_alias_A(
_alg.x, _cache.x.A, _cache.x.b)), kwargs...) where {N}

cache = primal(_cache)
alg = primal(_alg)
_args = map(primal, args)

(; A, b, sensealg) = cache
A_orig = copy(A)
b_orig = copy(b)

@assert sensealg isa LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# logic behind caching `A` and `b` for the reverse pass based on rrule above for SciMLBase.solve
if sensealg.linsolve === missing
if !(alg isa LinearSolve.AbstractFactorization || alg isa LinearSolve.AbstractKrylovSubspaceMethod ||
alg isa LinearSolve.DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A
end
else
A_ = deepcopy(A)
end

sol = zero_fcodual(solve!(cache))
cache.A = A_orig
cache.b = b_orig

function solve!_adjoint(::NoRData)
∂∅ = NoRData()
cachenew = init(LinearProblem(cache.A, cache.b), cache.alg, _args...; kwargs...)
new_sol = solve!(cachenew)
∂u = sol.dx.data.u

if sensealg.linsolve === missing
λ = if cache.cacheval isa Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple && cache.cacheval[1] isa Factorization
first(cache.cacheval)' \ ∂u
elseif alg isa AbstractKrylovSubspaceMethod
invprob = LinearProblem(adjoint(cache.A), ∂u)
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since alg wasn't defined before my commit this clearly wasn't tested 😅 and we need to make sure this branch works.

Copy link
Member Author

@AstitvaAggarwal AstitvaAggarwal Nov 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry actually I was just testing where some mutations occur internally and actually fixed to just LU for the pullback for debugging. I had tested it for generic alg locally as well and all tests did pass, i just forgot to switch out the LU to user chosen alg.
Thanks again for looking out for this.

elseif alg isa DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = LinearProblem(adjoint(A_), ∂u) # We cached `A`
λ = solve(
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

tu = adjoint(new_sol.u)
∂A = BroadcastArray(@~ .-(λ .* tu))
∂b = λ

if (iszero(∂b) || iszero(∂A)) && !iszero(tu)
error("Adjoint case currently not handled. Instead of using `solve!(cache); s1 = copy(cache.u) ...`, use `sol = solve!(cache); s1 = copy(sol.u)`.")
end

fdata(_cache.dx).fields.A .+= ∂A
fdata(_cache.dx).fields.b .+= ∂b
fdata(_cache.dx).fields.u .+= ∂u

# rdata for cache is a struct with NoRdata field values
return (∂∅, rdata(_cache.dx), ∂∅, ntuple(_ -> ∂∅, length(args))...)
end

return sol, solve!_adjoint
end

end
16 changes: 16 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,19 @@ function CRC.rrule(::Type{<:LinearProblem}, A, b, p; kwargs...)
∇prob(∂prob) = (NoTangent(), ∂prob.A, ∂prob.b, ∂prob.p)
return prob, ∇prob
end

function CRC.rrule(T::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Nothing, args...; kwargs...)
assump = OperatorAssumptions(issquare(prob.A))
alg = defaultalg(prob.A, prob.b, assump)
CRC.rrule(T, prob, alg, args...; kwargs...)
end

function CRC.rrule(::typeof(LinearSolve.init), prob::LinearSolve.LinearProblem, alg::Union{LinearSolve.SciMLLinearSolveAlgorithm,Nothing}, args...; kwargs...)
init_res = LinearSolve.init(prob, alg)
function init_adjoint(∂init)
∂prob = LinearProblem(∂init.A, ∂init.b, NoTangent())
return NoTangent(), ∂prob, NoTangent(), ntuple((_ -> NoTangent(), length(args))...)
end

return init_res, init_adjoint
end
145 changes: 143 additions & 2 deletions test/nopre/mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ b1 = rand(n);

function f(A, b1; alg = LUFactorization())
prob = LinearProblem(A, b1)

sol1 = solve(prob, alg)

s1 = sol1.u
norm(s1)
end
Expand Down Expand Up @@ -153,3 +151,146 @@ for alg in (
@test results[1] ≈ fA(A)
@test mooncake_gradient ≈ fd_jac rtol = 1e-5
end

# Tests for solve! and init rrules.
n = 4
A = rand(n, n);
b1 = rand(n);
b2 = rand(n);

function f_(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f_(copy(A), copy(b1), copy(b2))
rule = Mooncake.build_rrule(f_, copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_pullback!!(
rule, 1.0,
f_, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f_(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f_(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f_(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f_2(A, b1, b2; alg=RFLUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f_2(copy(A), copy(b1), copy(b2))
rule = Mooncake.build_rrule(f_2, copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_pullback!!(
rule, 1.0,
f_2, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f_2(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f_2(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f_3(A, b1, b2; alg=KrylovJL_GMRES())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
s1 = copy(solve!(cache).u)
cache.b = b2
s2 = solve!(cache).u
norm(s1 + s2)
end

f_primal = f_3(copy(A), copy(b1), copy(b2))
rule = Mooncake.build_rrule(f_3, copy(A), copy(b1), copy(b2))
value, gradient = Mooncake.value_and_pullback!!(
rule, 1.0,
f_3, copy(A), copy(b1), copy(b2)
)

dA2 = ForwardDiff.gradient(x -> f_3(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
db12 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
db22 = ForwardDiff.gradient(x -> f_3(eltype(x).(A), eltype(x).(b1), x), copy(b2))

@test value == f_primal
@test gradient[2] ≈ dA2
@test gradient[3] ≈ db12
@test gradient[4] ≈ db22

function f_4(A, b1, b2; alg=LUFactorization())
prob = LinearProblem(A, b1)
cache = init(prob, alg)
solve!(cache)
s1 = copy(cache.u)
cache.b = b2
solve!(cache)
s2 = copy(cache.u)
norm(s1 + s2)
end

A = rand(n, n);
b1 = rand(n);
b2 = rand(n);
f_primal = f_4(copy(A), copy(b1), copy(b2))

rule = Mooncake.build_rrule(f_4, copy(A), copy(b1), copy(b2))
@test_throws "Adjoint case currently not handled" Mooncake.value_and_pullback!!(
rule, 1.0,
f_4, copy(A), copy(b1), copy(b2)
)

# dA2 = ForwardDiff.gradient(x -> f_4(x, eltype(x).(b1), eltype(x).(b2)), copy(A))
# db12 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), x, eltype(x).(b2)), copy(b1))
# db22 = ForwardDiff.gradient(x -> f_4(eltype(x).(A), eltype(x).(b1), x), copy(b2))

# @test value == f_primal
# @test grad[2] ≈ dA2
# @test grad[3] ≈ db12
# @test grad[4] ≈ db22

A = rand(n, n);
b1 = rand(n);

function fnice(A, b, alg)
prob = LinearProblem(A, b)
sol1 = solve(prob, alg)
return sum(sol1.u)
end

@testset for alg in (
LUFactorization(),
RFLUFactorization(),
KrylovJL_GMRES()
)
# for B
fb_closure = b -> fnice(A, b, alg)
fd_jac_b = FiniteDiff.finite_difference_jacobian(fb_closure, b1) |> vec

val, en_jac = Mooncake.value_and_gradient!!(
prepare_gradient_cache(fnice, copy(A), copy(b1), alg),
fnice, copy(A), copy(b1), alg
)
@test en_jac[3] ≈ fd_jac_b rtol = 1e-5

# For A
fA_closure = A -> fnice(A, b1, alg)
fd_jac_A = FiniteDiff.finite_difference_jacobian(fA_closure, A) |> vec
A_grad = en_jac[2] |> vec
@test A_grad ≈ fd_jac_A rtol = 1e-5
end
Loading