Skip to content

Commit 6920654

Browse files
Merge pull request #83 from JuliaDiffEq/size
allow arbitrary parameter sizing
2 parents bc0501d + 3f95871 commit 6920654

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/Flux/layers.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function diffeq_adjoint(p,prob,args...;u0=prob.u0,kwargs...)
5454
adapt(T, solve(_prob,args...;kwargs...))
5555
end
5656

57-
diffeq_adjoint(p::TrackedVector,prob,args...;u0=prob.u0,kwargs...) =
57+
diffeq_adjoint(p::TrackedArray,prob,args...;u0=prob.u0,kwargs...) =
5858
Tracker.track(diffeq_adjoint, p, u0, prob, args...; kwargs...)
5959

6060
@grad function diffeq_adjoint(p,u0,prob,args...;backsolve=true,
@@ -95,6 +95,7 @@ diffeq_adjoint(p::TrackedVector,prob,args...;u0=prob.u0,kwargs...) =
9595
du0, dp = adjoint_sensitivities_u0(sol,args...,df,ts;
9696
sensealg=sensealg,
9797
kwargs...)
98-
(dp', reshape(du0,size(u0)), ntuple(_->nothing, 1+length(args))...)
98+
99+
(reshape(dp,size(p)), reshape(du0,size(u0)), ntuple(_->nothing, 1+length(args))...)
99100
end
100101
end

0 commit comments

Comments
 (0)