diff --git a/src/utils.jl b/src/utils.jl index b55a2f715..452c15a7c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -777,11 +777,15 @@ end """ float_type_with_fallback(T::DataType) -Return `float(T)` if possible; otherwise return `float(Real)`. +Return `T` if it is a non-integer Real, `float(T)` for integer types, and `float(Real)` otherwise. """ float_type_with_fallback(::Type) = float(Real) float_type_with_fallback(::Type{Union{}}) = float(Real) -float_type_with_fallback(::Type{T}) where {T<:Real} = float(T) +float_type_with_fallback(::Type{Real}) = float(Real) +float_type_with_fallback(::Type{T}) where {T<:Integer} = float(T) +# This final case is responsible not only for plain old Float64, but also things like +# ForwardDiff.Dual, etc. See https://github.com/TuringLang/DynamicPPL.jl/pull/1088. +float_type_with_fallback(::Type{T}) where {T<:Real} = T """ infer_nested_eltype(x::Type) diff --git a/test/model.jl b/test/model.jl index 6ba3bca2a..a6dc41ca6 100644 --- a/test/model.jl +++ b/test/model.jl @@ -370,29 +370,27 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end - if VERSION >= v"1.8" - @testset "Type stability of models" begin - models_to_test = [ - DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) - ] - @testset "$(model.f)" for model in models_to_test - vns = DynamicPPL.TestUtils.varnames(model) - example_values = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = filter( - is_type_stable_varinfo, - DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), - ) - @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo)) - true - end - - varinfo_linked = DynamicPPL.link(varinfo, model) - @test begin - @inferred(DynamicPPL.evaluate!!(model, varinfo_linked)) - true - end + @testset "Type stability of models" begin + models_to_test = [ + DynamicPPL.TestUtils.DEMO_MODELS..., DynamicPPL.TestUtils.demo_lkjchol(2) + ] + @testset "$(model.f)" for model in models_to_test + vns = DynamicPPL.TestUtils.varnames(model) + example_values = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = filter( + is_type_stable_varinfo, + DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns), + ) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo)) + true + end + + varinfo_linked = DynamicPPL.link(varinfo, model) + @test begin + @inferred(DynamicPPL.evaluate!!(model, varinfo_linked)) + true end end end