diff --git a/HISTORY.md b/HISTORY.md index 40a6fc803..90864508b 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -61,6 +61,10 @@ The only flag other than `"del"` that `Metadata` ever used was `"trans"`. Thus t The `resume_from=chn` keyword argument to `sample` has been removed; please use `initial_state=DynamicPPL.loadstate(chn)` instead. `loadstate` is exported from DynamicPPL. +### Change of default keytype of `pointwise_logdensities` + +The functions `pointwise_prior_logdensities`, `pointwise_logdensities`, and `pointwise_loglikelihoods` return dictionaries for which the keys are model variables, and the key type is either `VarName` or `String`. This release changes the default from `String` to `VarName`. + **Other changes** ### `predict(model, chain; include_all)` diff --git a/src/pointwise_logdensities.jl b/src/pointwise_logdensities.jl index 61834ab62..47ca62530 100644 --- a/src/pointwise_logdensities.jl +++ b/src/pointwise_logdensities.jl @@ -116,13 +116,13 @@ end ::Val{whichlogprob}=Val(:both), ) -Runs `model` on each sample in `chain` returning a `OrderedDict{String, Matrix{Float64}}` -with keys corresponding to symbols of the variables, and values being matrices -of shape `(num_chains, num_samples)`. +Runs `model` on each sample in `chain` returning a `OrderedDict{VarName, Matrix{Float64}}` +with keys being model variables and values being matrices of shape +`(num_chains, num_samples)`. `keytype` specifies what the type of the keys used in the returned `OrderedDict` are. -Currently, only `String` and `VarName` are supported. `whichlogprob` specifies -which log-probabilities to compute. It can be `:both`, `:prior`, or +Currently, only `String` and `VarName` are supported, with `VarName` being the default. +`whichlogprob` specifies which log-probabilities to compute. It can be `:both`, `:prior`, or `:likelihood`. See also: [`pointwise_loglikelihoods`](@ref), [`pointwise_loglikelihoods`](@ref). @@ -177,13 +177,13 @@ julia> # A chain with 3 iterations. ); julia> pointwise_logdensities(model, chain) -OrderedDict{String, Matrix{Float64}} with 6 entries: - "s" => [-0.802775; -1.38222; -2.09861;;] - "m" => [-8.91894; -7.51551; -7.46824;;] - "xs[1]" => [-5.41894; -5.26551; -5.63491;;] - "xs[2]" => [-2.91894; -3.51551; -4.13491;;] - "xs[3]" => [-1.41894; -2.26551; -2.96824;;] - "y" => [-0.918939; -1.51551; -2.13491;;] +OrderedDict{VarName, Matrix{Float64}} with 6 entries: + s => [-0.802775; -1.38222; -2.09861;;] + m => [-8.91894; -7.51551; -7.46824;;] + xs[1] => [-5.41894; -5.26551; -5.63491;;] + xs[2] => [-2.91894; -3.51551; -4.13491;;] + xs[3] => [-1.41894; -2.26551; -2.96824;;] + y => [-0.918939; -1.51551; -2.13491;;] julia> pointwise_logdensities(model, chain, String) OrderedDict{String, Matrix{Float64}} with 6 entries: @@ -225,7 +225,7 @@ julia> ℓ = pointwise_logdensities(m, VarInfo(m)); first.((ℓ[@varname(x[1])], ``` """ function pointwise_logdensities( - model::Model, chain, ::Type{KeyType}=String, ::Val{whichlogprob}=Val(:both) + model::Model, chain, ::Type{KeyType}=VarName, ::Val{whichlogprob}=Val(:both) ) where {KeyType,whichlogprob} # Get the data by executing the model once vi = VarInfo(model) @@ -283,7 +283,7 @@ including the likelihood terms. See also: [`pointwise_logdensities`](@ref), [`pointwise_prior_logdensities`](@ref). """ -function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} +function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=VarName) where {T} return pointwise_logdensities(model, chain, T, Val(:likelihood)) end @@ -301,7 +301,7 @@ including the prior terms. See also: [`pointwise_logdensities`](@ref), [`pointwise_loglikelihoods`](@ref). """ function pointwise_prior_logdensities( - model::Model, chain, keytype::Type{T}=String + model::Model, chain, keytype::Type{T}=VarName ) where {T} return pointwise_logdensities(model, chain, T, Val(:prior)) end diff --git a/test/pointwise_logdensities.jl b/test/pointwise_logdensities.jl index cfb222b66..aac59380c 100644 --- a/test/pointwise_logdensities.jl +++ b/test/pointwise_logdensities.jl @@ -60,11 +60,11 @@ end loglikelihoods_pointwise = pointwise_loglikelihoods(model, chain) # Check that they contain the correct variables. - @test all(string(vn) in keys(logjoints_pointwise) for vn in vns) - @test all(string(vn) in keys(logpriors_pointwise) for vn in vns) - @test !any(Base.Fix2(startswith, "x"), keys(logpriors_pointwise)) - @test !any(string(vn) in keys(loglikelihoods_pointwise) for vn in vns) - @test all(Base.Fix2(startswith, "x"), keys(loglikelihoods_pointwise)) + @test all(vn in keys(logjoints_pointwise) for vn in vns) + @test all(vn in keys(logpriors_pointwise) for vn in vns) + @test !any(Base.Fix1(subsumes, @varname(x)), keys(logpriors_pointwise)) + @test !any(vn in keys(loglikelihoods_pointwise) for vn in vns) + @test all(Base.Fix1(subsumes, @varname(x)), keys(loglikelihoods_pointwise)) # Get the sum of the logjoints for each of the iterations. logjoints = [