@@ -308,11 +308,12 @@ function make_initial_params(
308308 initial_params,
309309)
310310 T = sampler_eltype (spl)
311- if initial_params == nothing
311+ if initial_params === nothing
312312 d = LogDensityProblems. dimension (logdensity)
313- initial_params = randn (rng, d)
313+ return randn (rng, T, d)
314+ else
315+ return T .(initial_params)
314316 end
315- return T .(initial_params)
316317end
317318
318319# ########
@@ -342,10 +343,10 @@ end
342343function make_step_size (
343344 rng:: Random.AbstractRNG ,
344345 integrator:: AbstractIntegrator ,
345- T :: Type ,
346+ :: Type{T} ,
346347 hamiltonian:: Hamiltonian ,
347348 initial_params,
348- )
349+ ) where {T}
349350 if integrator. ϵ > 0
350351 ϵ = integrator. ϵ
351352 else
@@ -358,10 +359,10 @@ end
358359function make_step_size (
359360 rng:: Random.AbstractRNG ,
360361 integrator:: Symbol ,
361- T :: Type ,
362+ :: Type{T} ,
362363 hamiltonian:: Hamiltonian ,
363364 initial_params,
364- )
365+ ) where {T}
365366 ϵ = find_good_stepsize (rng, hamiltonian, initial_params)
366367 @info string (" Found initial step size " , ϵ)
367368 return T (ϵ)
@@ -370,21 +371,33 @@ end
370371make_integrator (spl:: HMCSampler , ϵ:: Real ) = spl. κ. τ. integrator
371372make_integrator (spl:: AbstractHMCSampler , ϵ:: Real ) = make_integrator (spl. integrator, ϵ)
372373make_integrator (i:: AbstractIntegrator , ϵ:: Real ) = i
373- make_integrator (i:: Symbol , ϵ:: Real ) = make_integrator (Val (i), ϵ)
374- make_integrator (@nospecialize (i), :: Real ) = error (" Integrator $i not supported." )
375- make_integrator (i:: Val{:leapfrog} , ϵ:: Real ) = Leapfrog (ϵ)
376- make_integrator (i:: Val{:jitteredleapfrog} , ϵ:: T ) where {T<: Real } =
377- JitteredLeapfrog (ϵ, T (0.1 ϵ))
378- make_integrator (i:: Val{:temperedleapfrog} , ϵ:: T ) where {T<: Real } = TemperedLeapfrog (ϵ, T (1 ))
374+ function make_integrator (i:: Symbol , ϵ:: Real )
375+ float_ϵ = AbstractFloat (ϵ)
376+ if i === :leapfrog
377+ return Leapfrog (float_ϵ)
378+ elseif i === :jitteredleapfrog
379+ return JitteredLeapfrog (float_ϵ, float_ϵ / 10 )
380+ elseif i === :temperedleapfrog
381+ return TemperedLeapfrog (float_ϵ, oneunit (float_ϵ))
382+ else
383+ error (" Integrator $i not supported." )
384+ end
385+ end
379386
380387# ########
381388
382- make_metric (@nospecialize (i), T:: Type , d:: Int ) = error (" Metric $(typeof (i)) not supported." )
383- make_metric (i:: Symbol , T:: Type , d:: Int ) = make_metric (Val (i), T, d)
384- make_metric (i:: AbstractMetric , T:: Type , d:: Int ) = i
385- make_metric (i:: Val{:diagonal} , T:: Type , d:: Int ) = DiagEuclideanMetric (T, d)
386- make_metric (i:: Val{:unit} , T:: Type , d:: Int ) = UnitEuclideanMetric (T, d)
387- make_metric (i:: Val{:dense} , T:: Type , d:: Int ) = DenseEuclideanMetric (T, d)
389+ make_metric (i:: AbstractMetric , :: Type , :: Int ) = i
390+ function make_metric (i:: Symbol , :: Type{T} , d:: Int ) where {T}
391+ if i === :diagonal
392+ return DiagEuclideanMetric (T, d)
393+ elseif i === :unit
394+ return UnitEuclideanMetric (T, d)
395+ elseif i === :dense
396+ return DenseEuclideanMetric (T, d)
397+ else
398+ error (" Metric $i not supported." )
399+ end
400+ end
388401
389402function make_metric (spl:: AbstractHMCSampler , logdensity)
390403 d = LogDensityProblems. dimension (logdensity)
0 commit comments