Skip to content

Commit 6d581b1

Browse files
nalimilankleinschmidt
authored andcommitted
Fix predict with confidence interval (#160)
* Fix predict with confidence interval DataFrames aren't a dependency anymore. * Add tests * Bump version to 0.6.6
1 parent 04a0ccf commit 6d581b1

File tree

3 files changed

+66
-13
lines changed

3 files changed

+66
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "StatsModels"
22
uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d"
3-
version = "0.6.5"
3+
version = "0.6.6"
44

55
[deps]
66
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"

src/statsmodel.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,26 +131,26 @@ StatsBase.adjr2(mm::TableRegressionModel) = adjr2(mm.model)
131131
StatsBase.r2(mm::TableRegressionModel, variant::Symbol) = r2(mm.model, variant)
132132
StatsBase.adjr2(mm::TableRegressionModel, variant::Symbol) = adjr2(mm.model, variant)
133133

134-
function _return_predictions(yp::AbstractVector, nonmissings, len)
134+
function _return_predictions(T, yp::AbstractVector, nonmissings, len)
135135
out = missings(eltype(yp), len)
136136
out[nonmissings] = yp
137137
out
138138
end
139139

140-
function _return_predictions(yp::AbstractMatrix, nonmissings, len)
140+
function _return_predictions(T, yp::AbstractMatrix, nonmissings, len)
141141
out = missings(eltype(yp), (len, 3))
142142
out[nonmissings, :] = yp
143-
DataFrame(prediction = out[:,1], lower = out[:,2], upper = out[:,3])
143+
T((prediction = out[:,1], lower = out[:,2], upper = out[:,3]))
144144
end
145145

146-
function _return_predictions(yp::NamedTuple, nonmissings, len)
146+
function _return_predictions(T, yp::NamedTuple, nonmissings, len)
147147
y = missings(eltype(yp[:prediction]), len)
148148
l, h = similar(y), similar(y)
149149
out = (prediction = y, lower = l, upper = h)
150150
for key in (:prediction, :lower, :upper)
151151
out[key][nonmissings] = yp[key]
152152
end
153-
DataFrame(out)
153+
T(out)
154154
end
155155

156156
# Predict function that takes data table as predictor instead of matrix
@@ -163,7 +163,7 @@ function StatsBase.predict(mm::TableRegressionModel, data; kwargs...)
163163
new_x = modelcols(f.rhs, cols)
164164
y_pred = predict(mm.model, reshape(new_x, size(new_x, 1), :);
165165
kwargs...)
166-
_return_predictions(y_pred, nonmissings, length(nonmissings))
166+
_return_predictions(Tables.materializer(data), y_pred, nonmissings, length(nonmissings))
167167
end
168168

169169
StatsBase.coefnames(model::TableModels) = coefnames(model.mf)

test/statsmodel.jl

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ struct DummyMod <: RegressionModel
77
y::Vector
88
end
99

10-
StatsBase.predict(mod::DummyMod) = mod.x * mod.beta
11-
StatsBase.predict(mod::DummyMod, newX::Matrix) = newX * mod.beta
1210
## dumb fit method: just copy the x and y input over
1311
StatsBase.fit(::Type{DummyMod}, x::Matrix, y::Vector) =
1412
DummyMod(collect(1:size(x, 2)), x, y)
@@ -19,6 +17,29 @@ StatsBase.coeftable(mod::DummyMod) =
1917
["'beta' value"],
2018
["" for n in 1:size(mod.x,2)],
2119
0)
20+
# dumb predict: return values predicted by "beta" and dummy confidence bounds
21+
function StatsBase.predict(mod::DummyMod;
22+
interval::Union{Nothing,Symbol}=nothing)
23+
pred = mod.x * mod.beta
24+
if interval === nothing
25+
return pred
26+
elseif interval === :prediction
27+
return (prediction=pred, lower=pred .- 1, upper=pred .+ 1)
28+
else
29+
throw(ArgumentError("value not allowed for interval"))
30+
end
31+
end
32+
function StatsBase.predict(mod::DummyMod, newX::Matrix;
33+
interval::Union{Nothing,Symbol}=nothing)
34+
pred = newX * mod.beta
35+
if interval === nothing
36+
return pred
37+
elseif interval === :prediction
38+
return (prediction=pred, lower=pred .- 1, upper=pred .+ 1)
39+
else
40+
throw(ArgumentError("value not allowed for interval"))
41+
end
42+
end
2243

2344
# A dummy RegressionModel type that does not support intercept
2445
struct DummyModNoIntercept <: RegressionModel
@@ -39,8 +60,29 @@ StatsBase.coeftable(mod::DummyModNoIntercept) =
3960
["'beta' value"],
4061
["" for n in 1:size(mod.x,2)],
4162
0)
42-
StatsBase.predict(mod::DummyModNoIntercept) = mod.x * mod.beta
43-
StatsBase.predict(mod::DummyModNoIntercept, newX::Matrix) = newX * mod.beta
63+
# dumb predict: return values predicted by "beta" and dummy confidence bounds
64+
function StatsBase.predict(mod::DummyModNoIntercept;
65+
interval::Union{Nothing,Symbol}=nothing)
66+
pred = mod.x * mod.beta
67+
if interval === nothing
68+
return pred
69+
elseif interval === :prediction
70+
return (prediction=pred, lower=pred .- 1, upper=pred .+ 1)
71+
else
72+
throw(ArgumentError("value not allowed for interval"))
73+
end
74+
end
75+
function StatsBase.predict(mod::DummyModNoIntercept, newX::Matrix;
76+
interval::Union{Nothing,Symbol}=nothing)
77+
pred = newX * mod.beta
78+
if interval === nothing
79+
return pred
80+
elseif interval === :prediction
81+
return (prediction=pred, lower=pred .- 1, upper=pred .+ 1)
82+
else
83+
throw(ArgumentError("value not allowed for interval"))
84+
end
85+
end
4486

4587
## Another dummy model type to test fall-through show method
4688
struct DummyModTwo <: RegressionModel
@@ -74,10 +116,21 @@ Base.show(io::IO, m::DummyModTwo) = println(io, m.msg)
74116

75117
## new data from matrix
76118
mm = ModelMatrix(ModelFrame(f, d))
77-
@test predict(m, mm.m) == mm.m * collect(1:4)
119+
p = predict(m, mm.m)
120+
@test p == mm.m * collect(1:4)
121+
p2 = predict(m, mm.m, interval=:prediction)
122+
@test p2 isa NamedTuple
123+
@test p2.prediction == p
124+
@test p2.lower == p .- 1
125+
@test p2.upper == p .+ 1
78126

79127
## new data from DataFrame (via ModelMatrix)
80-
@test predict(m, d) == predict(m, mm.m)
128+
@test predict(m, d) == p
129+
p3 = predict(m, d, interval=:prediction)
130+
@test p3 isa DataFrame
131+
@test p3.prediction == p
132+
@test p3.lower == p .- 1
133+
@test p3.upper == p .+ 1
81134

82135
d2 = deepcopy(d)
83136
d2[3, :x1] = missing

0 commit comments

Comments
 (0)