Skip to content

Commit af2eaac

Browse files
Merge remote-tracking branch 'origin/master'
2 parents c499b75 + f174868 commit af2eaac

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

docs/src/examples/divergence.md

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ parameters. This is shown in the loss function:
1616
```julia
1717
function loss(p)
1818
tmp_prob = remake(prob, p=p)
19-
tmp_sol = Array(solve(tmp_prob,Tsit5(),saveat=0.1))
20-
if size(tmp_sol) == size(dataset)
21-
return sum(abs2,tmp_sol - dataset)
19+
tmp_sol = solve(tmp_prob,Tsit5(),saveat=0.1)
20+
if tmp_sol.retcode == :Success
21+
return sum(abs2,Array(tmp_sol) - dataset)
2222
else
2323
return Inf
2424
end
@@ -55,9 +55,9 @@ scatter!(sol.t,dataset')
5555

5656
function loss(p)
5757
tmp_prob = remake(prob, p=p)
58-
tmp_sol = Array(solve(tmp_prob,Tsit5(),saveat=0.1))
59-
if size(tmp_sol) == size(dataset)
60-
return sum(abs2,tmp_sol - dataset)
58+
tmp_sol = solve(tmp_prob,Tsit5(),saveat=0.1)
59+
if tmp_sol.retcode == :Success
60+
return sum(abs2,Array(tmp_sol) - dataset)
6161
else
6262
return Inf
6363
end
@@ -66,5 +66,26 @@ end
6666
using DiffEqFlux
6767

6868
pinit = [1.2,0.8,2.5,0.8]
69-
res = DiffEqFlux.sciml_train(loss,pinit,BFGS())
69+
res = DiffEqFlux.sciml_train(loss,pinit,ADAM(), maxiters = 1000)
70+
71+
# res = DiffEqFlux.sciml_train(loss,pinit,BFGS(), maxiters = 1000) ### errors!
72+
73+
#try Newton method of optimization
74+
res = DiffEqFlux.sciml_train(loss,pinit,Newton(), GalacticOptim.AutoForwardDiff())
7075
```
76+
77+
You might notice that `AutoZygote` (default) fails for the above `sciml_train` call with Optim's optimizers which happens because
78+
of Zygote's behaviour for zero gradients in which case it returns `nothing`. To avoid such issue you can just use a different version of the same check which compares the size of the obtained
79+
solution and the data we have, shown below, which is easier to AD.
80+
81+
```julia
82+
function loss(p)
83+
tmp_prob = remake(prob, p=p)
84+
tmp_sol = solve(tmp_prob,Tsit5(),saveat=0.1)
85+
if size(tmp_sol) == size(dataset)
86+
return sum(abs2,Array(tmp_sol) .- dataset)
87+
else
88+
return Inf
89+
end
90+
end
91+
```

0 commit comments

Comments
 (0)