Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
license = "MIT"
desc = "Tape based task copying in Turing"
repo = "https://github.com/TuringLang/Libtask.jl.git"
version = "0.9.4"
version = "0.9.5"

[deps]
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ An opt-in mechanism marks functions that might contain `Libtask.produce` stateme

```@docs; canonical=true
Libtask.might_produce(::Type{<:Tuple})
Libtask.@might_produce
```
4 changes: 2 additions & 2 deletions perf/p0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end

# Case 1: Sample from the prior.
rng = MersenneTwister()
m = Turing.Core.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng)
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng)
f = m.evaluator[1];
args = m.evaluator[2:end];

Expand All @@ -27,7 +27,7 @@ println("Run a tape...")
@btime t.tf(args...)

# Case 2: SMC sampler
m = Turing.Core.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng)
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng)
f = m.evaluator[1];
args = m.evaluator[2:end];

Expand Down
2 changes: 1 addition & 1 deletion perf/p2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Random.seed!(rng, 2)
iterations = 500
model_fun = infiniteGMM(data)

m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
m = Turing.Inference.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
f = m.evaluator[1]
args = m.evaluator[2:end]

Expand Down
61 changes: 60 additions & 1 deletion src/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,70 @@ end
`true` if a call to method with signature `sig` is permitted to contain
`Libtask.produce` statements.

This is an opt-in mechanism. the fallback method of this function returns `false` indicating
This is an opt-in mechanism. The fallback method of this function returns `false` indicating
that, by default, we assume that calls do not contain `Libtask.produce` statements.
"""
might_produce(::Type{<:Tuple}) = false

"""
@might_produce(f)

If `f` is a function that may call `Libtask.produce` inside it, then `@might_produce(f)`
will generate the appropriate methods needed to ensure that `Libtask.might_produce` returns
`true` for all relevant signatures of `f`. This works even if `f` has methods with keyword
arguments.

```jldoctest might_produce_macro
julia> # For this demonstration we need to mark `g` as not being inlineable.
@noinline function g(x; y, z=0)
produce(x + y + z)
end
g (generic function with 1 method)

julia> function f()
g(1; y=2, z=3)
end
f (generic function with 1 method)

julia> # This returns nothing because `g` isn't yet marked as being able to `produce`.
consume(Libtask.TapedTask(nothing, f))

julia> Libtask.@might_produce(g)

julia> # Now it works!
consume(Libtask.TapedTask(nothing, f))
6
"""
macro might_produce(f)
# See https://github.com/TuringLang/Libtask.jl/issues/197 for discussion of this macro.
quote
function $(Libtask).might_produce(::Type{<:Tuple{typeof($(esc(f))),Vararg}})
return true
Comment on lines +400 to +401
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a little bit of a sledgehammer: we're basically saying, 'any invocation of f with any positional arguments might produce'. This is not necessarily true because some methods of f might produce and some might not.

But since there isn't any real downside to marking all methods are produceable, I don't think this is a huge issue. And if someone wants to be surgical, they can still use the non-macro version.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But since there isn't any real downside to marking all methods are produceable, I don't think this is a huge issue.

I would guess that there's a performance downside. Maybe for both compile-time and for run-time.

I was just thinking of proposing a comment in the docstring noting this aspect, and directing the user to might_produce if the function has many methods and they want to be performance-optimal.

Also, in the future we could make a version where you could restrict the types, like @might_produce(f(::Int,::Any)). Not in this PR though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would guess that there's a performance downside. Maybe for both compile-time and for run-time.

Not super important now, but would be curious to hear more about this, since I never looked closely at how it actually works.

a comment in the docstring

Will do

Also, in the future we could make a version where you could restrict the types

Interestingly enough parsing a function signature in a macro is exactly what I did here so this actually feels a bit familiar 😄 agree it can be a separate thing though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super important now, but would be curious to hear more about this, since I never looked closely at how it actually works.

I think if might_produce returns false for some method/function f, the Libtask-transformed function just calls f. Whereas if might_produce returns true, I think we recurse into transforming f, which means doing all the slow reading/writing of all intermediate variables within f as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see, thanks :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make the warning a bit more strident then because that actually sounds non-trivial 😅

end
possible_n_kwargs = unique(map(length ∘ Base.kwarg_decl, methods($(esc(f)))))
if possible_n_kwargs != [0]
# Oddly we need to interpolate the module and not the function: either
# `$(might_produce)` or $(Libtask.might_produce) seem more natural but both of
# those cause the entire `Libtask.might_produce` to be treated as a single
# symbol. See https://discourse.julialang.org/t/128613
function $(Libtask).might_produce(
::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof($(esc(f))),Vararg}}
)
return true
end
for n in possible_n_kwargs
# We only need `Any` and not `<:Any` because tuples are covariant.
kwarg_types = fill(Any, n)
function $(Libtask).might_produce(
::Type{<:Tuple{<:Function,kwarg_types...,typeof($(esc(f))),Vararg}}
)
return true
end
end
end
end
end

# Helper struct used in `derive_copyable_task_ir`.
struct TupleRef
n::Int
Expand Down
49 changes: 49 additions & 0 deletions test/copyable_task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,4 +251,53 @@
@test Libtask.consume(tt) === :a
@test Libtask.consume(tt) === nothing
end

@testset "@might_produce macro" begin
# Positional arguments only
@noinline g1(x) = produce(x)
f1(x) = g1(x)
# Without marking it as might_produce
tt = Libtask.TapedTask(nothing, f1, 0)
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g1)
tt = Libtask.TapedTask(nothing, f1, 0)
@test Libtask.consume(tt) === 0
@test Libtask.consume(tt) === nothing

# Keyword arguments only
@noinline g2(x; y=1, z=2) = produce(x + y + z)
f2(x) = g2(x)
# Without marking it as might_produce
tt = Libtask.TapedTask(nothing, f2, 0)
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g2)
tt = Libtask.TapedTask(nothing, f2, 0)
@test Libtask.consume(tt) === 3
@test Libtask.consume(tt) === nothing

# A function with multiple methods.
# The function reference is used to ensure that it really doesn't get inlined
# (otherwise, for reasons that are yet unknown, these functions do get inlined when
# inside a testset)
@noinline g3(x) = produce(x)
@noinline g3(x, y; z) = produce(x + y + z)
@noinline g3(x, y, z; p, q) = produce(x + y + z + p + q)
function f3(x, fref)
fref[](x)
fref[](x, 1; z=2)
fref[](x, 1, 2; p=3, q=4)
return nothing
end
tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3))
@test Libtask.consume(tt) === nothing
# Now marking it
Libtask.@might_produce(g3)
tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3))
@test Libtask.consume(tt) === 0
@test Libtask.consume(tt) === 3
@test Libtask.consume(tt) === 10
@test Libtask.consume(tt) === nothing
end
end
Loading