We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents bc0501d + 3f95871 commit 6920654Copy full SHA for 6920654
src/Flux/layers.jl
@@ -54,7 +54,7 @@ function diffeq_adjoint(p,prob,args...;u0=prob.u0,kwargs...)
54
adapt(T, solve(_prob,args...;kwargs...))
55
end
56
57
-diffeq_adjoint(p::TrackedVector,prob,args...;u0=prob.u0,kwargs...) =
+diffeq_adjoint(p::TrackedArray,prob,args...;u0=prob.u0,kwargs...) =
58
Tracker.track(diffeq_adjoint, p, u0, prob, args...; kwargs...)
59
60
@grad function diffeq_adjoint(p,u0,prob,args...;backsolve=true,
@@ -95,6 +95,7 @@ diffeq_adjoint(p::TrackedVector,prob,args...;u0=prob.u0,kwargs...) =
95
du0, dp = adjoint_sensitivities_u0(sol,args...,df,ts;
96
sensealg=sensealg,
97
kwargs...)
98
- (dp', reshape(du0,size(u0)), ntuple(_->nothing, 1+length(args))...)
+
99
+ (reshape(dp,size(p)), reshape(du0,size(u0)), ntuple(_->nothing, 1+length(args))...)
100
101
0 commit comments