@@ -14,7 +14,7 @@ Before getting to the explanation, here's some code to start with. We will follo
14
14
15
15
``` @example hamiltonian_cp
16
16
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
17
- ComponentArrays, Optimization, OptimizationOptimisers, IterTools
17
+ ComponentArrays, Optimization, OptimizationOptimisers, MLUtils
18
18
19
19
t = range(0.0f0, 1.0f0; length = 1024)
20
20
π_32 = Float32(π)
@@ -23,37 +23,33 @@ p_t = reshape(cos.(2π_32 * t), 1, :)
23
23
dqdt = 2π_32 .* p_t
24
24
dpdt = -2π_32 .* q_t
25
25
26
- data = vcat (q_t, p_t)
27
- target = vcat (dqdt, dpdt)
26
+ data = cat (q_t, p_t; dims = 1 )
27
+ target = cat (dqdt, dpdt; dims = 1 )
28
28
B = 256
29
- NEPOCHS = 100
30
- dataloader = ncycle(
31
- ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
32
- selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
33
- for i in 1:(size(data, 2) ÷ B)),
34
- NEPOCHS)
35
-
36
- hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64, 1)); autodiff = AutoZygote())
29
+ NEPOCHS = 500
30
+ dataloader = DataLoader((data, target); batchsize = B)
31
+
32
+ hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028, 1)); autodiff = AutoZygote())
37
33
ps, st = Lux.setup(Xoshiro(0), hnn)
38
34
ps_c = ps |> ComponentArray
39
35
40
36
opt = OptimizationOptimisers.Adam(0.01f0)
41
37
42
- function loss_function(ps, data, target)
38
+ function loss_function(ps, databatch)
39
+ data, target = databatch
43
40
pred, st_ = hnn(data, ps, st)
44
- return mean(abs2, pred .- target), pred
41
+ return mean(abs2, pred .- target)
45
42
end
46
43
47
- function callback(ps , loss, pred )
44
+ function callback(state , loss)
48
45
println("[Hamiltonian NN] Loss: ", loss)
49
46
return false
50
47
end
51
48
52
- opt_func = OptimizationFunction((ps, _, data, target) -> loss_function(ps, data, target),
53
- Optimization.AutoForwardDiff())
54
- opt_prob = OptimizationProblem(opt_func, ps_c)
49
+ opt_func = OptimizationFunction(loss_function, Optimization.AutoForwardDiff())
50
+ opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)
55
51
56
- res = Optimization.solve(opt_prob, opt, dataloader ; callback)
52
+ res = Optimization.solve(opt_prob, opt; callback, epochs = NEPOCHS )
57
53
58
54
ps_trained = res.u
59
55
@@ -75,7 +71,7 @@ The HNN predicts the gradients ``(\dot q, \dot p)`` given ``(q, p)``. Hence, we
75
71
76
72
``` @example hamiltonian
77
73
using Lux, DiffEqFlux, OrdinaryDiffEq, Statistics, Plots, Zygote, ForwardDiff, Random,
78
- ComponentArrays, Optimization, OptimizationOptimisers, IterTools
74
+ ComponentArrays, Optimization, OptimizationOptimisers, MLUtils
79
75
80
76
t = range(0.0f0, 1.0f0; length = 1024)
81
77
π_32 = Float32(π)
@@ -87,40 +83,37 @@ dpdt = -2π_32 .* q_t
87
83
data = cat(q_t, p_t; dims = 1)
88
84
target = cat(dqdt, dpdt; dims = 1)
89
85
B = 256
90
- NEPOCHS = 100
91
- dataloader = ncycle(
92
- ((selectdim(data, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))),
93
- selectdim(target, 2, ((i - 1) * B + 1):(min(i * B, size(data, 2)))))
94
- for i in 1:(size(data, 2) ÷ B)),
95
- NEPOCHS)
86
+ NEPOCHS = 500
87
+ dataloader = DataLoader((data, target); batchsize = B)
96
88
```
97
89
98
90
### Training the HamiltonianNN
99
91
100
92
We parameterize the with a small MultiLayered Perceptron. HNNs are trained by optimizing the gradients of the Neural Network. Zygote currently doesn't support nesting itself, so we will be using ForwardDiff in the training loop to compute the gradients of the HNN Layer for Optimization.
101
93
102
94
``` @example hamiltonian
103
- hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (64 , 1)); autodiff = AutoZygote())
95
+ hnn = Layers.HamiltonianNN{true}(Layers.MLP(2, (1028 , 1)); autodiff = AutoZygote())
104
96
ps, st = Lux.setup(Xoshiro(0), hnn)
105
97
ps_c = ps |> ComponentArray
98
+ hnn_stateful = StatefulLuxLayer{true}(hnn, ps_c, st)
106
99
107
- opt = OptimizationOptimisers.Adam(0.01f0 )
100
+ opt = OptimizationOptimisers.Adam(0.005f0 )
108
101
109
- function loss_function(ps, data, target)
110
- pred, st_ = hnn(data, ps, st)
111
- return mean(abs2, pred .- target), pred
102
+ function loss_function(ps, databatch)
103
+ (data, target) = databatch
104
+ pred = hnn_stateful(data, ps)
105
+ return mean(abs2, pred .- target)
112
106
end
113
107
114
- function callback(ps , loss, pred )
108
+ function callback(state , loss)
115
109
println("[Hamiltonian NN] Loss: ", loss)
116
110
return false
117
111
end
118
112
119
- opt_func = OptimizationFunction(
120
- (ps, _, data, target) -> loss_function(ps, data, target), Optimization.AutoZygote())
121
- opt_prob = OptimizationProblem(opt_func, ps_c)
113
+ opt_func = OptimizationFunction(loss_function, Optimization.AutoZygote())
114
+ opt_prob = OptimizationProblem(opt_func, ps_c, dataloader)
122
115
123
- res = solve(opt_prob, opt, dataloader ; callback)
116
+ res = Optimization. solve(opt_prob, opt; callback, epochs = NEPOCHS )
124
117
125
118
ps_trained = res.u
126
119
```
0 commit comments