Skip to content

Commit 14f3b54

Browse files
authored
Handle MatrixTerm when calculating column-to-term indexes in ModelMatrix (#134)
* handle MatrixTerm in asgn model matrix calculation * add test for #133 * bump version (bugfix)
1 parent 3891ba6 commit 14f3b54

File tree

4 files changed

+23
-4
lines changed

4 files changed

+23
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StatsModels"
22
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
3-
version = "0.6.1"
3+
version = "0.6.2"
44

55
[deps]
66
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/modelframe.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,16 @@ end
172172

173173
Base.size(mm::ModelMatrix, dim...) = size(mm.m, dim...)
174174

175+
asgn(f::FormulaTerm) = asgn(f.rhs)
176+
asgn(mt::MatrixTerm) = asgn(mt.terms)
177+
asgn(t) = mapreduce(((i,t), ) -> i*ones(width(t)),
178+
append!,
179+
enumerate(vectorize(t)),
180+
init=Int[])
181+
175182
function ModelMatrix{T}(mf::ModelFrame) where T<:AbstractMatrix{<:AbstractFloat}
176183
mat = modelmatrix(mf)
177-
asgn = mapreduce((it)->first(it)*ones(width(last(it))), append!,
178-
enumerate(vectorize(mf.f.rhs)), init=Int[])
179-
ModelMatrix(convert(T, mat), asgn)
184+
ModelMatrix(convert(T, mat), asgn(mf.f))
180185
end
181186

182187
ModelMatrix(mf::ModelFrame) = ModelMatrix{Matrix{Float64}}(mf)

test/modelframe.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
@testset "ModelFrame (legacy API)" begin
2+
3+
@testset "#133" begin
4+
df = (x = rand(12),
5+
y = categorical(repeat(1:3, inner=4)),
6+
z = categorical(repeat(1:2, outer=6)));
7+
f = @formula(x ~ y * z);
8+
mf = ModelFrame(f, df)
9+
mm = ModelMatrix(mf)
10+
@test mm.assign == [1, 2, 2, 3, 4, 4]
11+
end
12+
13+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ my_tests = ["formula.jl",
1313
"temporal_terms.jl",
1414
"schema.jl",
1515
"modelmatrix.jl",
16+
"modelframe.jl",
1617
"statsmodel.jl",
1718
"contrasts.jl",
1819
"extension.jl"]

0 commit comments

Comments
 (0)