@@ -7,6 +7,10 @@ representation.
7
7
8
8
include (" types.jl" )
9
9
10
+ function _is_multiclass (y:: Vector{Int64} ):: Bool
11
+ return length (unique (y)) > 2
12
+ end
13
+
10
14
"""
11
15
from_vector(X::Matrix{Int}, y::Vector{Int}, names::Union{Vector{String}, Nothing} = nothing)
12
16
@@ -26,11 +30,22 @@ function from_vector(X::Matrix{Int64}, y::Vector{Int64}, names::Union{Vector{Str
26
30
27
31
pos, neg, facts = String[], String[], String[]
28
32
29
- for (i, row) in enumerate (y)
30
- if Bool (row)
31
- push! (pos, " $(last (names)) (id$(i) )." )
32
- else
33
- push! (neg, " $(last (names)) (id$(i) )." )
33
+ is_multiclass = _is_multiclass (y)
34
+
35
+ if is_multiclass
36
+
37
+ for (i, row) in enumerate (y)
38
+ push! (pos, " $(last (names)) (id$(i) ,$(row) )." )
39
+ end
40
+
41
+ else
42
+
43
+ for (i, row) in enumerate (y)
44
+ if Bool (row)
45
+ push! (pos, " $(last (names)) (id$(i) )." )
46
+ else
47
+ push! (neg, " $(last (names)) (id$(i) )." )
48
+ end
34
49
end
35
50
end
36
51
@@ -39,8 +54,14 @@ function from_vector(X::Matrix{Int64}, y::Vector{Int64}, names::Union{Vector{Str
39
54
facts = vcat (facts, [" $(var) (id$(j) ,$(row) )." for (j, row) in enumerate (col)])
40
55
end
41
56
57
+
42
58
modes = [" $(name) (+id,#var$(name) )." for name in names[1 : end - 1 ]]
43
- push! (modes, " $(last (names)) (+id)." )
59
+
60
+ if is_multiclass
61
+ push! (modes, " $(last (names)) (+id,#classlabel)." )
62
+ else
63
+ push! (modes, " $(last (names)) (+id)." )
64
+ end
44
65
45
66
return RelationalDataset ((pos, neg, facts)), modes
46
67
end
0 commit comments