Skip to content

Commit 776aa9e

Browse files
committed
Add streaming API
1 parent ec7dba3 commit 776aa9e

File tree

2 files changed

+188
-0
lines changed

2 files changed

+188
-0
lines changed

src/Dagger.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ include("utils/locked-object.jl")
4141
include("utils/caching.jl")
4242
include("sch/Sch.jl"); using .Sch
4343

44+
# Streaming
45+
include("stream.jl")
46+
4447
# Array computations
4548
include("array/darray.jl")
4649
include("array/alloc.jl")

src/stream.jl

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
struct StreamingTaskQueue <: AbstractTaskQueue
2+
tasks::Vector{Pair{EagerTaskSpec,EagerThunk}}
3+
StreamingTaskQueue() = new(Pair{EagerTaskSpec,EagerThunk}[])
4+
end
5+
6+
function enqueue!(queue::StreamingTaskQueue, spec::Pair{EagerTaskSpec,EagerThunk})
7+
push!(queue.tasks, spec)
8+
end
9+
function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{EagerTaskSpec,EagerThunk}})
10+
append!(queue.tasks, specs)
11+
end
12+
13+
function spawn_streaming(f::Base.Callable)
14+
queue = StreamingTaskQueue()
15+
result = with_options(;task_queue=queue) do
16+
spawn_bulk(f)
17+
end
18+
if length(queue.tasks) > 0
19+
setup_streaming!(queue.tasks)
20+
enqueue!(queue.tasks)
21+
end
22+
return result
23+
end
24+
25+
mutable struct StreamElement
26+
value
27+
remaining::Int
28+
end
29+
function consume!(elem::StreamElement)
30+
elem.remaining -= 1
31+
return (elem.value, elem.remaining)
32+
end
33+
mutable struct StreamStore
34+
waiters::Int
35+
buffer::Vector{StreamElement}
36+
open::Bool
37+
lock::Threads.Condition
38+
StreamStore(waiters) = new(waiters, StreamElement[], true, Threads.Condition())
39+
end
40+
function Base.put!(store::StreamStore, @nospecialize(value))
41+
elem = StreamElement(value, store.waiters)
42+
@lock store.lock begin
43+
@debug "[$(myid())] adding $value"
44+
push!(store.buffer, elem)
45+
notify(store.lock)
46+
end
47+
end
48+
function Base.take!(store::StreamStore)
49+
@lock store.lock begin
50+
while true
51+
@debug "[$(myid())] take loop"
52+
if length(store.buffer) == 0
53+
wait(store.lock)
54+
end
55+
@debug "[$(myid())] wait finished"
56+
if !isopen(store)
57+
@debug "[$(myid())] closed!"
58+
throw(InvalidStateException("Stream is closed", :closed))
59+
end
60+
value, remaining = consume!(first(store.buffer))
61+
@debug "[$(myid())] took $value ($remaining remaining)"
62+
if remaining < 0
63+
throw(ConcurrencyViolationError("StoreElement consumed too many times"))
64+
end
65+
if remaining == 0
66+
@debug "[$(myid())] removing from buffer"
67+
popfirst!(store.buffer)
68+
end
69+
if remaining >= 0
70+
@debug "[$(myid())] value accepted"
71+
return value
72+
end
73+
@debug "[$(myid())] value ignored"
74+
end
75+
end
76+
end
77+
Base.isopen(store::StreamStore) = store.open
78+
function Base.close(store::StreamStore)
79+
store.open = false
80+
end
81+
82+
mutable struct Stream{T} <: AbstractChannel{T}
83+
ref::Union{Chunk,Nothing}
84+
waiters::Int
85+
end
86+
Stream() = Stream{Any}(nothing, 0)
87+
function initialize!(stream::Stream)
88+
@assert stream.ref === nothing "Stream already initialized"
89+
stream.ref = tochunk(StreamStore(stream.waiters))
90+
end
91+
function Base.put!(stream::Stream, @nospecialize(value))
92+
remotecall_wait(stream.ref.handle.owner, stream.ref.handle, value) do ref, value
93+
@nospecialize value
94+
store = MemPool.poolget(ref)::StreamStore
95+
put!(store, value)
96+
end
97+
end
98+
function Base.take!(stream::Stream{T}) where T
99+
return remotecall_fetch(stream.ref.handle.owner, stream.ref.handle) do ref
100+
store = MemPool.poolget(ref)::StreamStore
101+
return take!(store)::T
102+
end
103+
end
104+
function Base.isopen(stream::Stream)::Bool
105+
return remotecall_fetch(stream.ref.handle.owner, stream.ref.handle) do ref
106+
return isopen(MemPool.poolget(ref)::StreamStore)
107+
end
108+
end
109+
function Base.close(stream::Stream)
110+
remotecall_wait(stream.ref.handle.owner, stream.ref.handle) do ref
111+
close(MemPool.poolget(ref)::StreamStore)
112+
end
113+
end
114+
115+
struct FinishedStreaming end
116+
finish_streaming() = FinishedStreaming()
117+
118+
struct StreamingFunction{F, T}
119+
f::F
120+
stream::Stream{T}
121+
end
122+
function (sf::StreamingFunction)(args...; kwargs...)
123+
@nospecialize sf args kwargs
124+
result = nothing
125+
stream_args = Base.mapany(identity, args)
126+
stream_kwargs = Base.mapany(identity, kwargs)
127+
try
128+
while true
129+
# Get values from Stream args/kwargs
130+
for (idx, arg) in enumerate(args)
131+
if arg isa Stream
132+
stream_args[idx] = take!(arg)
133+
end
134+
end
135+
for (idx, (pos, arg)) in enumerate(kwargs)
136+
if arg isa Stream
137+
stream_kwargs[idx] = pos => take!(arg)
138+
end
139+
end
140+
141+
# Run a single cycle of f
142+
stream_result = sf.f(stream_args...; stream_kwargs...)
143+
144+
# Exit streaming on graceful request
145+
if stream_result === FinishedStreaming()
146+
return
147+
end
148+
149+
# Put the result into the output stream
150+
put!(sf.stream, stream_result)
151+
end
152+
finally
153+
# Ensure downstream consumers also terminate
154+
close(sf.stream)
155+
end
156+
end
157+
158+
function setup_streaming!(tasks::Vector{Pair{EagerTaskSpec,EagerThunk}})
159+
# Adapt called function for streaming and generate output Streams
160+
self_streams = Dict{UInt,Stream}()
161+
for (spec, task) in tasks
162+
# FIXME: Infer type
163+
stream = Stream()
164+
self_streams[task.uid] = stream
165+
spec.f = StreamingFunction(spec.f, stream)
166+
# FIXME: Generalize to other processors
167+
spec.options = merge(spec.options, (;occupancy=Dict(ThreadProc=>0)))
168+
end
169+
170+
# Adapt args to accept Stream output of other streaming tasks
171+
for (spec, task) in tasks
172+
for (idx, (pos, arg)) in reverse(collect(enumerate(spec.args)))
173+
if arg isa EagerThunk && haskey(self_streams, arg.uid)
174+
other_stream = self_streams[arg.uid]
175+
spec.args[idx] = pos => other_stream
176+
other_stream.waiters += 1
177+
end
178+
end
179+
end
180+
181+
# Initialize all Streams
182+
for stream in values(self_streams)
183+
initialize!(stream)
184+
end
185+
end

0 commit comments

Comments
 (0)