diff --git a/src/Graphs.jl b/src/Graphs.jl index ca07d57c0..245a9f3da 100644 --- a/src/Graphs.jl +++ b/src/Graphs.jl @@ -64,7 +64,7 @@ is_bipartite, bipartite_map, is_cyclic, topological_sort_by_dfs, dfs_tree, dfs_parents, # random -randomwalk, self_avoiding_walk, non_backtracking_randomwalk, +randomwalk, self_avoiding_walk, non_backtracking_randomwalk, loop_erased_randomwalk, # diffusion diffusion, diffusion_rate, @@ -144,6 +144,9 @@ euclidean_graph, #minimum_spanning_trees boruvka_mst, kruskal_mst, prim_mst, +#random_spanning_trees +wilson_rst, + #steinertree steiner_tree, @@ -259,6 +262,7 @@ include("community/assortativity.jl") include("spanningtrees/boruvka.jl") include("spanningtrees/kruskal.jl") include("spanningtrees/prim.jl") +include("spanningtrees/wilson.jl") include("steinertree/steiner_tree.jl") include("biconnectivity/articulation.jl") include("biconnectivity/biconnect.jl") diff --git a/src/spanningtrees/wilson.jl b/src/spanningtrees/wilson.jl new file mode 100644 index 000000000..df16df036 --- /dev/null +++ b/src/spanningtrees/wilson.jl @@ -0,0 +1,58 @@ +""" + wilson_rst(g, distmx=weights(g); seed=-1, rng=GLOBAL_RNG) + +Randomly sample a spanning tree from an undirected connected graph g using +[Wilson's algorithm](https://en.wikipedia.org/wiki/Loop-erased_random_walk). +The tree will be sampled with probability measure proportional to the product +of the edge weights (given by distmx). + +The exact probability of the tree may be determined by computing the +normalization constant of the distribution via [Kirchoff's](https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem) +or [Tutte's Matrix Tree](https://personalpages.manchester.ac.uk/staff/mark.muldoon/Teaching/DiscreteMaths/LectureNotes/MatrixTreeProof.pdf) +theorem, which is equivalent to the determinant of any minor of the Laplacian matrix. + +e.g. + +tree = wilson_rst(g) +probability_of_tree = prod([weights(g)[src(e), dst(e)] for e in tree]) +probability_of_tree /= det(laplacian_matrix(g)[2:nv(g), 2:nv(g)])) +""" +function wilson_rst end +@traitfn function wilson_rst(g::AG::(!IsDirected), + distmx::AbstractMatrix{T}=weights(g); + seed::Int=-1, + rng::AbstractRNG=GLOBAL_RNG +) where {T <: Real, U, AG <: AbstractGraph{U}} + + if seed >= 0 + rng = getRNG(seed) + end + + start1 = rand(rng, 1:nv(g)) + start2 = rand(rng, 1:nv(g)) + + walk = loop_erased_randomwalk(g, start1, distmx=distmx, f=[start2], rng=rng) + + tree = SimpleGraph(nv(g)) + for i = 1:length(walk)-1 + add_edge!(tree, walk[i], walk[i+1]) + end + + visited_vertices = Set(walk) + unvisited_vertices = setdiff(Set([i for i = 1:nv(g)]), visited_vertices) + + while length(unvisited_vertices) > 0 + v = rand(rng, unvisited_vertices) + walk = loop_erased_randomwalk(g, v, distmx=distmx, f=visited_vertices, + rng=rng) + + for i = 1:length(walk)-1 + add_edge!(tree, walk[i], walk[i+1]) + end + walk_set = Set(walk) + union!(visited_vertices, walk_set) + unvisited_vertices = setdiff(unvisited_vertices, walk_set) + end + return [e for e in edges(tree)] +end + diff --git a/src/traversals/randomwalks.jl b/src/traversals/randomwalks.jl index 3631d7b1d..8fd0d664c 100644 --- a/src/traversals/randomwalks.jl +++ b/src/traversals/randomwalks.jl @@ -1,4 +1,4 @@ -""" + """ randomwalk(g, s, niter; seed=-1) Perform a random walk on graph `g` starting at vertex `s` and continuing for @@ -118,3 +118,64 @@ function self_avoiding_walk(g::AG, s::Integer, niter::Integer; seed::Int=-1) whe end return visited[1:(i - 1)] end + +""" + loop_erased_randomwalk(g, s, niter, distmx=weights(g); f=Set(), seed=-1, + rng=GLOBAL_RNG) + +Perform a [loop-erased random walk](https://en.wikipedia.org/wiki/Loop-erased_random_walk) +on graph `g` starting at vertex `s` and continuing until one of the following +conditions are met: (i) `niter` steps are performed, (ii) the path has no more +places to go, (iii) or the walk reaches an element in f. + +If f is specified and the final element of the walk is not in f, the function +will throw an error. + +Return a vector of vertices visited in order. +""" +function loop_erased_randomwalk( + g::AG, s::Integer, + niter::Integer=max(100, nv(g)^2); + distmx::AbstractMatrix{T}=weights(g), + f::Union{Set{V},Vector{V}}=Set{Integer}(), + seed::Int=-1, + rng::AbstractRNG=GLOBAL_RNG +)::Vector{Int} where {T <: Real, U, AG <: AbstractGraph{U}, V <: Integer} + s in vertices(g) || throw(BoundsError()) + + if seed >= 0 + rng = getRNG(seed) + end + + visited = Vector{Integer}(undef, 1) + visited_view = view(visited, 1:1) + visited_view[1] = s + i = 1 + cur_pos = 1 + while i <= niter + cur = visited_view[cur_pos] + if cur in f + break + end + nbrs = neighbors(g, cur) + length(nbrs) == 0 && break + wght = [distmx[cur, n] for n in nbrs] + v = nbrs[findfirst(cumsum(wght) .> rand(rng)*sum(wght))] + if v in visited_view + cur_pos = indexin(v, visited_view)[1] + visited_view = view(visited, 1:cur_pos) + else + cur_pos += 1 + if length(visited) < cur_pos + resize!(visited, min(2*cur_pos, nv(g))) + end + visited[cur_pos] = v + visited_view = view(visited, 1:cur_pos) + end + i += 1 + end + + length(f) == 0 || visited_view[cur_pos] in f || throw(ErrorException("termiating set was not reached")) + return visited[1:cur_pos] +end + diff --git a/test/runtests.jl b/test/runtests.jl index 303ae1584..19eaf5682 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,6 +73,7 @@ tests = [ "spanningtrees/boruvka", "spanningtrees/kruskal", "spanningtrees/prim", + "spanningtrees/wilson", "steinertree/steiner_tree", "biconnectivity/articulation", "biconnectivity/biconnect", diff --git a/test/spanningtrees/wilson.jl b/test/spanningtrees/wilson.jl new file mode 100644 index 000000000..40d2288bb --- /dev/null +++ b/test/spanningtrees/wilson.jl @@ -0,0 +1,45 @@ +@testset "Wilson" begin + g = SimpleGraph(4) + add_edge!(g, 1, 2) + add_edge!(g, 1, 3) + add_edge!(g, 1, 4) + add_edge!(g, 2, 3) + add_edge!(g, 3, 4) + + N = 2*10^4 + diag_edge = Edge(1, 3) + tol = 0.01 + + # checking spanning tree count + minor = Matrix(view(laplacian_matrix(g), 2:4, 2:4)) + @test Int(det(minor)) == 8 + + # Testing Wilson's algorithm on an unweighted graphs + diag_count = 0 + for ii = 1:N + st = @inferred(wilson_rst(g, seed=5124154 + 312*ii)) + diag_count += Int(diag_edge in st) + end + @test abs(0.5 - diag_count/N) < tol + + rng = MersenneTwister(123411223) + diag_count = 0 + for ii = 1:N + st = @inferred(wilson_rst(g, rng=rng)) + diag_count += Int(diag_edge in st) + end + @test abs(0.5 - diag_count/N) < tol + + # Testing Wilson's algorithm on a weighted graph + distmx = ones(Float64, 4, 4) + distmx[1,3] = 0.5 + distmx[3,1] = 0.5 + + diag_count = 0 + N = 10^5 + for ii = 1:N + st = @inferred(wilson_rst(g, distmx, seed=1509812+192*ii)) + diag_count += Int(Edge(1,3) in st) + end + @test abs(1/3 - diag_count/N) < tol +end diff --git a/test/traversals/randomwalks.jl b/test/traversals/randomwalks.jl index 11b7c4cad..68626c02b 100644 --- a/test/traversals/randomwalks.jl +++ b/test/traversals/randomwalks.jl @@ -34,6 +34,8 @@ @test_throws BoundsError randomwalk(g, 20, 20) @test @inferred(non_backtracking_randomwalk(g, 10, 20)) == [10] @test @inferred(non_backtracking_randomwalk(g, 1, 20)) == [1:10;] + lerw = loop_erased_randomwalk(g, 1, 20) + @test lerw == [1:length(lerw);] end gx = path_graph(10) @@ -42,12 +44,18 @@ @test_throws BoundsError self_avoiding_walk(g, 20, 20) @test @inferred(non_backtracking_randomwalk(g, 1, 20)) == [1:10;] @test_throws BoundsError non_backtracking_randomwalk(g, 20, 20) + visited = @inferred(loop_erased_randomwalk(g, 1, 20)) + @test visited == [1:length(visited);] + @test_throws BoundsError loop_erased_randomwalk(g, 20, 20) end gx = SimpleDiGraph(path_graph(10)) for g in testdigraphs(gx) @test @inferred(non_backtracking_randomwalk(g, 1, 20)) == [1:10;] @test_throws BoundsError non_backtracking_randomwalk(g, 20, 20) + visited = @inferred(loop_erased_randomwalk(g, 1, 20)) + @test visited == [1:length(visited);] + @test_throws BoundsError loop_erased_randomwalk(g, 20, 20) end gx = cycle_graph(10)