Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Graphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ is_bipartite, bipartite_map,
is_cyclic, topological_sort_by_dfs, dfs_tree, dfs_parents,

# random
randomwalk, self_avoiding_walk, non_backtracking_randomwalk,
randomwalk, self_avoiding_walk, non_backtracking_randomwalk, loop_erased_randomwalk,

# diffusion
diffusion, diffusion_rate,
Expand Down Expand Up @@ -144,6 +144,9 @@ euclidean_graph,
#minimum_spanning_trees
boruvka_mst, kruskal_mst, prim_mst,

#random_spanning_trees
wilson_rst,

#steinertree
steiner_tree,

Expand Down Expand Up @@ -259,6 +262,7 @@ include("community/assortativity.jl")
include("spanningtrees/boruvka.jl")
include("spanningtrees/kruskal.jl")
include("spanningtrees/prim.jl")
include("spanningtrees/wilson.jl")
include("steinertree/steiner_tree.jl")
include("biconnectivity/articulation.jl")
include("biconnectivity/biconnect.jl")
Expand Down
58 changes: 58 additions & 0 deletions src/spanningtrees/wilson.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
wilson_rst(g, distmx=weights(g); seed=-1, rng=GLOBAL_RNG)
Randomly sample a spanning tree from an undirected connected graph g using
[Wilson's algorithm](https://en.wikipedia.org/wiki/Loop-erased_random_walk).
The tree will be sampled with probability measure proportional to the product
of the edge weights (given by distmx).
The exact probability of the tree may be determined by computing the
normalization constant of the distribution via [Kirchoff's](https://en.wikipedia.org/wiki/Kirchhoff%27s_theorem)
or [Tutte's Matrix Tree](https://personalpages.manchester.ac.uk/staff/mark.muldoon/Teaching/DiscreteMaths/LectureNotes/MatrixTreeProof.pdf)
theorem, which is equivalent to the determinant of any minor of the Laplacian matrix.
e.g.
tree = wilson_rst(g)
probability_of_tree = prod([weights(g)[src(e), dst(e)] for e in tree])
probability_of_tree /= det(laplacian_matrix(g)[2:nv(g), 2:nv(g)]))
Comment on lines +16 to +18
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use a julia or jldoctest block here?

"""
function wilson_rst end
@traitfn function wilson_rst(g::AG::(!IsDirected),
distmx::AbstractMatrix{T}=weights(g);
seed::Int=-1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need a seed, when we can pass in a random generator?

rng::AbstractRNG=GLOBAL_RNG
) where {T <: Real, U, AG <: AbstractGraph{U}}

if seed >= 0
rng = getRNG(seed)
end

start1 = rand(rng, 1:nv(g))
start2 = rand(rng, 1:nv(g))
Comment on lines +31 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
start1 = rand(rng, 1:nv(g))
start2 = rand(rng, 1:nv(g))
start1 = rand(rng, vertices(g))
start2 = rand(rng, vertices(g))

might make more sense, although at least for SimpleGraph, this will not be of type Int but of type eltype(g)


walk = loop_erased_randomwalk(g, start1, distmx=distmx, f=[start2], rng=rng)

tree = SimpleGraph(nv(g))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tree is only used to together with add_edge!, which is a costly operation. It is more efficient to directly store the unique edges.

for i = 1:length(walk)-1
add_edge!(tree, walk[i], walk[i+1])
end

visited_vertices = Set(walk)
unvisited_vertices = setdiff(Set([i for i = 1:nv(g)]), visited_vertices)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more efficient to just flag visited vertices (with a linear pass over walk) than to do complex set operations (which I think are not even linear...). We just maintain a BitVector is_visitedof size nv(g).
In loop_erased_randomwalk, e in f becomes is_visited[e], which is also more efficient.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
unvisited_vertices = setdiff(Set([i for i = 1:nv(g)]), visited_vertices)
unvisited_vertices = setdiff(Set(vertices(g)), visited_vertices)


while length(unvisited_vertices) > 0
v = rand(rng, unvisited_vertices)
walk = loop_erased_randomwalk(g, v, distmx=distmx, f=visited_vertices,
rng=rng)

for i = 1:length(walk)-1
add_edge!(tree, walk[i], walk[i+1])
end
walk_set = Set(walk)
union!(visited_vertices, walk_set)
unvisited_vertices = setdiff(unvisited_vertices, walk_set)
end
return [e for e in edges(tree)]
end

63 changes: 62 additions & 1 deletion src/traversals/randomwalks.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""
"""
randomwalk(g, s, niter; seed=-1)

Perform a random walk on graph `g` starting at vertex `s` and continuing for
Expand Down Expand Up @@ -118,3 +118,64 @@ function self_avoiding_walk(g::AG, s::Integer, niter::Integer; seed::Int=-1) whe
end
return visited[1:(i - 1)]
end

"""
loop_erased_randomwalk(g, s, niter, distmx=weights(g); f=Set(), seed=-1,
rng=GLOBAL_RNG)

Perform a [loop-erased random walk](https://en.wikipedia.org/wiki/Loop-erased_random_walk)
on graph `g` starting at vertex `s` and continuing until one of the following
conditions are met: (i) `niter` steps are performed, (ii) the path has no more
places to go, (iii) or the walk reaches an element in f.

If f is specified and the final element of the walk is not in f, the function
will throw an error.

Return a vector of vertices visited in order.
"""
function loop_erased_randomwalk(
g::AG, s::Integer,
niter::Integer=max(100, nv(g)^2);
distmx::AbstractMatrix{T}=weights(g),
f::Union{Set{V},Vector{V}}=Set{Integer}(),
seed::Int=-1,
rng::AbstractRNG=GLOBAL_RNG
Comment on lines +141 to +142
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seem here, do we need seed?

)::Vector{Int} where {T <: Real, U, AG <: AbstractGraph{U}, V <: Integer}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I remember ::SomeType after a function does actually force convert to be used on the result of that function.

s in vertices(g) || throw(BoundsError())

if seed >= 0
rng = getRNG(seed)
end

visited = Vector{Integer}(undef, 1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vector should be typed with eltype(g) (passed by dynamic dispatch)

visited_view = view(visited, 1:1)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to carry a view? The only place where it is used is line 164-165 (and only in one line with my proposed change). The indexing of visited_view is the same as visited because it always start at the first index of visited, so in line 152, 156 and 178, we can replace visited_view by visited.

visited_view[1] = s
i = 1
cur_pos = 1
while i <= niter
cur = visited_view[cur_pos]
if cur in f
break
end
nbrs = neighbors(g, cur)
length(nbrs) == 0 && break
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this check only for the first vertex? Graph is undirected so if we move to an adjacent neighbor, this vertex must have at least one neighbor ?

wght = [distmx[cur, n] for n in nbrs]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
wght = [distmx[cur, n] for n in nbrs]
wght = @view distmx[cur, nbrs]

Maybe this is better?

v = nbrs[findfirst(cumsum(wght) .> rand(rng)*sum(wght))]
if v in visited_view
cur_pos = indexin(v, visited_view)[1]
visited_view = view(visited, 1:cur_pos)
Comment on lines +164 to +166
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if v in visited_view
cur_pos = indexin(v, visited_view)[1]
visited_view = view(visited, 1:cur_pos)
new_cur_pos = findfirst(v, visited_view)
if !isnothing(new_cur_pos)
cur_pos = new_cur_pos
visited_view = view(visited, 1:cur_pos)
end

else
cur_pos += 1
if length(visited) < cur_pos
resize!(visited, min(2*cur_pos, nv(g)))
end
visited[cur_pos] = v
visited_view = view(visited, 1:cur_pos)
end
i += 1
end

length(f) == 0 || visited_view[cur_pos] in f || throw(ErrorException("termiating set was not reached"))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
length(f) == 0 || visited_view[cur_pos] in f || throw(ErrorException("termiating set was not reached"))
length(f) == 0 || visited_view[cur_pos] in f || throw(ErrorException("terminating set was not reached"))

return visited[1:cur_pos]
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ tests = [
"spanningtrees/boruvka",
"spanningtrees/kruskal",
"spanningtrees/prim",
"spanningtrees/wilson",
"steinertree/steiner_tree",
"biconnectivity/articulation",
"biconnectivity/biconnect",
Expand Down
45 changes: 45 additions & 0 deletions test/spanningtrees/wilson.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
@testset "Wilson" begin
g = SimpleGraph(4)
add_edge!(g, 1, 2)
add_edge!(g, 1, 3)
add_edge!(g, 1, 4)
add_edge!(g, 2, 3)
add_edge!(g, 3, 4)

N = 2*10^4
diag_edge = Edge(1, 3)
tol = 0.01

# checking spanning tree count
minor = Matrix(view(laplacian_matrix(g), 2:4, 2:4))
@test Int(det(minor)) == 8

# Testing Wilson's algorithm on an unweighted graphs
diag_count = 0
for ii = 1:N
st = @inferred(wilson_rst(g, seed=5124154 + 312*ii))
diag_count += Int(diag_edge in st)
end
@test abs(0.5 - diag_count/N) < tol

rng = MersenneTwister(123411223)
diag_count = 0
for ii = 1:N
st = @inferred(wilson_rst(g, rng=rng))
diag_count += Int(diag_edge in st)
end
@test abs(0.5 - diag_count/N) < tol

# Testing Wilson's algorithm on a weighted graph
distmx = ones(Float64, 4, 4)
distmx[1,3] = 0.5
distmx[3,1] = 0.5

diag_count = 0
N = 10^5
for ii = 1:N
st = @inferred(wilson_rst(g, distmx, seed=1509812+192*ii))
diag_count += Int(Edge(1,3) in st)
end
@test abs(1/3 - diag_count/N) < tol
end
8 changes: 8 additions & 0 deletions test/traversals/randomwalks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
@test_throws BoundsError randomwalk(g, 20, 20)
@test @inferred(non_backtracking_randomwalk(g, 10, 20)) == [10]
@test @inferred(non_backtracking_randomwalk(g, 1, 20)) == [1:10;]
lerw = loop_erased_randomwalk(g, 1, 20)
@test lerw == [1:length(lerw);]
end

gx = path_graph(10)
Expand All @@ -42,12 +44,18 @@
@test_throws BoundsError self_avoiding_walk(g, 20, 20)
@test @inferred(non_backtracking_randomwalk(g, 1, 20)) == [1:10;]
@test_throws BoundsError non_backtracking_randomwalk(g, 20, 20)
visited = @inferred(loop_erased_randomwalk(g, 1, 20))
@test visited == [1:length(visited);]
@test_throws BoundsError loop_erased_randomwalk(g, 20, 20)
end

gx = SimpleDiGraph(path_graph(10))
for g in testdigraphs(gx)
@test @inferred(non_backtracking_randomwalk(g, 1, 20)) == [1:10;]
@test_throws BoundsError non_backtracking_randomwalk(g, 20, 20)
visited = @inferred(loop_erased_randomwalk(g, 1, 20))
@test visited == [1:length(visited);]
@test_throws BoundsError loop_erased_randomwalk(g, 20, 20)
end

gx = cycle_graph(10)
Expand Down