@@ -7,8 +7,6 @@ struct DummyMod <: RegressionModel
77 y:: Vector
88end
99
10- StatsBase. predict (mod:: DummyMod ) = mod. x * mod. beta
11- StatsBase. predict (mod:: DummyMod , newX:: Matrix ) = newX * mod. beta
1210# # dumb fit method: just copy the x and y input over
1311StatsBase. fit (:: Type{DummyMod} , x:: Matrix , y:: Vector ) =
1412 DummyMod (collect (1 : size (x, 2 )), x, y)
@@ -19,6 +17,29 @@ StatsBase.coeftable(mod::DummyMod) =
1917 [" 'beta' value" ],
2018 [" " for n in 1 : size (mod. x,2 )],
2119 0 )
20+ # dumb predict: return values predicted by "beta" and dummy confidence bounds
21+ function StatsBase. predict (mod:: DummyMod ;
22+ interval:: Union{Nothing,Symbol} = nothing )
23+ pred = mod. x * mod. beta
24+ if interval === nothing
25+ return pred
26+ elseif interval === :prediction
27+ return (prediction= pred, lower= pred .- 1 , upper= pred .+ 1 )
28+ else
29+ throw (ArgumentError (" value not allowed for interval" ))
30+ end
31+ end
32+ function StatsBase. predict (mod:: DummyMod , newX:: Matrix ;
33+ interval:: Union{Nothing,Symbol} = nothing )
34+ pred = newX * mod. beta
35+ if interval === nothing
36+ return pred
37+ elseif interval === :prediction
38+ return (prediction= pred, lower= pred .- 1 , upper= pred .+ 1 )
39+ else
40+ throw (ArgumentError (" value not allowed for interval" ))
41+ end
42+ end
2243
2344# A dummy RegressionModel type that does not support intercept
2445struct DummyModNoIntercept <: RegressionModel
@@ -39,8 +60,29 @@ StatsBase.coeftable(mod::DummyModNoIntercept) =
3960 [" 'beta' value" ],
4061 [" " for n in 1 : size (mod. x,2 )],
4162 0 )
42- StatsBase. predict (mod:: DummyModNoIntercept ) = mod. x * mod. beta
43- StatsBase. predict (mod:: DummyModNoIntercept , newX:: Matrix ) = newX * mod. beta
63+ # dumb predict: return values predicted by "beta" and dummy confidence bounds
64+ function StatsBase. predict (mod:: DummyModNoIntercept ;
65+ interval:: Union{Nothing,Symbol} = nothing )
66+ pred = mod. x * mod. beta
67+ if interval === nothing
68+ return pred
69+ elseif interval === :prediction
70+ return (prediction= pred, lower= pred .- 1 , upper= pred .+ 1 )
71+ else
72+ throw (ArgumentError (" value not allowed for interval" ))
73+ end
74+ end
75+ function StatsBase. predict (mod:: DummyModNoIntercept , newX:: Matrix ;
76+ interval:: Union{Nothing,Symbol} = nothing )
77+ pred = newX * mod. beta
78+ if interval === nothing
79+ return pred
80+ elseif interval === :prediction
81+ return (prediction= pred, lower= pred .- 1 , upper= pred .+ 1 )
82+ else
83+ throw (ArgumentError (" value not allowed for interval" ))
84+ end
85+ end
4486
4587# # Another dummy model type to test fall-through show method
4688struct DummyModTwo <: RegressionModel
@@ -74,10 +116,21 @@ Base.show(io::IO, m::DummyModTwo) = println(io, m.msg)
74116
75117 # # new data from matrix
76118 mm = ModelMatrix (ModelFrame (f, d))
77- @test predict (m, mm. m) == mm. m * collect (1 : 4 )
119+ p = predict (m, mm. m)
120+ @test p == mm. m * collect (1 : 4 )
121+ p2 = predict (m, mm. m, interval= :prediction )
122+ @test p2 isa NamedTuple
123+ @test p2. prediction == p
124+ @test p2. lower == p .- 1
125+ @test p2. upper == p .+ 1
78126
79127 # # new data from DataFrame (via ModelMatrix)
80- @test predict (m, d) == predict (m, mm. m)
128+ @test predict (m, d) == p
129+ p3 = predict (m, d, interval= :prediction )
130+ @test p3 isa DataFrame
131+ @test p3. prediction == p
132+ @test p3. lower == p .- 1
133+ @test p3. upper == p .+ 1
81134
82135 d2 = deepcopy (d)
83136 d2[3 , :x1 ] = missing
0 commit comments