Skip to content

Commit 3abecb2

Browse files
github-actions[bot]CompatHelper Juliasethaxen
authored
CompatHelper: bump compat for DimensionalData to 0.29, (keep existing compat) (#87)
* CompatHelper: bump compat for DimensionalData to 0.29, (keep existing compat) * Use DimensionalData.maplayers if available * Update test/mcmcdiagnostictools.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Increment patch number --------- Co-authored-by: CompatHelper Julia <[email protected]> Co-authored-by: Seth Axen <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent dd5f879 commit 3abecb2

File tree

7 files changed

+19
-9
lines changed

7 files changed

+19
-9
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "InferenceObjects"
22
uuid = "b5cf5a8d-e756-4ee3-b014-01d49d192c00"
33
authors = ["Seth Axen <[email protected]> and contributors"]
4-
version = "0.4.6"
4+
version = "0.4.7"
55

66
[deps]
77
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -23,7 +23,7 @@ InferenceObjectsPosteriorStatsExt = ["PosteriorStats", "StatsBase"]
2323
[compat]
2424
ArviZExampleData = "0.1.10"
2525
Dates = "1.9"
26-
DimensionalData = "0.27, 0.28"
26+
DimensionalData = "0.27, 0.28, 0.29"
2727
EvoTrees = "0.16"
2828
MCMCDiagnosticTools = "0.3.4"
2929
MLJBase = "1"

ext/InferenceObjectsMCMCDiagnosticToolsExt/InferenceObjectsMCMCDiagnosticToolsExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ using DimensionalData: DimensionalData, Dimensions, LookupArrays
55
using InferenceObjects: InferenceObjects, Random
66
using MCMCDiagnosticTools: MCMCDiagnosticTools
77

8+
maplayers = isdefined(DimensionalData, :maplayers) ? DimensionalData.maplayers : map
9+
810
include("utils.jl")
911
include("bfmi.jl")
1012
include("ess_rhat.jl")

ext/InferenceObjectsMCMCDiagnosticToolsExt/ess_rhat.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ end
2121
for f in (:ess, :rhat)
2222
@eval begin
2323
function MCMCDiagnosticTools.$f(data::InferenceObjects.Dataset; kwargs...)
24-
ds = map(data) do var
24+
ds = maplayers(data) do var
2525
return _as_dimarray(MCMCDiagnosticTools.$f(_params_array(var); kwargs...), var)
2626
end
2727
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())

ext/InferenceObjectsMCMCDiagnosticToolsExt/mcse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ function MCMCDiagnosticTools.mcse(data::InferenceObjects.InferenceData; kwargs..
88
return MCMCDiagnosticTools.mcse(data.posterior; kwargs...)
99
end
1010
function MCMCDiagnosticTools.mcse(data::InferenceObjects.Dataset; kwargs...)
11-
ds = map(data) do var
11+
ds = maplayers(data) do var
1212
return _as_dimarray(MCMCDiagnosticTools.mcse(_params_array(var); kwargs...), var)
1313
end
1414
return DimensionalData.rebuild(ds; metadata=DimensionalData.NoMetadata())

ext/InferenceObjectsMCMCDiagnosticToolsExt/rstar.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ end
1616
function MCMCDiagnosticTools.rstar(
1717
rng::Random.AbstractRNG, clf, data::InferenceObjects.Dataset; kwargs...
1818
)
19-
data_array = cat(map(_as_3d_array _params_array, data)...; dims=3)
19+
data_array = cat(maplayers(_as_3d_array _params_array, data)...; dims=3)
2020
return MCMCDiagnosticTools.rstar(rng, clf, data_array; kwargs...)
2121
end
2222
function MCMCDiagnosticTools.rstar(

src/dataset.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ for f in [:data, :dims, :refdims, :metadata, :layerdims, :layermetadata]
126126
end
127127
end
128128

129+
DimensionalData.modify(f, s::Dataset) = Dataset(DimensionalData.modify(f, parent(s)))
130+
129131
# Warning: this is not an API function and probably should be implemented abstractly upstream
130132
DimensionalData.show_after(io, mime, ::Dataset) = nothing
131133

test/mcmcdiagnostictools.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ using Random
77
using Statistics
88
using Test
99

10+
if !isdefined(DimensionalData, :maplayers)
11+
maplayers = map
12+
end
13+
1014
@testset "MCMCDiagnosticTools integration" begin
1115
nchains, ndraws = 4, 10
1216
sizes = (x=(), y=(2,), z=(3, 5))
@@ -16,12 +20,12 @@ using Test
1620
dict1 = Dict(Symbol(k) => randn(ndraws, nchains, sz...) for (k, sz) in pairs(sizes))
1721
idata1 = from_dict(dict1; dims, coords, sample_stats=Dict(:energy => energy))
1822
# permute dimensions to test that diagnostics are invariant to dimension order
19-
post2 = map(idata1.posterior) do var
23+
post2 = maplayers(idata1.posterior) do var
2024
n = ndims(var)
2125
permdims = ((3:n)..., 2, 1)
2226
return permutedims(var, permdims)
2327
end
24-
sample_stats2 = map(permutedims, idata1.sample_stats)
28+
sample_stats2 = maplayers(permutedims, idata1.sample_stats)
2529
idata2 = InferenceData(; posterior=post2, sample_stats=sample_stats2)
2630

2731
@testset for f in (ess, rhat, ess_rhat, mcse)
@@ -35,7 +39,7 @@ using Test
3539
@test issetequal(keys(metric), keys(idata1.posterior))
3640
@test metric == f(idata1.posterior; kind)
3741
@test metric2 == f(idata2.posterior; kind)
38-
@test all(map(, metric2, metric))
42+
@test all(maplayers(, metric2, metric))
3943
for k in keys(sizes)
4044
@test all(
4145
hasdim(
@@ -81,7 +85,9 @@ using Test
8185
r4 = rstar(rng, classifier(rng), idata2.posterior; subset)
8286
rng = Random.seed!(123)
8387
post_mat = cat(
84-
map(var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior)...;
88+
maplayers(
89+
var -> reshape(parent(var), ndraws, nchains, :), idata1.posterior
90+
)...;
8591
dims=3,
8692
)
8793
r5 = rstar(rng, classifier(rng), post_mat; subset)

0 commit comments

Comments
 (0)