Skip to content

Commit f4f9047

Browse files
Merge branch 'size'
2 parents 6920654 + 3921516 commit f4f9047

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DiffEqFlux"
22
uuid = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
33
authors = ["Chris Rackauckas <[email protected]>"]
4-
version = "0.6.1"
4+
version = "0.7.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ if GROUP == "All"
1010
@safetestset "Layers Tests" begin include("layers.jl") end
1111
@safetestset "Layers SDE" begin include("layers_sde.jl") end
1212
@safetestset "Layers DDE" begin include("layers_dde.jl") end
13+
@safetestset "Size Handling in Adjoint Tests" begin include("size_handling_adjoint.jl") end
1314
@safetestset "odenet" begin include("odenet.jl") end
1415
@safetestset "Neural DE Tests" begin include("neural_de.jl") end
1516
@safetestset "Partial Neural Tests" begin include("partial_neural.jl") end

test/size_handling_adjoint.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using OrdinaryDiffEq, Test
2+
3+
p = [1.5 1.0;3.0 1.0]
4+
function lotka_volterra(du,u,p,t)
5+
du[1] = p[1,1]*u[1] - p[1,2]*u[1]*u[2]
6+
du[2] = -p[2,1]*u[2] + p[2,2]*u[1]*u[2]
7+
end
8+
9+
u0 = [1.0,1.0]
10+
tspan = (0.0,10.0)
11+
12+
prob = ODEProblem(lotka_volterra,u0,tspan,p)
13+
sol = solve(prob,Tsit5())
14+
using Plots
15+
plot(sol)
16+
17+
using Flux, DiffEqFlux
18+
p = param([2.2 1.0;2.0 0.4]) # Tweaked Initial Parameter Array
19+
params = Flux.Params([p])
20+
21+
function predict_adjoint() # Our 1-layer neural network
22+
diffeq_adjoint(p,prob,Tsit5(),saveat=0.0:0.1:10.0)
23+
end
24+
25+
loss_adjoint() = sum(abs2,x-1 for x in predict_adjoint())
26+
27+
data = Iterators.repeated((), 100)
28+
opt = ADAM(0.1)
29+
cb = function () #callback function to observe training
30+
display(loss_adjoint())
31+
end
32+
33+
predict_adjoint()
34+
35+
# Display the ODE with the initial parameter values.
36+
cb()
37+
Flux.train!(loss_adjoint, params, data, opt, cb = cb)
38+
39+
@test loss_adjoint() < 1

0 commit comments

Comments
 (0)