@@ -104,26 +104,28 @@ neural_ode_rd(dudt,u0,tspan,Tsit5(),saveat=0.1)
104
104
# Adjoint
105
105
106
106
@testset " adjoint mode trackedu0" begin
107
- Tracker. zero_grad! (dudt[1 ]. W. grad)
108
- Tracker. zero_grad! (downsample. W. grad)
109
- m1 = Chain (downsample, u0-> neural_ode (dudt,u0,tspan,Tsit5 (),save_everystep= false ,save_start= false )) # broke
110
- Flux. back! (sum (m1 (x0)))
111
- @test ! iszero (Tracker. grad (dudt[1 ]. W))
112
- @test ! iszero (Tracker. grad (downsample. W))
113
-
114
- Tracker. zero_grad! (dudt[1 ]. W. grad)
115
- Tracker. zero_grad! (downsample. W. grad)
116
- m2 = Chain (downsample, u0-> neural_ode (dudt,u0,tspan,Tsit5 (),saveat= 0.0 : 0.1 : 10.0 ))
117
- Flux. back! (sum (m2 (x0)))
118
- @test ! iszero (Tracker. grad (dudt[1 ]. W))
119
- @test ! iszero (Tracker. grad (downsample. W))
120
-
121
- Tracker. zero_grad! (dudt[1 ]. W. grad)
122
- Tracker. zero_grad! (downsample. W. grad)
123
- m3 = Chain (downsample, u0-> neural_ode (dudt,u0,tspan,Tsit5 (),saveat= 0.1 ))
124
- @test_broken Flux. back! (sum (m3 (x0)))
125
- # @test ! iszero(Tracker.grad(dudt[1].W))
126
- # @test ! iszero(Tracker.grad(downsample.W))
107
+ @test_broken
108
+ Tracker. zero_grad! (dudt[1 ]. W. grad)
109
+ Tracker. zero_grad! (downsample. W. grad)
110
+ m1 = Chain (downsample, u0-> neural_ode (dudt,u0,tspan,Tsit5 (),save_everystep= false ,save_start= false )) # broke
111
+ Flux. back! (sum (m1 (x0)))
112
+ @test ! iszero (Tracker. grad (dudt[1 ]. W))
113
+ @test ! iszero (Tracker. grad (downsample. W))
114
+
115
+ Tracker. zero_grad! (dudt[1 ]. W. grad)
116
+ Tracker. zero_grad! (downsample. W. grad)
117
+ m2 = Chain (downsample, u0-> neural_ode (dudt,u0,tspan,Tsit5 (),saveat= 0.0 : 0.1 : 10.0 ))
118
+ Flux. back! (sum (m2 (x0)))
119
+ @test ! iszero (Tracker. grad (dudt[1 ]. W))
120
+ @test ! iszero (Tracker. grad (downsample. W))
121
+
122
+ Tracker. zero_grad! (dudt[1 ]. W. grad)
123
+ Tracker. zero_grad! (downsample. W. grad)
124
+ m3 = Chain (downsample, u0-> neural_ode (dudt,u0,tspan,Tsit5 (),saveat= 0.1 ))
125
+ @test_broken Flux. back! (sum (m3 (x0)))
126
+ # @test ! iszero(Tracker.grad(dudt[1].W))
127
+ # @test ! iszero(Tracker.grad(downsample.W))
128
+ end
127
129
end ;
128
130
129
131
#= # RD =#
0 commit comments