Skip to content

Commit 76f8431

Browse files
Merge pull request #306 from SciML/u/immutablefix
Add immutable ODE Problem for GPU compilation
2 parents 6d1b08b + 5f62f6b commit 76f8431

File tree

8 files changed

+137
-13
lines changed

8 files changed

+137
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ MuladdMacro = "0.2"
5252
Parameters = "0.12"
5353
RecursiveArrayTools = "2"
5454
Requires = "1.0"
55-
SciMLBase = "1.26"
55+
SciMLBase = "1.26, 2"
5656
Setfield = "1"
5757
SimpleDiffEq = "1"
5858
StaticArrays = "1"

docs/src/tutorials/lower_level_api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ prob = ODEProblem{false}(lorenz, u0, tspan, p)
3030
3131
## Building different problems for different parameters
3232
probs = map(1:trajectories) do i
33-
remake(prob, p = (@SVector rand(Float32, 3)) .* p)
33+
DiffEqGPU.make_prob_compatible(remake(prob, p = (@SVector rand(Float32, 3)) .* p))
3434
end
3535
3636
## Move the arrays to the GPU

src/DiffEqGPU.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ include("ensemblegpukernel/tableaus/verner_tableaus.jl")
6666
include("ensemblegpukernel/tableaus/rodas_tableaus.jl")
6767
include("ensemblegpukernel/tableaus/kvaerno_tableaus.jl")
6868

69+
include("ensemblegpukernel/problems/ode_problems.jl")
70+
6971
include("utils.jl")
7072
include("algorithms.jl")
7173
include("solve.jl")

src/ensemblegpuarray/problem_generation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function generate_problem(prob::ODEProblem, u0, p, jac_prototype, colorvec)
1+
function generate_problem(prob::SciMLBase.AbstractODEProblem, u0, p, jac_prototype, colorvec)
22
_f = let f = prob.f.f, kernel = DiffEqBase.isinplace(prob) ? gpu_kernel : gpu_kernel_oop
33
function (du, u, p, t)
44
version = get_backend(u)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import SciMLBase: @add_kwonly, AbstractODEProblem, AbstractODEFunction,
2+
FunctionWrapperSpecialize, StandardODEProblem, prepare_initial_state, promote_tspan,
3+
warn_paramtype
4+
5+
struct ImmutableODEProblem{uType, tType, isinplace, P, F, K, PT} <:
6+
AbstractODEProblem{uType, tType, isinplace}
7+
"""The ODE is `du = f(u,p,t)` for out-of-place and f(du,u,p,t) for in-place."""
8+
f::F
9+
"""The initial condition is `u(tspan[1]) = u0`."""
10+
u0::uType
11+
"""The solution `u(t)` will be computed for `tspan[1] ≤ t ≤ tspan[2]`."""
12+
tspan::tType
13+
"""Constant parameters to be supplied as the second argument of `f`."""
14+
p::P
15+
"""A callback to be applied to every solver which uses the problem."""
16+
kwargs::K
17+
"""An internal argument for storing traits about the solving process."""
18+
problem_type::PT
19+
@add_kwonly function ImmutableODEProblem{iip}(f::AbstractODEFunction{iip},
20+
u0, tspan, p = NullParameters(),
21+
problem_type = StandardODEProblem();
22+
kwargs...) where {iip}
23+
_u0 = prepare_initial_state(u0)
24+
_tspan = promote_tspan(tspan)
25+
warn_paramtype(p)
26+
new{typeof(_u0), typeof(_tspan),
27+
isinplace(f), typeof(p), typeof(f),
28+
typeof(kwargs),
29+
typeof(problem_type)}(f,
30+
_u0,
31+
_tspan,
32+
p,
33+
kwargs,
34+
problem_type)
35+
end
36+
37+
"""
38+
ImmutableODEProblem{isinplace}(f,u0,tspan,p=NullParameters(),callback=CallbackSet())
39+
40+
Define an ODE problem with the specified function.
41+
`isinplace` optionally sets whether the function is inplace or not.
42+
This is determined automatically, but not inferred.
43+
"""
44+
function ImmutableODEProblem{iip}(f,
45+
u0,
46+
tspan,
47+
p = NullParameters();
48+
kwargs...) where {iip}
49+
_u0 = prepare_initial_state(u0)
50+
_tspan = promote_tspan(tspan)
51+
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
52+
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
53+
end
54+
55+
@add_kwonly function ImmutableODEProblem{iip, recompile}(f, u0, tspan,
56+
p = NullParameters();
57+
kwargs...) where {iip, recompile}
58+
ImmutableODEProblem{iip}(ODEFunction{iip, recompile}(f), u0, tspan, p; kwargs...)
59+
end
60+
61+
function ImmutableODEProblem{iip, FunctionWrapperSpecialize}(f, u0, tspan,
62+
p = NullParameters();
63+
kwargs...) where {iip}
64+
_u0 = prepare_initial_state(u0)
65+
_tspan = promote_tspan(tspan)
66+
if !(f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
67+
if iip
68+
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_iip(f,
69+
(_u0, _u0, p,
70+
_tspan[1])))
71+
else
72+
ff = ODEFunction{iip, FunctionWrapperSpecialize}(wrapfun_oop(f,
73+
(_u0, p,
74+
_tspan[1])))
75+
end
76+
end
77+
ImmutableODEProblem{iip}(ff, _u0, _tspan, p; kwargs...)
78+
end
79+
end
80+
81+
"""
82+
ImmutableODEProblem(f::ODEFunction,u0,tspan,p=NullParameters(),callback=CallbackSet())
83+
84+
Define an ODE problem from an [`ODEFunction`](@ref).
85+
"""
86+
function ImmutableODEProblem(f::AbstractODEFunction, u0, tspan, args...; kwargs...)
87+
ImmutableODEProblem{isinplace(f)}(f, u0, tspan, args...; kwargs...)
88+
end
89+
90+
function ImmutableODEProblem(f, u0, tspan, p = NullParameters(); kwargs...)
91+
iip = isinplace(f, 4)
92+
_u0 = prepare_initial_state(u0)
93+
_tspan = promote_tspan(tspan)
94+
_f = ODEFunction{iip, DEFAULT_SPECIALIZATION}(f)
95+
ImmutableODEProblem(_f, _u0, _tspan, p; kwargs...)
96+
end
97+
98+
function Base.convert(::Type{ImmutableODEProblem}, prob::T) where {T <: ODEProblem}
99+
ImmutableODEProblem(prob.f,
100+
prob.u0,
101+
prob.tspan,
102+
prob.p,
103+
prob.problem_type;
104+
prob.kwargs...)
105+
end

src/solve.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -124,19 +124,21 @@ function batch_solve(ensembleprob, alg,
124124
ensemblealg::Union{EnsembleArrayAlgorithm, EnsembleKernelAlgorithm}, I,
125125
adaptive;
126126
kwargs...)
127-
if ensembleprob.safetycopy
128-
probs = map(I) do i
129-
ensembleprob.prob_func(deepcopy(ensembleprob.prob), i, 1)
130-
end
131-
else
132-
probs = map(I) do i
133-
ensembleprob.prob_func(ensembleprob.prob, i, 1)
134-
end
135-
end
136127
@assert !isempty(I)
137128
#@assert all(p->p.f === probs[1].f,probs)
138129

139130
if ensemblealg isa EnsembleGPUKernel
131+
if ensembleprob.safetycopy
132+
probs = map(I) do i
133+
make_prob_compatible(ensembleprob.prob_func(deepcopy(ensembleprob.prob),
134+
i,
135+
1))
136+
end
137+
else
138+
probs = map(I) do i
139+
make_prob_compatible(ensembleprob.prob_func(ensembleprob.prob, i, 1))
140+
end
141+
end
140142
# Using inner saveat requires all of them to be of same size,
141143
# because the dimension of CuMatrix is decided by it.
142144
# The columns of it are accessed at each thread.
@@ -192,6 +194,15 @@ function batch_solve(ensembleprob, alg,
192194
error("We don't have solvers implemented for this algorithm yet")
193195
end
194196
else
197+
if ensembleprob.safetycopy
198+
probs = map(I) do i
199+
ensembleprob.prob_func(deepcopy(ensembleprob.prob), i, 1)
200+
end
201+
else
202+
probs = map(I) do i
203+
ensembleprob.prob_func(ensembleprob.prob, i, 1)
204+
end
205+
end
195206
u0 = reduce(hcat, Array(probs[i].u0) for i in 1:length(I))
196207

197208
if !all(Base.Fix2((prob1, prob2) -> isequal(prob1.tspan, prob2.tspan),

src/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,9 @@ function diffeqgpunorm(u::AbstractArray{<:ForwardDiff.Dual}, t)
44
sqrt.(sum(abs2 ForwardDiff.value, u) ./ length(u))
55
end
66
diffeqgpunorm(u::ForwardDiff.Dual, t) = abs(ForwardDiff.value(u))
7+
8+
make_prob_compatible(prob) = prob
9+
10+
function make_prob_compatible(prob::T) where {T <: ODEProblem}
11+
convert(ImmutableODEProblem, prob)
12+
end

test/lower_level_api.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ prob = ODEProblem{false}(func, u0, tspan, p)
4848
## Building different problems for different parameters
4949
batch = 1:trajectories
5050
probs = map(batch) do i
51-
remake(prob, p = (@SVector rand(Float32, 3)) .* p)
51+
DiffEqGPU.make_prob_compatible(remake(prob, p = (@SVector rand(Float32, 3)) .* p))
5252
end
5353

5454
## Move the arrays to the GPU

0 commit comments

Comments
 (0)