Skip to content

Commit d1e6389

Browse files
authored
Support CategoricalArrays 1 (#324)
Since `levels(::CategoricalArray)` now returns a `CategoricalArray`, we need to unwrap the result before storing it as an `Array` field. This also works on CategoricalArrays 0.10.
1 parent 5891331 commit d1e6389

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

src/contrasts.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,9 @@ ContrastsMatrix(contrasts_matrix::ContrastsMatrix, levels::AbstractVector)
157157
constructing a model matrix from a `ModelFrame` using different data.
158158
159159
"""
160-
function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:AbstractContrasts, T}
160+
function ContrastsMatrix(contrasts::C, levels::AbstractVector) where {C<:AbstractContrasts}
161+
162+
u_levels = DataAPI.unwrap.(levels)
161163

162164
# if levels are defined on contrasts, use those, validating that they line up.
163165
# what does that mean? either:
@@ -167,9 +169,9 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst
167169
# better to filter data frame first
168170
# 3. contrast levels missing from data: would have empty columns, generate a
169171
# rank-deficient model matrix.
170-
c_levels = something(DataAPI.levels(contrasts), levels)
172+
c_levels = something(DataAPI.levels(contrasts), u_levels)
171173

172-
mismatched_levels = symdiff(c_levels, levels)
174+
mismatched_levels = symdiff(c_levels, u_levels)
173175
if !isempty(mismatched_levels)
174176
throw(ArgumentError("contrasts levels not found in data or vice-versa: " *
175177
"$mismatched_levels." *
@@ -179,7 +181,7 @@ function ContrastsMatrix(contrasts::C, levels::AbstractVector{T}) where {C<:Abst
179181

180182
# do conversion AFTER checking for levels so users get a nice error message
181183
# when they've made a mistake with the level types
182-
c_levels = convert(Vector{T}, c_levels)
184+
c_levels = convert(Vector{eltype(u_levels)}, c_levels)
183185

184186
n = length(c_levels)
185187
if n == 0
@@ -228,7 +230,7 @@ end
228230

229231
function StatsAPI.coefnames(C::AbstractContrasts, levels::AbstractVector, baseind::Integer)
230232
not_base = [1:(baseind-1); (baseind+1):length(levels)]
231-
levels[not_base]
233+
DataAPI.unwrap.(levels[not_base])
232234
end
233235

234236
Base.getindex(contrasts::ContrastsMatrix, rowinds, colinds) =
@@ -594,7 +596,7 @@ function contrasts_matrix(C::HypothesisCoding, baseind, n)
594596
end
595597

596598
StatsAPI.coefnames(C::HypothesisCoding, levels::AbstractVector, baseind::Int) =
597-
something(C.labels, levels[1:length(levels) .!= baseind])
599+
something(C.labels, DataAPI.unwrap.(levels[1:length(levels) .!= baseind]))
598600

599601
DataAPI.levels(c::HypothesisCoding) = c.levels
600602

0 commit comments

Comments
 (0)