-
-
Notifications
You must be signed in to change notification settings - Fork 75
Make length of Partials known at compile time in ForwardDiff overloads #727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
JET test it? |
using LinearSolve
using ForwardDiff
using Test
using JET
function h(p)
(A=[p[1] p[2]+1 p[2]^3;
3*p[1] p[1]+5 p[2]*p[1]-4;
p[2]^2 9*p[1] p[2]],
b=[p[1] + 1, p[2] * 2, p[1]^2])
end
A, b = h([ForwardDiff.Dual(5.0, 1.0, 0.0), ForwardDiff.Dual(5.0, 0.0, 1.0)])
prob = LinearProblem(A, b)
cache = init(prob)
solve!(cache)
ext = Base.get_extension(LinearSolve, :LinearSolveForwardDiffExt)
@test_opt ext.linearsolve_dual_solution([1.0, 1.0, 1.0], [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], cache) This passes now. Without this change there is indeed a runtime dispatch. There are also a bunch of runtime dispatches for an overall |
Okay, but you still didn't add that to the JET tests. |
What is the state of this? I am currently hanging back with old versions of LinearSolve. It would be really nice to get back the old performance in SciML/OrdinaryDiffEq.jl#2837. |
010269e
to
985a4cc
Compare
Test fails when rebased. |
985a4cc
to
010269e
Compare
f0fa5a1
to
9d37f1a
Compare
Tests pass now |
Thanks! |
Checklist
contributor guidelines, in particular the SciML Style Guide and
COLPRAC.
Additional context
Add any other context about the problem here.