@@ -4,6 +4,9 @@ using DataFrames
4
4
using Random
5
5
using Test
6
6
7
+ using GLM: ProbitLink
8
+ using Distributions: Binomial
9
+
7
10
import MLJ, MLJDecisionTreeInterface
8
11
9
12
const SLM = StatsLearnModels
@@ -17,7 +20,7 @@ const SLM = StatsLearnModels
17
20
@testset " interface" begin
18
21
@testset " MLJ" begin
19
22
Random. seed! (123 )
20
- Tree = MLJ. @load (DecisionTreeClassifier, pkg= DecisionTree, verbosity= 0 )
23
+ Tree = MLJ. @load (DecisionTreeClassifier, pkg = DecisionTree, verbosity = 0 )
21
24
fmodel = SLM. fit (Tree (), input[train, :], output[train, :])
22
25
pred = SLM. predict (fmodel, input[test, :])
23
26
accuracy = count (pred. target .== output. target[test]) / length (test)
@@ -32,6 +35,25 @@ const SLM = StatsLearnModels
32
35
accuracy = count (pred. target .== output. target[test]) / length (test)
33
36
@test accuracy > 0.9
34
37
end
38
+
39
+ @testset " GLM" begin
40
+ x = [1 , 2 , 3 ]
41
+ y = [2 , 4 , 7 ]
42
+ input = DataFrame (; ones= ones (length (x)), x)
43
+ output = DataFrame (; y)
44
+ model = LinearRegressor ()
45
+ fmodel = SLM. fit (model, input, output)
46
+ pred = SLM. predict (fmodel, input)
47
+ @test all (isapprox .(pred. y, output. y, atol= 0.5 ))
48
+ x = [1 , 2 , 2 ]
49
+ y = [1 , 0 , 1 ]
50
+ input = DataFrame (; ones= ones (length (x)), x)
51
+ output = DataFrame (; y)
52
+ model = GeneralizedLinearRegressor (Binomial (), ProbitLink ())
53
+ fmodel = SLM. fit (model, input, output)
54
+ pred = SLM. predict (fmodel, input)
55
+ @test all (isapprox .(pred. y, output. y, atol= 0.5 ))
56
+ end
35
57
end
36
58
37
59
@testset " Learn" begin
0 commit comments