Skip to content

Commit 143d576

Browse files
authored
Merge pull request #14 from srlearn/multiclass
✨ Add multiclass from_vector
2 parents 759c733 + 07012c3 commit 143d576

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

src/convert.jl

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ representation.
77

88
include("types.jl")
99

10+
function _is_multiclass(y::Vector{Int64})::Bool
11+
return length(unique(y)) > 2
12+
end
13+
1014
"""
1115
from_vector(X::Matrix{Int}, y::Vector{Int}, names::Union{Vector{String}, Nothing} = nothing)
1216
@@ -26,11 +30,22 @@ function from_vector(X::Matrix{Int64}, y::Vector{Int64}, names::Union{Vector{Str
2630

2731
pos, neg, facts = String[], String[], String[]
2832

29-
for (i, row) in enumerate(y)
30-
if Bool(row)
31-
push!(pos, "$(last(names))(id$(i)).")
32-
else
33-
push!(neg, "$(last(names))(id$(i)).")
33+
is_multiclass = _is_multiclass(y)
34+
35+
if is_multiclass
36+
37+
for (i, row) in enumerate(y)
38+
push!(pos, "$(last(names))(id$(i),$(row)).")
39+
end
40+
41+
else
42+
43+
for (i, row) in enumerate(y)
44+
if Bool(row)
45+
push!(pos, "$(last(names))(id$(i)).")
46+
else
47+
push!(neg, "$(last(names))(id$(i)).")
48+
end
3449
end
3550
end
3651

@@ -39,8 +54,14 @@ function from_vector(X::Matrix{Int64}, y::Vector{Int64}, names::Union{Vector{Str
3954
facts = vcat(facts, ["$(var)(id$(j),$(row))." for (j, row) in enumerate(col)])
4055
end
4156

57+
4258
modes = ["$(name)(+id,#var$(name))." for name in names[1:end-1]]
43-
push!(modes, "$(last(names))(+id).")
59+
60+
if is_multiclass
61+
push!(modes, "$(last(names))(+id,#classlabel).")
62+
else
63+
push!(modes, "$(last(names))(+id).")
64+
end
4465

4566
return RelationalDataset((pos, neg, facts)), modes
4667
end

test/test_convert.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,32 @@ using RelationalDatasets
101101
"c(+id).",
102102
]
103103
)
104+
105+
@test RelationalDatasets.from_vector(
106+
[0 1 1; 1 0 2; 2 2 0],
107+
[0, 1, 2],
108+
) == (RelationalDatasets.RelationalDataset((
109+
pos=["v4(id1,0).", "v4(id2,1).", "v4(id3,2)."],
110+
neg=[],
111+
facts=[
112+
"v1(id1,0).",
113+
"v1(id2,1).",
114+
"v1(id3,2).",
115+
"v2(id1,1).",
116+
"v2(id2,0).",
117+
"v2(id3,2).",
118+
"v3(id1,1).",
119+
"v3(id2,2).",
120+
"v3(id3,0).",
121+
])),
122+
[
123+
"v1(+id,#varv1).",
124+
"v2(+id,#varv2).",
125+
"v3(+id,#varv3).",
126+
"v4(+id,#classlabel).",
127+
]
128+
)
129+
104130
end
105131

106132
@testset "Misaligned Classification and Regression" begin

0 commit comments

Comments
 (0)