Skip to content

Commit 94aed2f

Browse files
committed
Some tweaks + docs
1 parent b25d2f4 commit 94aed2f

File tree

4 files changed

+107
-28
lines changed

4 files changed

+107
-28
lines changed

src/Libtask.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ using Core.Compiler: Argument, IRCode, ReturnNode
1515
include("copyable_task.jl")
1616
include("test_utils.jl")
1717

18-
export CopyableTask, consume, produce
18+
export TapedTask, consume, produce
1919

2020
end

src/copyable_task.jl

Lines changed: 89 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,103 @@ __v::Int = 5
44
return nothing
55
end
66

7-
mutable struct CopyableTask{Tmc<:MistyClosure,Targs}
7+
mutable struct TapedTask{Tmc<:MistyClosure,Targs}
88
const mc::Tmc
99
args::Targs
1010
const position::Base.RefValue{Int32}
11+
const deepcopy_types::Type
1112
end
1213

13-
@inline consume(t::CopyableTask) = t.mc(t.args...)
14+
"""
15+
Base.copy(t::TapedTask)
16+
17+
Makes a copy of `t` which can be run. For the most part, calls to [`consume`](@ref) on the
18+
copied task will give the same results as the original. There are, however, substantial
19+
limitations to this, detailed in the extended help.
20+
21+
# Extended Help
22+
23+
We call a copy of a `TapedTask` _consistent_ with the original if the call to `==` in the
24+
loop below always returns `true`:
25+
```julia
26+
t = <some_TapedTask>
27+
tc = copy(t)
28+
for (v, vc) in zip(t, tc)
29+
v == vc
30+
end
31+
```
32+
(provided that `==` is implemented for all `v` that are produced). Convesely, we refer to a
33+
copy as _inconsistent_ if this property doesn't hold. In order to ensure
34+
consistency, we need to ensure that independent copies are made of anything which might be
35+
mutated by the task or its copy during subsequent `consume` calls. Failure to do this can
36+
cause problems if, for example, a task reads-to and writes-from some memory.
37+
If we call `consume` on the original task, and then on a copy of it, any changes made by the
38+
original will be visible to the copy, potentially causing its behaviour to differ. This can
39+
manifest itself as a race condition if the task and its copies are run concurrently.
40+
41+
To understand a bit more about when a task is / is not consistent, we need to dig into the
42+
rather specific semantics of `copy`. Calling `copy` on a `TapedTask` does the following:
43+
1. `copy` the `position` field,
44+
2. `map`s `_tape_copy` over the `args` field, and
45+
3. `map`s `_tape_copy` over the all of the data closed over in the `OpaqueClosure` which
46+
implements the task (specifically the values _inside_ the `Ref`s) -- call these the
47+
`captures`. Except the last elements of this data, because this is `===` to the
48+
`position` field -- for this element we use the copy we made in step 1.
49+
50+
`_tape_copy` doesn't actually make a copy of the object at all if it is not either an
51+
`Array`, a `Ref`, or an instance of one of the types listed in the task's `deepcopy_type`
52+
field. If it is an instance of one of these types then `_tape_copy` just calls `deepcopy`.
53+
54+
This behaviour is plainly entirely acceptable if the argument to `_tape_copy` is a bits
55+
type. For any `mutable struct`s which aren't flagged for `deepcopy`ing, we have an immediate
56+
risk of inconsistency. Similarly, for any `struct` types which aren't bits types (e.g.
57+
those which contain an `Array`, `Ref`, or some other `mutable struct` either directly as one
58+
of their fields, or as a field of a field, etc), we have an inconsistency risk.
59+
60+
Furthermore, for anything which _is_ `deepcopy`ed we introduce inconsistency risks. If, for
61+
example, two elements of the data closed over by the task alias one another, calling
62+
`deepcopy` on them separately will cause the copies to _not_ alias one another.
63+
The same thing can happen if one element is `deepcopy`ed and the other not. For example, if
64+
we have both an `Array` `x` and `view(x, inds)` stored in separate elements of `captures`,
65+
`x` will be `deepcopy`ed, while `view(x, inds)` will not. In the copy of `captures`, the
66+
`view` will still be a view into the original `x`, not the `deepcopy`ed version. Again, this
67+
introduces inconsistency.
68+
69+
Why do we have these semantics? We have them because Libtask has always had them, and at the
70+
time of writing we're unsure whether AdvancedPS.jl, and by extension Turing.jl rely on this
71+
behaviour.
72+
73+
What other options do we have? Simply calling `deepcopy` on a `TapedTask` works fine, and
74+
should reliably result in consistent behaviour between a `TapedTask` and any copies of it.
75+
This would, therefore, be a preferable implementation. We should try to determine whether
76+
this is a viable option.
77+
"""
78+
function Base.copy(t::T) where {T<:TapedTask}
79+
captures = t.mc.oc.captures
80+
new_captures = map(Base.Fix2(_tape_copy, t.deepcopy_types), captures)
81+
new_position = new_captures[end] # baked in later on.
82+
new_args = map(Base.Fix2(_tape_copy, t.deepcopy_types), t.args)
83+
new_mc = Mooncake.replace_captures(t.mc, new_captures)
84+
return T(new_mc, new_args, new_position, t.deepcopy_types)
85+
end
86+
87+
_tape_copy(v, deepcopy_types::Type) = v isa deepcopy_types ? deepcopy(v) : v
88+
89+
# Not sure that we need this in the new implementation.
90+
_tape_copy(box::Core.Box, deepcopy_types::Type) = error("Found a box")
91+
92+
@inline consume(t::TapedTask) = t.mc(t.args...)
1493

15-
function initialise!(t::CopyableTask, args::Vararg{Any,N})::Nothing where {N}
94+
function initialise!(t::TapedTask, args::Vararg{Any,N})::Nothing where {N}
1695
t.position[] = -1
1796
t.args = args
1897
return nothing
1998
end
2099

21-
function CopyableTask(fargs...)
100+
function TapedTask(fargs...; deepcopy_types::Type=Union{})
22101
sig = typeof(fargs)
23102
mc, count_ref = build_callable(Base.code_ircode_by_type(sig)[1][1])
24-
return CopyableTask(mc, fargs[2:end], count_ref)
103+
return TapedTask(mc, fargs[2:end], count_ref, Union{deepcopy_types, Array, Ref})
25104
end
26105

27106
function build_callable(ir::IRCode)
@@ -36,10 +115,10 @@ end
36115
might_produce(sig::Type{<:Tuple})::Bool
37116
38117
`true` if a call to method with signature `sig` is permitted to contain
39-
`CopyableTasks.produce` statements.
118+
`Libtask.produce` statements.
40119
41120
This is an opt-in mechanism. the fallback method of this function returns `false` indicating
42-
that, by default, we assume that calls do not contain `CopyableTasks.produce` statements.
121+
that, by default, we assume that calls do not contain `Libtask.produce` statements.
43122
"""
44123
might_produce(::Type{<:Tuple}) = false
45124

@@ -382,9 +461,9 @@ end
382461
@inline deref_phi(::R, x) where {R<:Tuple} = x
383462

384463
# Implement iterator interface.
385-
function Base.iterate(t::CopyableTask, state::Nothing=nothing)
464+
function Base.iterate(t::TapedTask, state::Nothing=nothing)
386465
v = consume(t)
387466
return v === nothing ? nothing : (v, nothing)
388467
end
389-
Base.IteratorSize(::Type{<:CopyableTask}) = Base.SizeUnknown()
390-
Base.IteratorEltype(::Type{<:CopyableTask}) = Base.EltypeUnknown()
468+
Base.IteratorSize(::Type{<:TapedTask}) = Base.SizeUnknown()
469+
Base.IteratorEltype(::Type{<:TapedTask}) = Base.EltypeUnknown()

src/test_utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module TestUtils
22

33
using ..Libtask
44
using Test
5-
using ..Libtask: CopyableTask
5+
using ..Libtask: TapedTask
66

77
struct Testcase
88
name::String
@@ -14,14 +14,14 @@ function (case::Testcase)()
1414
testset = @testset "$(case.name)" begin
1515

1616
# Construct the task.
17-
t = CopyableTask(case.fargs...)
17+
t = TapedTask(case.fargs...)
1818

1919
# Iterate through t. Record the results, and take a copy after each iteration.
2020
iteration_results = []
21-
t_copies = [deepcopy(t)]
21+
t_copies = [copy(t)]
2222
for val in t
2323
push!(iteration_results, val)
24-
push!(t_copies, deepcopy(t))
24+
push!(t_copies, copy(t))
2525
end
2626

2727
# Check that iterating the original task gives the expected results.

test/copyable_task.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
end
2828
end
2929

30-
ttask = CopyableTask(f)
30+
ttask = TapedTask(f)
3131

3232
next = iterate(ttask)
3333
@test next === (1, nothing)
@@ -57,7 +57,7 @@
5757
end
5858
end
5959

60-
ttask = CopyableTask(f)
60+
ttask = TapedTask(f)
6161
try
6262
consume(ttask)
6363
catch ex
@@ -75,7 +75,7 @@
7575
end
7676
end
7777

78-
ttask = CopyableTask(f)
78+
ttask = TapedTask(f)
7979
try
8080
consume(ttask)
8181
catch ex
@@ -94,7 +94,7 @@
9494
end
9595
end
9696

97-
ttask = CopyableTask(f)
97+
ttask = TapedTask(f)
9898
try
9999
consume(ttask)
100100
catch ex
@@ -113,7 +113,7 @@
113113
end
114114
end
115115

116-
ttask = CopyableTask(f)
116+
ttask = TapedTask(f)
117117
@test consume(ttask) == 2
118118
try
119119
consume(ttask)
@@ -133,9 +133,9 @@
133133
end
134134
end
135135

136-
ttask = CopyableTask(f)
136+
ttask = TapedTask(f)
137137
@test consume(ttask) == 2
138-
ttask2 = deepcopy(ttask)
138+
ttask2 = copy(ttask)
139139
try
140140
consume(ttask2)
141141
catch ex
@@ -155,10 +155,10 @@
155155
end
156156
end
157157

158-
ttask = CopyableTask(f)
158+
ttask = TapedTask(f)
159159
@test consume(ttask) == 0
160160
@test consume(ttask) == 1
161-
a = deepcopy(ttask)
161+
a = copy(ttask)
162162
@test consume(a) == 2
163163
@test consume(a) == 3
164164
@test consume(ttask) == 2
@@ -175,10 +175,10 @@
175175
end
176176
end
177177

178-
ttask = CopyableTask(f)
178+
ttask = TapedTask(f)
179179
@test consume(ttask) == 0
180180
@test consume(ttask) == 1
181-
a = deepcopy(ttask)
181+
a = copy(ttask)
182182
@test consume(a) == 2
183183
@test consume(a) == 3
184184
@test consume(ttask) == 2
@@ -221,11 +221,11 @@
221221
end
222222
end
223223

224-
ttask = CopyableTask(f)
224+
ttask = TapedTask(f)
225225

226226
consume(ttask)
227227
consume(ttask)
228-
a = deepcopy(ttask)
228+
a = copy(ttask)
229229
consume(a)
230230
consume(a)
231231

0 commit comments

Comments
 (0)