diff --git a/doc/source/validate.md b/doc/source/validate.md index abd19881..be57af62 100644 --- a/doc/source/validate.md +++ b/doc/source/validate.md @@ -99,3 +99,13 @@ the similarity of two different clusterings of a dataset. ```@docs mutualinfo ``` + +## Confusion matrix + +Pair [confusion matrix](https://en.wikipedia.org/wiki/Confusion_matrix) +arising from two clusterings is a 2×2 contingency table representation of +the partition co-occurrence, see [`counts`](@ref). + +```@docs +confusion +``` diff --git a/src/Clustering.jl b/src/Clustering.jl index 2da54517..a3513a17 100644 --- a/src/Clustering.jl +++ b/src/Clustering.jl @@ -65,7 +65,10 @@ module Clustering Hclust, hclust, cutree, # MCL - mcl, MCLResult + mcl, MCLResult, + + # pair confusion matrix + confusion ## source files @@ -85,6 +88,7 @@ module Clustering include("varinfo.jl") include("vmeasure.jl") include("mutualinfo.jl") + include("confusion.jl") include("hclust.jl") diff --git a/src/confusion.jl b/src/confusion.jl new file mode 100644 index 00000000..0aabb97d --- /dev/null +++ b/src/confusion.jl @@ -0,0 +1,39 @@ +""" + confusion(a::Union{ClusteringResult, AbstractVector}, + b::Union{ClusteringResult, AbstractVector}) -> Matrix{Int} + +Return 2×2 confusion matrix `C` that represents partition co-occurrence or +similarity matrix between two clusterings by considering all pairs of samples +and counting pairs that are assigned into the same or into different clusters +under the true and predicted clusterings. + +Considering a pair of samples that is in the same group as a **positive pair**, +and a pair is in the different group as a **negative pair**, then the count of +true positives is `C₁₁`, false negatives is `C₁₂`, false positives `C₂₁`, and +true negatives is `C₂₂`: + +| | Positive | Negative | +|:--:|:-:|:-:| +|Positive|C₁₁|C₁₂| +|Negative|C₂₁|C₂₂| +""" +function confusion(a::AbstractVector{<:Integer}, b::AbstractVector{<:Integer}) + c = counts(a, b) + + n = sum(c) + nis = sum(abs2, sum(c, dims=2)) # sum of squares of sums of rows + njs = sum(abs2, sum(c, dims=1)) # sum of squares of sums of columns + + t2 = sum(abs2, c) # sum over rows & columns of nij^2 + t3 = nis + njs + C = [(t2 - n)÷2 (nis - t2)÷2; (njs - t2)÷2 (t2 + n^2 - t3)÷2] + return C +end + +confusion(a::ClusteringResult, b::ClusteringResult) = + confusion(assignments(a), assignments(b)) +confusion(a::AbstractVector{<:Integer}, b::ClusteringResult) = + confusion(a, assignments(b)) +confusion(a::ClusteringResult, b::AbstractVector{<:Integer}) = + confusion(assignments(a), b) + diff --git a/src/randindex.jl b/src/randindex.jl index 23576a31..3cc146f2 100644 --- a/src/randindex.jl +++ b/src/randindex.jl @@ -14,39 +14,29 @@ Returns a tuple of indices: # References > Lawrence Hubert and Phipps Arabie (1985). *Comparing partitions.* -> Journal of Classification 2 (1): 193–218 +> Journal of Classification 2 (1): 193-218 > Meila, Marina (2003). *Comparing Clusterings by the Variation of -> Information.* Learning Theory and Kernel Machines: 173–187. +> Information.* Learning Theory and Kernel Machines: 173-187. + +> Steinley, Douglas (2004). *Properties of the Hubert-Arabie Adjusted +> Rand Index.* Psychological Methods, Vol. 9, No. 3: 386-396 """ function randindex(a, b) - c = counts(a, b) - - n = sum(c) - nis = sum(abs2, sum(c, dims=2)) # sum of squares of sums of rows - njs = sum(abs2, sum(c, dims=1)) # sum of squares of sums of columns - - t1 = binomial(n, 2) # total number of pairs of entities - t2 = sum(abs2, c) # sum over rows & columnns of nij^2 - t3 = .5*(nis+njs) - - # Expected index (for adjustment) - nc = (n*(n^2+1)-(n+1)*nis-(n+1)*njs+2*(nis*njs)/n)/(2*(n-1)) + c11, c21, c12, c22 = confusion(a, b) # Table 2 from Steinley 2004 - A = t1+t2-t3; # agreements count - D = -t2+t3; # disagreements count + t = c11 + c12 + c21 + c22 # total number of pairs of entities + A = c11 + c22 + D = c12 + c21 - if t1 == nc - # avoid division by zero; if k=1, define Rand = 0 - ARI = 0 - else - # adjusted Rand - Hubert & Arabie 1985 - ARI = (A-nc)/(t1-nc) - end + # expected index + ERI = (c11+c12)*(c11+c21)+(c21+c22)*(c12+c22) + # adjusted Rand - Hubert & Arabie 1985 + ARI = D == 0 ? 1.0 : (t*A-ERI)/(t*t-ERI) # (9) from Steinley 2004 - RI = A/t1 # Rand 1971 # Probability of agreement - MI = D/t1 # Mirkin 1970 # p(disagreement) - HI = (A-D)/t1 # Hubert 1977 # p(agree)-p(disagree) + RI = A/t # Rand 1971 # Probability of agreement + MI = D/t # Mirkin 1970 # p(disagreement) + HI = (A-D)/t # Hubert 1977 # p(agree)-p(disagree) return (ARI, RI, MI, HI) end diff --git a/test/confusion.jl b/test/confusion.jl new file mode 100644 index 00000000..24c931ac --- /dev/null +++ b/test/confusion.jl @@ -0,0 +1,44 @@ +# Test confusion matrix + +using Test +using Clustering + +@testset "confusion() (Confusion matrix)" begin + + @testset "small size tests" begin + @test confusion([0,0,0], [0,0,0]) == [3 0; 0 0] + @test confusion([0,0,1], [0,0,0]) == [1 0; 2 0] + @test confusion([0,1,1], [0,0,0]) == [1 0; 2 0] + @test confusion([1,1,1], [0,0,0]) == [3 0; 0 0] + + @test confusion([0,0,0], [0,0,1]) == [1 2; 0 0] + @test confusion([0,0,1], [0,0,1]) == [1 0; 0 2] + @test confusion([0,1,1], [0,0,1]) == [0 1; 1 1] + @test confusion([1,1,1], [0,0,1]) == [1 2; 0 0] + + @test confusion([0,0,0], [0,1,1]) == [1 2; 0 0] + @test confusion([0,0,1], [0,1,1]) == [0 1; 1 1] + @test confusion([0,1,1], [0,1,1]) == [1 0; 0 2] + @test confusion([1,1,1], [0,1,1]) == [1 2; 0 0] + + @test confusion([0,0,0], [1,1,1]) == [3 0; 0 0] + @test confusion([0,0,1], [1,1,1]) == [1 0; 2 0] + @test confusion([0,1,1], [1,1,1]) == [1 0; 2 0] + @test confusion([1,1,1], [1,1,1]) == [3 0; 0 0] + end + + @testset "comparing 2 k-means clusterings" begin + m = 3 + n = 100 + k = 1 + x = rand(m, n) + + # non-weighted + r1 = kmeans(x, k; maxiter=5) + r2 = kmeans(x, k; maxiter=5) + C = confusion(r1, r2) + @test C == [n*(n-1)/2 0; 0 0] + end + +end + diff --git a/test/randindex.jl b/test/randindex.jl index d3b83033..c6ad3de4 100644 --- a/test/randindex.jl +++ b/test/randindex.jl @@ -34,4 +34,9 @@ a3 = [3, 3, 3, 2, 2, 2, 1, 1, 1, 1] @test randindex(a1, a2) == randindex(a2, a1) +@test randindex(ones(Int, 3), ones(Int, 3)) == (1, 1, 0, 1) + +a, b = rand(1:5, 10_000), rand(1:5, 10_000) +@test randindex(a, b)[1] < 1.0e-2 + end diff --git a/test/runtests.jl b/test/runtests.jl index 1f9d483a..9eaca2ed 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,7 +19,8 @@ tests = ["seeding", "hclust", "mcl", "vmeasure", - "mutualinfo"] + "mutualinfo", + "confusion"] println("Runing tests:") for t in tests