-
Notifications
You must be signed in to change notification settings - Fork 29
Prior/posterior predictive check plots #319
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Probably 4.16.0 since #310 is a bigger thing and probably won't have too much effect here. |
Backup of Paulina's implementation of PPC Plots: src/plot.jl
@shorthands meanplot
@shorthands autocorplot
@shorthands mixeddensity
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@userplot RidgelinePlot
@userplot ForestPlot
@shorthands ppcplot
struct _TracePlot; c; val; end
struct _MeanPlot; c; val; end
struct _DensityPlot; c; val; end
struct _HistogramPlot; c; val; end
struct _AutocorPlot; lags; val; end
struct _PPCPlot; y_obs; y_pred; ymean_pred; end
# define alias functions for old syntax
const translationdict = Dict(
:traceplot => _TracePlot,
:meanplot => _MeanPlot,
:density => _DensityPlot,
:histogram => _HistogramPlot,
:autocorplot => _AutocorPlot,
:pooleddensity => _DensityPlot,
:ppcplot => _PPCPlot
)
const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner)
@recipe f(c::Chains, s::Symbol) = c, [s]
@recipe function f(
chains::Chains, i::Int;
colordim = :chain,
barbounds = (-Inf, Inf),
maxlag = nothing,
append_chains = false
)
st = get(plotattributes, :seriestype, :traceplot)
c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains
if colordim == :parameter
title --> "Chain $(MCMCChains.chains(c)[i])"
label --> string.(names(c))
val = c.value[:, :, i]
elseif colordim == :chain
title --> string(names(c)[i])
label --> map(x -> "Chain $x", MCMCChains.chains(c))
val = c.value[:, i, :]
else
throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
end
if st == :mixeddensity || st == :pooleddensity
discrete = indiscretesupport(c, barbounds)
st = if colordim == :chain
discrete[i] ? :histogram : :density
else
# NOTE: It might make sense to overlay histograms and density plots here.
:density
end
seriestype := st
end
if st == :autocorplot
lags = 0:(maxlag === nothing ? round(Int, 10 * log10(length(range(c)))) : maxlag)
ac = autocor(c; sections = nothing, lags = lags)
ac_mat = convert(Array, ac)
val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
_AutocorPlot(lags, val)
elseif st ∈ supportedplots
translationdict[st](c, val)
else
range(c), val
end
end
@recipe function f(p::_DensityPlot)
xaxis --> "Sample value"
yaxis --> "Density"
trim --> true
[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end
@recipe function f(p::_HistogramPlot)
xaxis --> "Sample value"
yaxis --> "Frequency"
fillalpha --> 0.7
bins --> 25
trim --> true
[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val, 2)]
end
@recipe function f(p::_MeanPlot)
seriestype := :path
xaxis --> "Iteration"
yaxis --> "Mean"
range(p.c), cummean(p.val)
end
@recipe function f(p::_AutocorPlot)
seriestype := :path
xaxis --> "Lag"
yaxis --> "Autocorrelation"
p.lags, p.val
end
@recipe function f(p::_TracePlot)
seriestype := :path
xaxis --> "Iteration"
yaxis --> "Sample value"
range(p.c), p.val
end
@recipe function f(
chains::Chains,
parameters::AbstractVector{Symbol};
colordim = :chain
)
colordim != :chain &&
error("Symbol names are interpreted as parameter names, only compatible with ",
"`colordim = :chain`")
ret = indexin(parameters, names(chains))
any(y === nothing for y in ret) && error("Parameter not found")
return chains, Int.(ret)
end
@recipe function f(
chains::Chains,
parameters::AbstractVector{<:Integer} = Int[];
sections = _default_sections(chains),
width = 500,
height = 250,
colordim = :chain,
append_chains = false
)
_chains = isempty(parameters) ? Chains(chains, _clean_sections(chains, sections)) : chains
c = append_chains ? pool_chain(_chains) : _chains
ptypes = get(plotattributes, :seriestype, (:traceplot, :mixeddensity))
ptypes = ptypes isa Symbol ? (ptypes,) : ptypes
@assert all(ptype -> ptype ∈ supportedplots, ptypes)
ntypes = length(ptypes)
nrows, nvars, nchains = size(c)
isempty(parameters) && (parameters = colordim == :chain ? (1:nvars) : (1:nchains))
N = length(parameters)
if :corner ∉ ptypes
size --> (ntypes*width, N*height)
legend --> false
multiple_plots = N * ntypes > 1
if multiple_plots
layout := (N, ntypes)
end
i = 0
for par in parameters
for ptype in ptypes
i += 1
@series begin
if multiple_plots
subplot := i
end
colordim := colordim
seriestype := ptype
c, par
end
end
end
else
ntypes > 1 && error(":corner is not compatible with multiple seriestypes")
Corner(c, names(c)[parameters])
end
end
struct Corner
c
parameters
end
@recipe function f(corner::Corner)
label --> permutedims(corner.parameters)
compact --> true
size --> (600, 600)
ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
RecipesBase.recipetype(:cornerplot, vcat(ar...))
end
function _compute_plot_data(
i::Integer,
chains::Chains,
par_names::AbstractVector{Symbol},
hpd_val = [0.05, 0.2],
q = [0.1, 0.9],
spacer = 0.4,
_riser = 0.2,
barbounds = (-Inf, Inf),
show_mean = true,
show_median = true,
show_qi = false,
show_hpdi = true,
fill_q = true,
fill_hpd = false,
ordered = false
)
chain_dic = Dict(zip(quantile(chains)[:,1], quantile(chains)[:,4]))
sorted_chain = sort(collect(zip(values(chain_dic), keys(chain_dic))))
sorted_par = [sorted_chain[i][2] for i in 1:length(par_names)]
par = (ordered ? sorted_par : par_names)
hpdi = sort(hpd_val)
chain_sections = MCMCChains.group(chains, Symbol(par[i]))
chain_vec = vec(chain_sections.value.data)
lower_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.lower
for j in 1:length(hpdi)]
upper_hpd = [MCMCChains.hpd(chain_sections, alpha = hpdi[j]).nt.upper
for j in 1:length(hpdi)]
h = _riser + spacer*(i-1)
qs = quantile(chain_vec, q)
k_density = kde(chain_vec)
if fill_hpd
x_int = filter(x -> lower_hpd[1][1] <= x <= upper_hpd[1][1], k_density.x)
val = pdf(k_density, x_int) .+ h
elseif fill_q
x_int = filter(x -> qs[1] <= x <= qs[2], k_density.x)
val = pdf(k_density, x_int) .+ h
else
x_int = k_density.x
val = k_density.density .+ h
end
chain_med = median(chain_vec)
chain_mean = mean(chain_vec)
min = minimum(k_density.density .+ h)
q_int = (show_qi ? [qs[1], chain_med, qs[2]] : [chain_med])
return par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med,
chain_mean, min, q_int
end
@recipe function f(
p::RidgelinePlot;
hpd_val = [0.05, 0.2],
q = [0.1, 0.9],
spacer = 0.5,
_riser = 0.2,
show_mean = true,
show_median = true,
show_qi = false,
show_hpdi = true,
fill_q = true,
fill_hpd = false,
ordered = false
)
chn = p.args[1]
par_names = p.args[2]
for i in 1:length(par_names)
par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
min, q_int = _compute_plot_data(i, chn, par_names, hpd_val, q, spacer, _riser,
show_mean, show_median, show_qi, show_hpdi, fill_q, fill_hpd, ordered)
yticks --> (length(par_names) > 1 ?
(_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" )
@series begin
seriestype := :hline
label := nothing
linecolor := "#BBBBBB"
linewidth --> 1.2
[h]
end
@series begin
seriestype := :path
label := nothing
fillrange --> min
fillalpha --> 0.8
x_int, val
end
@series begin
seriestype := :path
label := nothing
linecolor --> "#000000"
k_density.x, k_density.density .+ h
end
@series begin
seriestype := :path
label --> (show_mean ? (i == 1 ? "Mean" : nothing) : nothing)
linecolor --> "dark red"
linewidth --> (show_mean ? 1.2 : 0)
[chain_mean, chain_mean], [min, min + pdf(k_density, chain_mean)]
end
@series begin
seriestype := :path
label --> (show_median ? (i == 1 ? "Median" : nothing) : nothing)
linecolor --> "#000000"
linewidth --> (show_median ? 1.2 : 0)
[chain_med, chain_med], [min, min + pdf(k_density, chain_med)]
end
@series begin
seriestype := :scatter
label := (show_qi ? (i == 1 ? "Q$(q[1]), Q$(q[2])" : nothing) : nothing)
markershape --> (show_qi ? :diamond : :circle)
markercolor --> "#000000"
markersize --> (show_qi ? 2 : 0)
q_int, [h]
end
@series begin
seriestype := :path
label := nothing
linecolor := "#000000"
linewidth --> (show_qi ? 1.2 : 0)
[qs[1], qs[2]], [h, h]
end
@series begin
seriestype := :path
label := (show_hpdi ? (i == 1 ? "$(Integer((1-hpdi[1])*100))% HPDI" : nothing)
: nothing)
linewidth --> (show_hpdi ? 2 : 0)
seriesalpha --> 0.80
linecolor --> :darkblue
[lower_hpd[1][1], upper_hpd[1][1]], [h, h]
end
end
end
@recipe function f(
p::ForestPlot;
hpd_val = [0.05, 0.2],
q = [0.1, 0.9],
spacer = 0.5,
_riser = 0.2,
show_mean = true,
show_median = true,
show_qi = false,
show_hpdi = true,
fill_q = true,
fill_hpd = false,
ordered = false
)
chn = p.args[1]
par_names = p.args[2]
for i in 1:length(par_names)
par, hpdi, lower_hpd, upper_hpd, h, qs, k_density, x_int, val, chain_med, chain_mean,
min, q_int = _compute_plot_data(i, chn, par_names, hpd_val, q, spacer, _riser,
show_mean, show_median, show_qi, show_hpdi, fill_q, fill_hpd, ordered)
yticks --> (length(par_names) > 1 ?
(_riser .+ ((1:length(par_names)) .- 1) .* spacer, string.(par)) : :default)
yaxis --> (length(par_names) > 1 ? "Parameters" : "Density" )
for j in 1:length(hpdi)
@series begin
seriestype := :path
label := (show_hpdi ?
(i == 1 ? "$(Integer((1-hpdi[j])*100))% HPDI" : nothing) : nothing)
linecolor --> j
linewidth --> (show_hpdi ? 1.5*j : 0)
seriesalpha --> 0.80
[lower_hpd[j][1], upper_hpd[j][1]], [h, h]
end
end
@series begin
seriestype := :scatter
label := (show_median ? (i == 1 ? "Median" : nothing) : nothing)
markershape --> :diamond
markercolor --> "#000000"
markersize --> (show_median ? length(hpdi) : 0)
[chain_med], [h]
end
@series begin
seriestype := :scatter
label := (show_mean ? (i == 1 ? "Mean" : nothing) : nothing)
markershape --> :circle
markercolor --> :gray
markersize --> (show_mean ? length(hpdi) : 0)
[chain_mean], [h]
end
@series begin
seriestype := :scatter
label := (show_qi ? (i == 1 ? "Q1 = $(q[1]), Q3 = $(q[2])" : nothing) : nothing)
markershape --> (show_qi ? :diamond : :circle)
markercolor --> "#000000"
markersize --> (show_qi ? 2 : 0)
q_int, [h]
end
@series begin
seriestype := :path
label := nothing
linecolor := "#000000"
linewidth --> (show_qi ? 1.2 : 0.0)
[qs[1], qs[2]], [h, h]
end
end
end
@recipe function f(
yobs_data,
ypred_data::Chains;
yvar_name::AbstractVector{Symbol} = [],
plot_type = :density,
predictive_check = :posterior,
n_samples::Int = 50
)
st = get(plotattributes, :seriestype, :traceplot)
if st == :ppcplot
N = n_samples <= size(ypred_data)[1] ? n_samples : size(ypred_data)[1]
index = sample(1:size(ypred_data)[1], N, replace = false, ordered = true)
if ndims(yobs_data) == 1
y_obs = plot_type == :cumulative ? ecdf(vec(yobs_data)) : vec(yobs_data)
predictions = ypred_data.value.data[index,:,:]
ymean_pred = (plot_type == :cumulative
? ecdf(vec(mean(ypred_data.value.data, dims = 1)))
: vec(mean(ypred_data.value.data, dims = 1)))
if plot_type == :density || plot_type == :cumulative
if predictive_check == :posterior
title --> "Posterior predictive check"
elseif predictive_check == :prior
title --> "Prior predictive check"
else
throw(ArgumentError("`predictive_check` must be one of `prior` or `posterior`"))
end
for i in 1:N
y_pred = (plot_type == :cumulative ? ecdf(vec(predictions[i,:,:]))
: vec(predictions[i,:,:]))
ypred_label = (isempty(yvar_name) ? (i == 1 ? "y pred" : nothing)
: (i == 1 ? "$(yvar_name[1]) pred" : nothing))
@series begin
seriestype := :density
seriesalpha --> 0.3
linecolor --> "#BBBBBB"
label --> ypred_label
y_pred
end
end
@series begin
seriestype := :density
label --> (isempty(yvar_name) ? "y obs" : "$(yvar_name[1]) obs")
y_obs
end
@series begin
seriestype := :density
label --> (isempty(yvar_name) ? "y mean" : "$(yvar_name[1]) mean")
ymean_pred
end
elseif plot_type == :histogram
layout --> N + 2
k = 1
@series begin
subplot := k
seriestype := :histogram
label --> (isempty(yvar_name) ? "y obs" : "$(yvar_name[1]) obs")
y_obs
end
k = 2
@series begin
subplot := k
seriestype := :histogram
label --> (isempty(yvar_name) ? "y mean" : "$(yvar_name[1]) mean")
ymean_pred
end
for i in 1:N
y_pred = predictions[i,:,:]
@series begin
subplot := k + i
seriestype := :histogram
label --> nothing
y_pred
end
end
else
throw(ArgumentError("`plot_type` must be one of `:density`, `:cumulative` or `histogram`"))
end
elseif ndims(yobs_data) > 1
n_yval = size(yobs_data)[1]
n_yvar = size(yobs_data)[2]
mean_arr = reshape(mean(ypred_data.value.data, dims = 1), (n_yval, n_yvar))
k = 0
for j in 1:n_yvar
sections = MCMCChains.group(ypred_data, Symbol(yvar_name[j]))
predictions = sections.value.data[index,:,:]
y_obs = (plot_type == :cumulative ? ecdf(vec(yobs_data[:,j]))
: vec(yobs_data[:,j]))
ymean_pred = (plot_type == :cumulative ? ecdf(vec(mean_arr[:,j]))
: vec(mean_arr[:,j]))
if plot_type == :density || plot_type == :cumulative
k += 1
layout --> (1, n_yvar)
if predictive_check == :posterior
title --> "Posterior predictive check"
elseif predictive_check == :prior
title --> "Prior predictive check"
else
throw(ArgumentError("`predictive_check` must be one of `prior` or `posterior`"))
end
for i in 1:N
y_pred = (plot_type == :cumulative ? ecdf(vec(predictions[i,:,:]))
: vec(predictions[i,:,:]))
@series begin
subplot := k
seriestype := :density
seriesalpha --> 0.3
linecolor --> "#BBBBBB"
label --> (i == 1 ? "$(yvar_name[j]) pred" : nothing)
y_pred
end
end
@series begin
subplot := k
seriestype := :density
label --> "$(yvar_name[j]) obs"
y_obs
end
@series begin
subplot := k
seriestype := :density
label --> "$(yvar_name[j]) mean"
ymean_pred
end
elseif plot_type == :histogram
subplot := k
layout --> N + 2
h = 1
@series begin
subplot := h
seriestype := :histogram
label --> "$(yvar_name[j]) obs"
y_obs
end
h = 2
@series begin
subplot := h
seriestype := :histogram
label --> "$(yvar_name[j]) mean"
ymean_pred
end
for i in 1:N
y_pred = predictions[i,:,:]
@series begin
subplot := h + i
seriestype := :histogram
label --> nothing
y_pred
end
end
else
throw(ArgumentError("`plot_type` must be one of `:density`, `:cumulative` or `:histogram`"))
end
end
else
throw(ArgumentError("Observed data must have `dim > 1`"))
end
else
end
end
@recipe function f(p::_PPCPlot)
p.y_obs, p.y_pred
end |
420f9b2
to
537514a
Compare
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #319 +/- ##
==========================================
+ Coverage 86.13% 87.57% +1.43%
==========================================
Files 20 20
Lines 1147 1360 +213
==========================================
+ Hits 988 1191 +203
- Misses 159 169 +10 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@yebai @penelopeysm Ready for review! Please take a look at plots here: preview |
Looks very useful. Thanks @shravanngoswamii. @sunxd3, you might be interested in this, too. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds posterior/prior predictive check (PPC) plotting functionality to MCMCChains.jl. PPC plots are essential tools for Bayesian model validation that compare observed data with samples from the posterior (or prior) predictive distribution to assess model fit.
- Implements
ppcplot
function with four plot types: density, histogram, cumulative distribution, and scatter plots - Adds comprehensive test coverage for the new functionality
- Provides extensive documentation with examples for all plot types and customization options
Reviewed Changes
Copilot reviewed 5 out of 6 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
src/plot.jl | Core implementation of ppcplot function with recipe for different plot types |
src/MCMCChains.jl | Adds import for ecdf function needed for cumulative distribution plots |
test/ppc_tests.jl | Comprehensive test suite covering plot creation, validation, options, and edge cases |
test/runtests.jl | Integrates PPC tests into the main test suite |
docs/src/statsplots.md | Detailed documentation with examples for all PPC plot types and options |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy with this. Realistically I don't think anybody is going to be constructing Chains manually and plotting them so it would probably be a good thing to have some docs focusing on how to generate the data with Turing, but that's not a part of this PR
@sunxd3 Any comments before we merge? |
"ppcplot" function was added for plotting prior/posterior predictive checks for one or more dependent variables. As
args
this function receivesyobs_data
, the observed data for dependet variables (a vector or matrix), andypred_data
, the posterior/prior predictive results (Chains
object). It plots the observed data, a sample of predictions and the predictions mean.As kwargs, this function receives:
yvar_name
(vector ofSymbol
) which contains the name of the dependent variables to be plotted,plot_type
which can take:density
,:cumulative
, and:histogram
as values,predictive_check
for plot titles and can be:prior
or:posterior
(default value is:posterior
)n_samples
which established the number o samples to be plotted (default value is 50, but when plotting it is redefined as the minimum between 50 and sample size in ypred_data).For more than one dependet variable in a single model,
yvar_name
must be provided and the order in which names variables appear must be the same as in the observed data matrix. This was done in order to separate predictions for every dependent variable, becausepredict
does not return predictions ordered by variable.The following is a working example for a model with one dependent variable
And for posterior predictive check
Plot_type = :density

Plot_type = :cumulative
Plot_type = :histogram

Aditionally, this is a working example for a model with two dependent variables