Skip to content

Commit ad89c67

Browse files
paldayararslan
andauthored
fix boolean-as-categorical support (#327)
* fix boolean-as-categorical support * ignore versioned manifests * fix method ambiguity --------- Co-authored-by: Alex Arslan <[email protected]>
1 parent 8bf821e commit ad89c67

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
docs/build
55
Manifest.toml
6+
Manifest-v*.toml

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StatsModels"
22
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
3-
version = "0.7.5"
3+
version = "0.7.6"
44

55
[deps]
66
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"

src/contrasts.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ function StatsAPI.coefnames(C::AbstractContrasts, levels::AbstractVector, basein
233233
DataAPI.unwrap.(levels[not_base])
234234
end
235235

236+
function StatsAPI.coefnames(C::AbstractContrasts, levels::AbstractVector{Bool}, baseind::Integer)
237+
not_base = [firstindex(levels):(baseind - 1); (baseind + 1):lastindex(levels)]
238+
# broadcasted DataAPI.unwrap converts Vector{Bool} to BitVector
239+
convert(Vector{Bool}, DataAPI.unwrap.(levels[not_base]))
240+
end
241+
236242
Base.getindex(contrasts::ContrastsMatrix, rowinds, colinds) =
237243
getindex(contrasts.matrix, getindex.(Ref(contrasts.invindex), rowinds), colinds)
238244

@@ -598,6 +604,10 @@ end
598604
StatsAPI.coefnames(C::HypothesisCoding, levels::AbstractVector, baseind::Int) =
599605
something(C.labels, DataAPI.unwrap.(levels[1:length(levels) .!= baseind]))
600606

607+
# We need an explicit method for `AbstractVector{Bool}` to avoid an ambiguity
608+
StatsAPI.coefnames(C::HypothesisCoding, levels::AbstractVector{Bool}, baseind::Int) =
609+
something(C.labels, DataAPI.unwrap.(levels[1:length(levels) .!= baseind]))
610+
601611
DataAPI.levels(c::HypothesisCoding) = c.levels
602612

603613
"""

test/contrasts.jl

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@testset "contrasts" begin
22

33
cm = StatsModels.ContrastsMatrix(DummyCoding(), ["a", "b"])
4-
@test_logs((:warn,
4+
@test_logs((:warn,
55
"The `termnames` field of `ConstrastsMatrix` is deprecated; use `coefnames(cm)` instead."),
66
cm.termnames)
77
@test cm.termnames == cm.coefnames
@@ -87,9 +87,9 @@
8787
1 -1 -1
8888
1 0 1]
8989
@test coefnames(mf) == ["(Intercept)"; "x: c"; "x: b"]
90-
90+
9191
# respect order of levels
92-
92+
9393
data = DataFrame(x = levels!(categorical(['A', 'B', 'C', 'C', 'D']), ['C', 'B', 'A', 'D']))
9494
f = apply_schema(@formula(x ~ 1), schema(data))
9595
@test modelcols(f.lhs, data) == [0 1 0; 1 0 0; 0 0 0; 0 0 0; 0 0 1]
@@ -239,7 +239,7 @@
239239
f_effects = apply_schema(f, schema(d2, Dict(:x => effects_hyp)))
240240

241241
y_means = combine(groupby(d2, :x), :y => mean).y_mean
242-
242+
243243
y, X_sdiff = modelcols(f_sdiff, d2)
244244
@test X_sdiff \ y [mean(y_means); diff(y_means)]
245245

@@ -271,18 +271,18 @@
271271
[0 1 0 0
272272
0 0 1 0
273273
0 0 0 1]
274-
274+
275275

276276
cmat2 = contrasts_matrix(HelmertCoding(), 1, 4)
277277
@test needs_intercept(cmat2) == false
278-
hmat2 = hypothesis_matrix(cmat2)
278+
hmat2 = hypothesis_matrix(cmat2)
279279
@test hmat2
280280
[-1/2 1/2 0 0
281281
-1/6 -1/6 1/3 0
282282
-1/12 -1/12 -1/12 1/4]
283283

284284
@test eltype(hmat2) <: Rational
285-
285+
286286
@test hypothesis_matrix(cmat2, intercept=true)
287287
vcat([1/4 1/4 1/4 1/4], hmat2)
288288

@@ -350,7 +350,7 @@
350350
@test levels(c) == levs
351351
# no notion of base level for ContrastsCoding
352352
@test_throws MethodError ContrastsCoding(rand(4,3), base=base)
353-
353+
354354
end
355355

356356
@testset "Non-unique levels" begin
@@ -396,10 +396,26 @@
396396

397397
mm = modelcols(term, (; x=repeat('a':'d'; inner=2)))
398398
smm = modelcols(spterm, (; x=repeat('a':'d'; inner=2)))
399-
399+
400400
@test mm isa Matrix
401401
@test smm isa SparseMatrixCSC
402402
@test mm == smm
403403
end
404-
404+
405+
@testset "booleans as categorical" begin
406+
cm = ContrastsMatrix(EffectsCoding(), [true, false])
407+
@test cm.coefnames == [false]
408+
@test issetequal(cm.levels, [true, false])
409+
410+
hypothesis = [0 1]'
411+
cm = ContrastsMatrix(EffectsCoding(), [true, false])
412+
@test cm.coefnames == [false]
413+
@test issetequal(cm.levels, [true, false])
414+
415+
hc = HypothesisCoding([1 0]; levels=[true, false], labels=["yes", "no"])
416+
cm = ContrastsMatrix(hc, [true, false])
417+
@test issetequal(cm.coefnames, ["yes", "no"])
418+
@test issetequal(cm.levels, [true, false])
419+
end
420+
405421
end

0 commit comments

Comments
 (0)