diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index d3fc08f12..f7ab7d985 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -6,6 +6,7 @@ os: linux arch: x86_64 command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'" + .bench: &bench if: build.message =~ /\[run benchmarks\]/ agents: @@ -14,6 +15,7 @@ os: linux arch: x86_64 num_cpus: 16 + steps: - label: Julia 1.9 timeout_in_minutes: 90 @@ -25,6 +27,7 @@ steps: julia_args: "--threads=1" - JuliaCI/julia-coverage#v1: codecov: true + - label: Julia 1.10 timeout_in_minutes: 90 <<: *test @@ -35,6 +38,7 @@ steps: julia_args: "--threads=1" - JuliaCI/julia-coverage#v1: codecov: true + - label: Julia nightly timeout_in_minutes: 90 <<: *test @@ -77,6 +81,7 @@ steps: - JuliaCI/julia-coverage#v1: codecov: true command: "julia -e 'using Pkg; Pkg.develop(;path=pwd()); Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.develop(;path=\"lib/DaggerWebDash\"); include(\"lib/DaggerWebDash/test/runtests.jl\")'" + - label: Benchmarks timeout_in_minutes: 120 <<: *bench @@ -93,6 +98,7 @@ steps: BENCHMARK_SCALE: "5:5:50" artifacts: - benchmarks/result* + - label: DTables.jl stability test timeout_in_minutes: 20 plugins: diff --git a/Project.toml b/Project.toml index fd7508cd7..7289ff5f7 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ GraphViz = "0.2" Graphs = "1" JSON3 = "1" MacroTools = "0.5" -MemPool = "0.4.6" +MemPool = "0.4.10" OnlineStats = "1" Plots = "1" PrecompileTools = "1.2" diff --git a/docs/make.jl b/docs/make.jl index c21c03f2d..8f1f97f5c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,6 +22,7 @@ makedocs(; "Task Spawning" => "task-spawning.md", "Data Management" => "data-management.md", "Distributed Arrays" => "darray.md", + "Streaming Tasks" => "streaming.md", "Scopes" => "scopes.md", "Processors" => "processors.md", "Task Queues" => "task-queues.md", diff --git a/docs/src/streaming.md b/docs/src/streaming.md new file mode 100644 index 000000000..338f5739b --- /dev/null +++ b/docs/src/streaming.md @@ -0,0 +1,105 @@ +# Streaming Tasks + +Dagger tasks have a limited lifetime - they are created, execute, finish, and +are eventually destroyed when they're no longer needed. Thus, if one wants +to run the same kind of computations over and over, one might re-create a +similar set of tasks for each unit of data that needs processing. + +This might be fine for computations which take a long time to run (thus +dwarfing the cost of task creation, which is quite small), or when working with +a limited set of data, but this approach is not great for doing lots of small +computations on a large (or endless) amount of data. For example, processing +image frames from a webcam, reacting to messages from a message bus, reading +samples from a software radio, etc. All of these tasks are better suited to a +"streaming" model of data processing, where data is simply piped into a +continuously-running task (or DAG of tasks) forever, or until the data runs +out. + +Thankfully, if you have a problem which is best modeled as a streaming system +of tasks, Dagger has you covered! Building on its support for +[Task Queues](@ref), Dagger provides a means to convert an entire DAG of +tasks into a streaming DAG, where data flows into and out of each task +asynchronously, using the `spawn_streaming` function: + +```julia +Dagger.spawn_streaming() do # enters a streaming region + vals = Dagger.@spawn rand() + print_vals = Dagger.@spawn println(vals) +end # exits the streaming region, and starts the DAG running +``` + +In the above example, `vals` is a Dagger task which has been transformed to run +in a streaming manner - instead of just calling `rand()` once and returning its +result, it will re-run `rand()` endlessly, continuously producing new random +values. In typical Dagger style, `print_vals` is a Dagger task which depends on +`vals`, but in streaming form - it will continuously `println` the random +values produced from `vals`. Both tasks will run forever, and will run +efficiently, only doing the work necessary to generate, transfer, and consume +values. + +As the comments point out, `spawn_streaming` creates a streaming region, during +which `vals` and `print_vals` are created and configured. Both tasks are halted +until `spawn_streaming` returns, allowing large DAGs to be built all at once, +without any task losing a single value. If desired, streaming regions can be +connected, although some values might be lost while tasks are being connected: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.@spawn rand() +end + +# Some values might be generated by `vals` but thrown away +# before `print_vals` is fully setup and connected to it + +print_vals = Dagger.spawn_streaming() do + Dagger.@spawn println(vals) +end +``` + +More complicated streaming DAGs can be easily constructed, without doing +anything different. For example, we can generate multiple streams of random +numbers, write them all to their own files, and print the combined results: + +```julia +Dagger.spawn_streaming() do + all_vals = [Dagger.spawn(rand) for i in 1:4] + all_vals_written = map(1:4) do i + Dagger.spawn(all_vals[i]) do val + open("results_$i.txt"; write=true, create=true, append=true) do io + println(io, repr(val)) + end + return val + end + end + Dagger.spawn(all_vals_written...) do all_vals_written... + vals_sum = sum(all_vals_written) + println(vals_sum) + end +end +``` + +If you want to stop the streaming DAG and tear it all down, you can call +`Dagger.kill!(all_vals[1])` (or `Dagger.kill!(all_vals_written[2])`, etc., the +kill propagates throughout the DAG). + +Alternatively, tasks can stop themselves from the inside with +`finish_streaming`, optionally returning a value that can be `fetch`'d. Let's +do this when our randomly-drawn number falls within some arbitrary range: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.spawn() do + x = rand() + if x < 0.001 + # That's good enough, let's be done + return Dagger.finish_streaming("Finished!") + end + return x + end +end +fetch(vals) +``` + +In this example, the call to `fetch` will hang (while random numbers continue +to be drawn), until a drawn number is less than 0.001; at that point, `fetch` +will return with "Finished!", and the task `vals` will have terminated. diff --git a/src/Dagger.jl b/src/Dagger.jl index b478ece0f..70725340a 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -23,6 +23,9 @@ if !isdefined(Base, :ScopedValues) else import Base.ScopedValues: ScopedValue, with end +import TaskLocalValues: TaskLocalValue + +import TaskLocalValues: TaskLocalValue if !isdefined(Base, :get_extension) import Requires: @require @@ -46,16 +49,16 @@ include("processor.jl") include("threadproc.jl") include("context.jl") include("utils/processors.jl") +include("dtask.jl") +include("cancellation.jl") include("task-tls.jl") include("scopes.jl") include("utils/scopes.jl") -include("dtask.jl") include("queue.jl") include("thunk.jl") include("submission.jl") include("chunks.jl") include("memory-spaces.jl") -include("cancellation.jl") # Task scheduling include("compute.jl") @@ -67,6 +70,11 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps.jl") +# Streaming +include("stream.jl") +include("stream-buffers.jl") +include("stream-transfer.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") @@ -145,6 +153,20 @@ function __init__() ThreadProc(myid(), tid) end end + + # Set up @dagdebug categories, if specified + try + if haskey(ENV, "JULIA_DAGGER_DEBUG") + empty!(DAGDEBUG_CATEGORIES) + for category in split(ENV["JULIA_DAGGER_DEBUG"], ",") + if category != "" + push!(DAGDEBUG_CATEGORIES, Symbol(category)) + end + end + end + catch err + @warn "Error parsing JULIA_DAGGER_DEBUG" exception=err + end end end # module diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..69725eb7a 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -1,5 +1,3 @@ -import TaskLocalValues: TaskLocalValue - ### getindex struct GetIndex{T,N} <: ArrayOp{T,N} diff --git a/src/cancellation.jl b/src/cancellation.jl index c982fd20c..86f562bc4 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,3 +1,51 @@ +# DTask-level cancellation + +mutable struct CancelToken + @atomic cancelled::Bool + @atomic graceful::Bool + event::Base.Event +end +CancelToken() = CancelToken(false, false, Base.Event()) +function cancel!(token::CancelToken; graceful::Bool=true) + if !graceful + @atomic token.graceful = false + end + @atomic token.cancelled = true + notify(token.event) + return +end +function is_cancelled(token::CancelToken; must_force::Bool=false) + if token.cancelled[] + if must_force && token.graceful[] + # If we're only responding to forced cancellation, ignore graceful cancellations + return false + end + return true + end + return false +end +Base.wait(token::CancelToken) = wait(token.event) +# TODO: Enable this for safety +#Serialization.serialize(io::AbstractSerializer, ::CancelToken) = +# throw(ConcurrencyViolationError("Cannot serialize a CancelToken")) + +const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing) + +function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer) + remote_token = remotecall_fetch(wid) do + return poolset(CancelToken()) + end + errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin + wait(orig_token) + @dagdebug nothing :cancel "Cancelling remote token on worker $wid" + MemPool.access_ref(remote_token) do remote_token + cancel!(remote_token) + end + end) +end + +# Global-level cancellation + """ cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) @@ -48,7 +96,7 @@ function _cancel!(state, tid, force, halt_sch) for task in state.ready tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling ready task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end @@ -58,7 +106,7 @@ function _cancel!(state, tid, force, halt_sch) for task in keys(state.waiting) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling waiting task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end @@ -80,11 +128,11 @@ function _cancel!(state, tid, force, halt_sch) Tf === typeof(Sch.eager_thunk) && continue istaskdone(task) && continue any_cancelled = true - @dagdebug tid :cancel "Cancelling running task ($Tf)" if force @dagdebug tid :cancel "Interrupting running task ($Tf)" Threads.@spawn Base.throwto(task, InterruptException()) else + @dagdebug tid :cancel "Cancelling running task ($Tf)" # Tell the processor to just drop this task task_occupancy = task_spec[4] time_util = task_spec[2] @@ -93,6 +141,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) + cancel!(istate.cancel_tokens[tid]; graceful=false) end end end diff --git a/src/compute.jl b/src/compute.jl index f421eaccc..093b527f4 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -36,12 +36,6 @@ end Base.@deprecate gather(ctx, x) collect(ctx, x) Base.@deprecate gather(x) collect(x) -cleanup() = cleanup(Context(global_context())) -function cleanup(ctx::Context) - Sch.cleanup(ctx) - nothing -end - function get_type(s::String) local T for t in split(s, ".") diff --git a/src/dtask.jl b/src/dtask.jl index 68f2d3c1b..98f74005a 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -39,6 +39,16 @@ end Options(;options...) = Options((;options...)) Options(options...) = Options((;options...)) +""" + DTaskMetadata + +Represents some useful metadata pertaining to a `DTask`: +- `return_type::Type` - The inferred return type of the task +""" +mutable struct DTaskMetadata + return_type::Type +end + """ DTask @@ -50,9 +60,11 @@ more details. mutable struct DTask uid::UInt future::ThunkFuture + metadata::DTaskMetadata finalizer_ref::DRef thunk_ref::DRef - DTask(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + + DTask(uid, future, metadata, finalizer_ref) = new(uid, future, metadata, finalizer_ref) end const EagerThunk = DTask diff --git a/src/options.jl b/src/options.jl index 1c1e3ff29..00196dd59 100644 --- a/src/options.jl +++ b/src/options.jl @@ -20,6 +20,12 @@ function with_options(f, options::NamedTuple) end with_options(f; options...) = with_options(f, NamedTuple(options)) +function _without_options(f) + with(options_context => NamedTuple()) do + f() + end +end + """ get_options(key::Symbol, default) -> Any get_options(key::Symbol) -> Any diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 73bb07bf9..794afce92 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -253,9 +253,11 @@ end Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`. """ function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) - single = topts.single !== nothing ? topts.single : sopts.single - allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors - proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist + select_option = (sopt, topt) -> isnothing(topt) ? sopt : topt + + single = select_option(sopts.single, topts.single) + allow_errors = select_option(sopts.allow_errors, topts.allow_errors) + proclist = select_option(sopts.proclist, topts.proclist) ThunkOptions(single, proclist, topts.time_util, @@ -307,9 +309,6 @@ function populate_defaults(opts::ThunkOptions, Tf, Targs) ) end -function cleanup(ctx) -end - # Eager scheduling include("eager.jl") @@ -1180,6 +1179,7 @@ struct ProcessorInternalState proc_occupancy::Base.RefValue{UInt32} time_pressure::Base.RefValue{UInt64} cancelled::Set{Int} + cancel_tokens::Dict{Int,Dagger.CancelToken} done::Base.RefValue{Bool} end struct ProcessorState @@ -1199,7 +1199,7 @@ function proc_states(f::Base.Callable, uid::UInt64) end end proc_states(f::Base.Callable) = - proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64) + proc_states(f, Dagger.get_tls().sch_uid) task_tid_for_processor(::Processor) = nothing task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid @@ -1329,7 +1329,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Execute the task and return its result t = @task begin + # Set up cancellation + cancel_token = Dagger.CancelToken() + Dagger.DTASK_CANCEL_TOKEN[] = cancel_token + lock(istate.queue) do _ + istate.cancel_tokens[thunk_id] = cancel_token + end was_cancelled = false + result = try do_task(to_proc, task) catch err @@ -1346,6 +1353,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Task was cancelled, so occupancy and pressure are # already reduced pop!(istate.cancelled, thunk_id) + delete!(istate.cancel_tokens, thunk_id) was_cancelled = true end end @@ -1363,6 +1371,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re else rethrow(err) end + finally + # Ensure that any spawned tasks get cleaned up + Dagger.cancel!(cancel_token) end end lock(istate.queue) do _ @@ -1412,6 +1423,7 @@ function do_tasks(to_proc, return_queue, tasks) Dict{Int,Vector{Any}}(), Ref(UInt32(0)), Ref(UInt64(0)), Set{Int}(), + Dict{Int,Dagger.CancelToken}(), Ref(false)) runner = start_processor_runner!(istate, uid, return_queue) @static if VERSION < v"1.9" @@ -1653,6 +1665,7 @@ function do_task(to_proc, task_desc) sch_handle, processor=to_proc, task_spec=task_desc, + cancel_token=Dagger.DTASK_CANCEL_TOKEN[], )) res = Dagger.with_options(propagated) do diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index e02085ee6..5b917fdb5 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -17,7 +17,7 @@ struct SchedulerHandle end "Gets the scheduler handle for the currently-executing thunk." -sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle +sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle "Thrown when the scheduler halts before finishing processing the DAG." struct SchedulerHaltedException <: Exception end diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 87a109788..aea0abbf6 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -6,7 +6,7 @@ const EAGER_STATE = Ref{Union{ComputeState,Nothing}}(nothing) function eager_context() if EAGER_CONTEXT[] === nothing - EAGER_CONTEXT[] = Context([myid(),workers()...]) + EAGER_CONTEXT[] = Context(procs()) end return EAGER_CONTEXT[] end @@ -124,6 +124,13 @@ function eager_cleanup(state, uid) # N.B. cache and errored expire automatically delete!(state.thunk_dict, tid) end + remotecall_wait(1, uid) do uid + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + delete!(global_streams, uid) + end + end + end end function _find_thunk(e::Dagger.DTask) @@ -134,3 +141,6 @@ function _find_thunk(e::Dagger.DTask) unwrap_weak_checked(EAGER_STATE[].thunk_dict[tid]) end end +Dagger.task_id(t::Dagger.DTask) = lock(EAGER_ID_MAP) do id_map + id_map[t.uid] +end diff --git a/src/sch/util.jl b/src/sch/util.jl index e81703db5..2e090b26c 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -29,6 +29,10 @@ unwrap_nested_exception(err::CapturedException) = unwrap_nested_exception(err.ex) unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) +unwrap_nested_exception(err::DTaskFailedException) = + unwrap_nested_exception(err.ex) +unwrap_nested_exception(err::TaskFailedException) = + unwrap_nested_exception(err.t.exception) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." @@ -406,12 +410,19 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) else get(state.signature_alloc_cost, sig, UInt64(0)) end::UInt64 - est_occupancy = if occupancy !== nothing && haskey(occupancy, T) - # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` - Base.unsafe_trunc(UInt32, clamp(occupancy[T], 0, 1) * typemax(UInt32)) - else - typemax(UInt32) - end::UInt32 + est_occupancy::UInt32 = typemax(UInt32) + if occupancy !== nothing + occ = nothing + if haskey(occupancy, T) + occ = occupancy[T] + elseif haskey(occupancy, Any) + occ = occupancy[Any] + end + if occ !== nothing + # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` + est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32)) + end + end #= FIXME: Estimate if cached data can be swapped to storage storage = storage_resource(p) real_alloc_util = state.worker_storage_pressure[gp][storage] diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl new file mode 100644 index 000000000..579a59753 --- /dev/null +++ b/src/stream-buffers.jl @@ -0,0 +1,98 @@ +""" +A buffer that drops all elements put into it. +""" +mutable struct DropBuffer{T} + open::Bool + DropBuffer{T}() where T = new{T}(true) +end +DropBuffer{T}(_) where T = DropBuffer{T}() +Base.isempty(::DropBuffer) = true +isfull(::DropBuffer) = false +capacity(::DropBuffer) = typemax(Int) +Base.length(::DropBuffer) = 0 +Base.isopen(buf::DropBuffer) = buf.open +function Base.close(buf::DropBuffer) + buf.open = false +end +function Base.put!(buf::DropBuffer, _) + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + return +end +function Base.take!(buf::DropBuffer) + while true + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + end +end + +"A process-local ring buffer." +mutable struct ProcessRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + @atomic open::Bool + function ProcessRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer, true) + end +end +Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 +isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +capacity(rb::ProcessRingBuffer) = length(rb.buffer) +Base.length(rb::ProcessRingBuffer) = @atomic rb.count +Base.isopen(rb::ProcessRingBuffer) = @atomic rb.open +function Base.close(rb::ProcessRingBuffer) + @atomic rb.open = false +end +function Base.put!(rb::ProcessRingBuffer{T}, x) where T + while isfull(rb) + yield() + if !isopen(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + end + to_write_idx = mod1(rb.write_idx, length(rb.buffer)) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ProcessRingBuffer) + while isempty(rb) + yield() + if !isopen(rb) && isempty(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + if task_cancelled() && isempty(rb) + # We respect a graceful cancellation only if the buffer is empty. + # Otherwise, we may have values to continue communicating. + task_may_cancel!() + end + task_may_cancel!(; must_force=true) + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end + +""" +`take!()` all the elements from a buffer and put them in a `Vector`. +""" +function collect!(rb::ProcessRingBuffer{T}) where T + output = Vector{T}(undef, rb.count) + for i in 1:rb.count + output[i] = take!(rb) + end + + return output +end diff --git a/src/stream-transfer.jl b/src/stream-transfer.jl new file mode 100644 index 000000000..667808762 --- /dev/null +++ b/src/stream-transfer.jl @@ -0,0 +1,128 @@ +struct RemoteChannelFetcher + chan::RemoteChannel + RemoteChannelFetcher() = new(RemoteChannel()) +end +const _THEIR_TID = TaskLocalValue{Int}(()->0) +function stream_push_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_push "taking output value: $our_tid -> $their_tid" + value = try + take!(buffer) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_push "pushing output value: $our_tid -> $their_tid" + try + put!(fetcher.chan, value) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_push "channel closed: $our_tid -> $their_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_push "finished pushing output value: $our_tid -> $their_tid" +end +function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_pull "pulling input value: $their_tid -> $our_tid" + value = try + take!(fetcher.chan) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_pull "channel closed: $their_tid -> $our_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_pull "putting input value: $their_tid -> $our_tid" + try + put!(buffer, value) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_pull "finished putting input value: $their_tid -> $our_tid" +end + +#= TODO: Remove me +# This is a bad implementation because it wants to sleep on the remote side to +# wait for values, but this isn't semantically valid when done with MemPool.access_ref +struct RemoteFetcher end +function stream_push_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + sleep(1) +end +function stream_pull_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + id = our_store.uid + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "fetching values" + + free_space = capacity(buffer) - length(buffer) + if free_space == 0 + @dagdebug thunk_id :stream "waiting for drain of full input buffer" + yield() + task_may_cancel!() + wait_for_nonfull_input(our_store, their_stream.uid) + return + end + + values = T[] + while isempty(values) + values, closed = MemPool.access_ref(their_stream.store_ref.handle, id, T, thunk_id, free_space) do their_store, id, T, thunk_id, free_space + @dagdebug thunk_id :stream "trying to fetch values at worker $(myid())" + STREAM_THUNK_ID[] = thunk_id + values = T[] + @dagdebug thunk_id :stream "trying to fetch with free_space: $free_space" + wait_for_nonempty_output(their_store, id) + if isempty(their_store, id) && !isopen(their_store, id) + @dagdebug thunk_id :stream "remote stream is closed, returning" + return values, true + end + while !isempty(their_store, id) && length(values) < free_space + value = take!(their_store, id)::T + @dagdebug thunk_id :stream "fetched $value" + push!(values, value) + end + return values, false + end::Tuple{Vector{T},Bool} + if closed + throw(InterruptException()) + end + + # We explicitly yield in the loop to allow other tasks to run. This + # matters on single-threaded instances because MemPool.access_ref() + # might not yield when accessing data locally, which can cause this loop + # to spin forever. + yield() + task_may_cancel!() + end + + @dagdebug thunk_id :stream "fetched $(length(values)) values" + for value in values + put!(buffer, value) + end +end +=# diff --git a/src/stream.jl b/src/stream.jl new file mode 100644 index 000000000..ddc303c98 --- /dev/null +++ b/src/stream.jl @@ -0,0 +1,757 @@ +mutable struct StreamStore{T,B} + uid::UInt + waiters::Vector{Int} + input_streams::Dict{UInt,Any} # FIXME: Concrete type + output_streams::Dict{UInt,Any} # FIXME: Concrete type + input_buffers::Dict{UInt,B} + output_buffers::Dict{UInt,B} + input_buffer_amount::Int + output_buffer_amount::Int + input_fetchers::Dict{UInt,Any} + output_fetchers::Dict{UInt,Any} + open::Bool + migrating::Bool + lock::Threads.Condition + StreamStore{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} = + new{T,B}(uid, zeros(Int, 0), + Dict{UInt,Any}(), Dict{UInt,Any}(), + Dict{UInt,B}(), Dict{UInt,B}(), + input_buffer_amount, output_buffer_amount, + Dict{UInt,Any}(), Dict{UInt,Any}(), + true, false, Threads.Condition()) +end + +function tid_to_uid(thunk_id) + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end +end + +function Base.put!(store::StreamStore{T,B}, value) where {T,B} + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)" + for output_uid in keys(store.output_streams) + buffer = store.output_buffers[output_uid] + while isfull(buffer) + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting" + wait(store.lock) + if !isfull(buffer) + @dagdebug thunk_id :stream "buffer has space ($(length(buffer)) values), continuing" + end + task_may_cancel!() + end + put!(buffer, value) + end + notify(store.lock) + end +end + +function Base.take!(store::StreamStore, id::UInt) + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + error("Must first check isempty(store, id) before taking from a stream") + end + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug thunk_id :stream "no elements, not taking" + wait(store.lock) + task_may_cancel!() + end + @dagdebug thunk_id :stream "wait finished" + if !isopen(store, id) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + unlock(store.lock) + value = try + take!(buffer) + finally + lock(store.lock) + end + @dagdebug thunk_id :stream "value accepted" + notify(store.lock) + return value + end +end +function wait_for_nonfull_input(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.input_streams, id) + @assert haskey(store.input_buffers, id) + buffer = store.input_buffers[id] + while isfull(buffer) && isopen(store) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for space in input buffer" + wait(store.lock) + end + end +end +function wait_for_nonempty_output(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.output_streams, id) + + # Wait for the output buffer to be initialized + while !haskey(store.output_buffers, id) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be initialized" + wait(store.lock) + end + isopen(store, id) || return + + # Wait for the output buffer to be nonempty + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be nonempty" + wait(store.lock) + end + end +end + +function Base.isempty(store::StreamStore, id::UInt) + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return true + end + return isempty(store.output_buffers[id]) +end +isfull(store::StreamStore, id::UInt) = isfull(store.output_buffers[id]) + +"Returns whether the store is actively open. Only check this when deciding if new values can be pushed." +Base.isopen(store::StreamStore) = store.open + +""" +Returns whether the store is actively open, or if closing, still has remaining +messages for `id`. Only check this when deciding if existing values can be +taken. +""" +function Base.isopen(store::StreamStore, id::UInt) + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return store.open + end + if !isempty(store.output_buffers[id]) + return true + end + return store.open + end +end + +function Base.close(store::StreamStore) + @lock store.lock begin + store.open || return + + store.open = false + for buffer in values(store.input_buffers) + close(buffer) + end + for buffer in values(store.output_buffers) + close(buffer) + end + notify(store.lock) + end +end + +# FIXME: Just pass Stream directly, rather than its uid +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Pair{UInt,Any}}) where {T,B} + our_uid = store.uid + @lock store.lock begin + for (output_uid, output_fetcher) in waiters + store.output_streams[output_uid] = task_to_stream(output_uid) + push!(store.waiters, output_uid) + store.output_fetchers[output_uid] = output_fetcher + end + notify(store.lock) + end +end + +function remove_waiters!(store::StreamStore, waiters::Vector{UInt}) + @lock store.lock begin + for w in waiters + delete!(store.output_buffers, w) + idx = findfirst(wo->wo==w, store.waiters) + deleteat!(store.waiters, idx) + delete!(store.input_streams, w) + end + notify(store.lock) + end +end + +mutable struct Stream{T,B} + uid::UInt + store::Union{StreamStore{T,B},Nothing} + store_ref::Chunk + function Stream{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} + # Creates a new output stream + store = StreamStore{T,B}(uid, input_buffer_amount, output_buffer_amount) + store_ref = tochunk(store) + return new{T,B}(uid, store, store_ref) + end + function Stream(stream::Stream{T,B}) where {T,B} + # References an existing output stream + return new{T,B}(stream.uid, nothing, stream.store_ref) + end +end + +struct StreamCancelledException <: Exception end +struct StreamingValue{B} + buffer::B +end +Base.take!(sv::StreamingValue) = take!(sv.buffer) + +function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::Stream{IT,IB}) where {IT,OT,IB,OB} + input_uid = input_stream.uid + our_uid = our_store.uid + local buffer, input_fetcher + @lock our_store.lock begin + if haskey(our_store.input_buffers, input_uid) + return StreamingValue(our_store.input_buffers[input_uid]) + end + + buffer = initialize_stream_buffer(OB, IT, our_store.input_buffer_amount) + # FIXME: Also pass a RemoteChannel to track remote closure + our_store.input_buffers[input_uid] = buffer + input_fetcher = our_store.input_fetchers[input_uid] + end + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while isopen(our_store) + stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) + end + catch err + unwrapped_err = Sch.unwrap_nested_exception(err) + if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && !isopen(input_fetcher.chan)) + return + else + rethrow() + end + finally + # Close the buffer because there will be no more values put into + # it. We don't close the entire store because there might be some + # remaining elements in the buffer to process and send to downstream + # tasks. + close(buffer) + @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" + end + end) + return StreamingValue(buffer) +end +initialize_input_stream!(our_store::StreamStore, arg) = arg +function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} + @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" + local buffer + @lock our_store.lock begin + buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) + our_store.output_buffers[output_uid] = buffer + end + + our_uid = our_store.uid + output_stream = our_store.output_streams[output_uid] + output_fetcher = our_store.output_fetchers[output_uid] + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while true + if !isopen(our_store) && isempty(buffer) + # Only exit if the buffer is empty; otherwise, we need to + # continue draining it + break + end + stream_push_values!(output_fetcher, T, our_store, output_stream, buffer) + end + catch err + if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow(err) + end + finally + close(output_fetcher.chan) + @dagdebug thunk_id :stream "output stream closed" + end + end) +end + +Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) + +function Base.isopen(stream::Stream, id::UInt)::Bool + return MemPool.access_ref(stream.store_ref.handle, id) do store, id + return isopen(store::StreamStore, id) + end +end + +function Base.close(stream::Stream) + MemPool.access_ref(stream.store_ref.handle) do store + close(store::StreamStore) + return + end + return +end + +function add_waiters!(stream::Stream, waiters::Vector{Pair{UInt,Any}}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + add_waiters!(store::StreamStore, waiters) + return + end + return +end + +add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream, UInt[waiter]) + +function remove_waiters!(stream::Stream, waiters::Vector{UInt}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + remove_waiters!(store::StreamStore, waiters) + return + end + return +end + +remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream, Int[waiter]) + +struct StreamingFunction{F, S} + f::F + stream::S + max_evals::Int + + StreamingFunction(f::F, stream::S, max_evals) where {F, S} = + new{F, S}(f, stream, max_evals) +end + +function migrate_stream!(stream::Stream, w::Integer=myid()) + # Perform migration of the StreamStore + # MemPool will block access to the new ref until the migration completes + # FIXME: Do this with MemPool.access_ref, in case stream was already migrated + if stream.store_ref.handle.owner != w + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "Beginning migration... ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + + # TODO: Wire up listener to ferry cancel_token notifications to remote worker + tls = get_tls() + @assert w == myid() "Only pull-based migration is currently supported" + #remote_cancel_token = clone_cancel_token_remote(get_tls().cancel_token, worker_id) + + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; + pre_migration=store->begin + # Lock store to prevent any further modifications + # N.B. Serialization automatically unlocks the migrated copy + lock((store::StreamStore).lock) + + # Return the serializeable unsent inputs/outputs. We can't send the + # buffers themselves because they may be mmap'ed or something. + unsent_inputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.input_buffers) + unsent_outputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.output_buffers) + empty!(store.input_buffers) + empty!(store.output_buffers) + return (unsent_inputs, unsent_outputs) + end, + dest_post_migration=(store, unsent)->begin + # Initialize the StreamStore on the destination with the unsent inputs/outputs. + STREAM_THUNK_ID[] = thunk_id + @assert !in_task() + set_tls!(tls) + #get_tls().cancel_token = MemPool.access_ref(identity, remote_cancel_token; local_only=true) + unsent_inputs, unsent_outputs = unsent + for (input_uid, inputs) in unsent_inputs + input_stream = store.input_streams[input_uid] + initialize_input_stream!(store, input_stream) + for item in inputs + put!(store.input_buffers[input_uid], item) + end + end + for (output_uid, outputs) in unsent_outputs + initialize_output_stream!(store, output_uid) + for item in outputs + put!(store.output_buffers[output_uid], item) + end + end + + # Reset the state of this new store + store.open = true + store.migrating = false + end, + post_migration=store->begin + # Indicate that this store has migrated + store.migrating = true + store.open = false + + # Unlock the store + unlock((store::StreamStore).lock) + end) + if w == myid() + stream.store_ref.handle = new_store_ref # FIXME: It's not valid to mutate the Chunk handle, but we want to update this to enable fast location queries + stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) + end + + @dagdebug thunk_id :stream "Migration complete ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + end +end + +struct StreamingTaskQueue <: AbstractTaskQueue + tasks::Vector{Pair{DTaskSpec,DTask}} + self_streams::Dict{UInt,Any} + StreamingTaskQueue() = new(Pair{DTaskSpec,DTask}[], + Dict{UInt,Any}()) +end + +function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) + push!(queue.tasks, spec) + initialize_streaming!(queue.self_streams, spec...) +end + +function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) + append!(queue.tasks, specs) + for (spec, task) in specs + initialize_streaming!(queue.self_streams, spec, task) + end +end + +function initialize_streaming!(self_streams, spec, task) + if !isa(spec.f, StreamingFunction) + # Calculate the return type of the called function + T_old = Base.uniontypes(task.metadata.return_type) + T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) + # N.B. We treat non-dominating error paths as unreachable + T_old = filter(t->t !== Union{}, T_old) + T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + + # Get input buffer configuration + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + if input_buffer_amount <= 0 + throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) + end + + # Get output buffer configuration + output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + if output_buffer_amount <= 0 + throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) + end + + # Create the Stream + buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer) + stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount) + self_streams[task.uid] = stream + + # Get max evaluation count + max_evals = get(spec.options, :stream_max_evals, -1) + if max_evals == 0 + throw(ArgumentError("stream_max_evals cannot be 0")) + end + + spec.f = StreamingFunction(spec.f, stream, max_evals) + spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) + + # Register Stream globally + remotecall_wait(1, task.uid, stream) do uid, stream + lock(EAGER_THUNK_STREAMS) do global_streams + global_streams[uid] = stream + end + end + end +end + +function spawn_streaming(f::Base.Callable) + queue = StreamingTaskQueue() + result = with_options(f; task_queue=queue) + if length(queue.tasks) > 0 + finalize_streaming!(queue.tasks, queue.self_streams) + enqueue!(queue.tasks) + end + return result +end + +struct FinishStream{T,R} + value::Union{Some{T},Nothing} + result::R +end + +""" + finish_stream(value=nothing; result=nothing) + +Tell Dagger to stop executing the streaming function and all of its downstream +[`DTask`](@ref)'s. + +# Arguments +- `value`: The final value to be returned by the streaming function. This will + be passed to all downstream [`DTask`](@ref)'s. +- `result`: The value that will be returned by `fetch()`'ing the [`DTask`](@ref). +""" +finish_stream(value::T; result::R=nothing) where {T,R} = FinishStream{T,R}(Some{T}(value), result) + +finish_stream(; result::R=nothing) where R = FinishStream{Union{},R}(nothing, result) + +const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) + +chunktype(sf::StreamingFunction{F}) where F = F + +struct StreamMigrating end + +function (sf::StreamingFunction)(args...; kwargs...) + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # Migrate our output stream store to this worker + if sf.stream isa Stream + remote_cancel_token = migrate_stream!(sf.stream) + end + + @label start + @dagdebug thunk_id :stream "Starting StreamingFunction" + worker_id = sf.stream.store_ref.handle.owner + result = if worker_id == myid() + _run_streamingfunction(nothing, nothing, sf, args...; kwargs...) + else + tls = get_tls() + remotecall_fetch(_run_streamingfunction, worker_id, tls, remote_cancel_token, sf, args...; kwargs...) + end + if result === StreamMigrating() + @goto start + end + return result +end + +function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...) + @nospecialize sf args kwargs + + store = sf.stream.store = MemPool.access_ref(identity, sf.stream.store_ref.handle; local_only=true) + @assert isopen(store) + + if tls !== nothing + # Setup TLS on this new task + tls.cancel_token = MemPool.access_ref(identity, cancel_token; local_only=true) + set_tls!(tls) + end + + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # FIXME: Remove when scheduler is distributed + uid = remotecall_fetch(1, thunk_id) do thunk_id + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end + end + + try + # TODO: This kwarg song-and-dance is required to ensure that we don't + # allocate boxes within `stream!`, when possible + kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) + kwarg_values = map(last, (kwargs...,)) + args = map(arg->initialize_input_stream!(store, arg), args) + kwarg_values = map(kwarg->initialize_input_stream!(store, kwarg), kwarg_values) + return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) + finally + if !sf.stream.store.migrating + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end + end + for stream in streams + @dagdebug thunk_id :stream "dropping waiter" + remove_waiters!(stream, uid) + @dagdebug thunk_id :stream "dropped waiter" + end + + # Ensure downstream tasks also terminate + close(sf.stream) + @dagdebug thunk_id :stream "closed stream store" + end + end +end + +# N.B We specialize to minimize/eliminate allocations +function stream!(sf::StreamingFunction, uid, + args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + f = move(thunk_processor(), sf.f) + counter = 0 + + # Initialize output streams. We can't do this in add_waiters!() because the + # output handlers depend on the DTaskTLS, so they have to be set up from + # within the DTask. + store = sf.stream.store + for output_uid in keys(store.output_streams) + if !haskey(store.output_buffers, output_uid) + initialize_output_stream!(store, output_uid) + end + end + + while true + # Yield to other (streaming) tasks + yield() + + # Exit streaming on cancellation + task_may_cancel!() + + # Exit streaming on migration + if sf.stream.store.migrating + error("FIXME: max_evals should be retained") + @dagdebug STREAM_THUNK_ID[] :stream "returning for migration" + return StreamMigrating() + end + + # Get values from Stream args/kwargs + local stream_args, stream_kwarg_values + try + stream_args = _stream_take_values!(args) + stream_kwarg_values = _stream_take_values!(kwarg_values) + catch ex + if ex isa InvalidStateException + # This means a buffer has been closed because an upstream task + # finished. + @dagdebug STREAM_THUNK_ID[] :stream "Upstream task finished, returning" + return nothing + else + rethrow() + end + end + + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + + if length(stream_args) > 0 || length(stream_kwarg_values) > 0 + # Notify tasks that input buffers may have space + @lock sf.stream.store.lock notify(sf.stream.store.lock) + end + + # Run a single cycle of f + counter += 1 + @dagdebug STREAM_THUNK_ID[] :stream "executing $f (eval $counter)" + stream_result = f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishStream + if stream_result.value !== nothing + value = something(stream_result.value) + put!(sf.stream, value) + end + @dagdebug STREAM_THUNK_ID[] :stream "voluntarily returning" + return stream_result.result + end + + # Put the result into the output stream + put!(sf.stream, stream_result) + + # Exit streaming on eval limit + if sf.max_evals > 0 && counter >= sf.max_evals + @dagdebug STREAM_THUNK_ID[] :stream "max evals reached ($counter)" + return + end + end +end + +function _stream_take_values!(args) + return ntuple(length(args)) do idx + arg = args[idx] + if arg isa StreamingValue + return take!(arg) + else + return arg + end + end +end + +@inline @generated function _stream_namedtuple(kwarg_names::Tuple, + stream_kwarg_values::Tuple) + name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) + NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) + return :($NT(stream_kwarg_values)) +end + +initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) + +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) +function task_to_stream(uid::UInt) + if myid() != 1 + return remotecall_fetch(task_to_stream, 1, uid) + end + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + return global_streams[uid] + end + return + end +end + +function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) + stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() + + for (spec, task) in tasks + @assert haskey(self_streams, task.uid) + our_stream = self_streams[task.uid] + + # Adapt args to accept Stream output of other streaming tasks + for (idx, (pos, arg)) in enumerate(spec.args) + if arg isa DTask + # Check if this is a streaming task + if haskey(self_streams, arg.uid) + other_stream = self_streams[arg.uid] + else + other_stream = task_to_stream(arg.uid) + end + + if other_stream !== nothing + # Generate Stream handle for input + # FIXME: Be configurable + input_fetcher = RemoteChannelFetcher() + other_stream_handle = Stream(other_stream) + spec.args[idx] = pos => other_stream_handle + our_stream.store.input_streams[arg.uid] = other_stream_handle + our_stream.store.input_fetchers[arg.uid] = input_fetcher + + # Add this task as a waiter for the associated output Stream + changes = get!(stream_waiter_changes, arg.uid) do + Pair{UInt,Any}[] + end + push!(changes, task.uid => input_fetcher) + end + end + end + + # Filter out all streaming options + to_filter = (:stream_buffer_type, + :stream_input_buffer_amount, :stream_output_buffer_amount, + :stream_max_evals) + spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), + Base.pairs(spec.options))) + if haskey(spec.options, :propagates) + propagates = filter(opt -> !(opt in to_filter), + spec.options.propagates) + spec.options = merge(spec.options, (;propagates)) + end + end + + # Adjust waiter count of Streams with dependencies + for (uid, waiters) in stream_waiter_changes + stream = task_to_stream(uid) + add_waiters!(stream, waiters) + end +end diff --git a/src/submission.jl b/src/submission.jl index 7312e378d..f23539271 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -218,15 +218,27 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) return options end end + +function DTaskMetadata(spec::DTaskSpec) + f = spec.f isa StreamingFunction ? spec.f.f : spec.f + arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + return_type = Base.promote_op(f, arg_types...) + return DTaskMetadata(return_type) +end + function eager_spawn(spec::DTaskSpec) # Generate new DTask uid = eager_next_id() future = ThunkFuture() + metadata = DTaskMetadata(spec) finalizer_ref = poolset(DTaskFinalizer(uid); device=MemPool.CPURAMDevice()) # Create unlaunched DTask - return DTask(uid, future, finalizer_ref) + return DTask(uid, future, metadata, finalizer_ref) end + +chunktype(t::DTask) = t.metadata.return_type + function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) # Assign a name, if specified eager_assign_name!(spec, task) diff --git a/src/task-tls.jl b/src/task-tls.jl index ea188e004..5c7d0375b 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,41 +1,81 @@ # In-Thunk Helpers +mutable struct DTaskTLS + processor::Processor + sch_uid::UInt + sch_handle::Any # FIXME: SchedulerHandle + task_spec::Vector{Any} # FIXME: TaskSpec + cancel_token::CancelToken +end + +const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) + +Base.copy(tls::DTaskTLS) = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) + """ - task_processor() + get_tls() -> DTaskTLS -Get the current processor executing the current Dagger task. +Gets all Dagger TLS variable as a `DTaskTLS`. """ -task_processor() = task_local_storage(:_dagger_processor)::Processor -@deprecate thunk_processor() task_processor() +get_tls() = DTASK_TLS[]::DTaskTLS """ - in_task() + set_tls!(tls) + +Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. +""" +function set_tls!(tls) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) +end + +""" + in_task() -> Bool Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ -in_task() = haskey(task_local_storage(), :_dagger_sch_uid) -@deprecate in_thunk() in_task() +in_task() = DTASK_TLS[] !== nothing +@deprecate(in_thunk(), in_task()) """ - get_tls() + task_id() -> Int -Gets all Dagger TLS variable as a `NamedTuple`. +Returns the ID of the current [`DTask`](@ref). """ -get_tls() = ( - sch_uid=task_local_storage(:_dagger_sch_uid), - sch_handle=task_local_storage(:_dagger_sch_handle), - processor=task_processor(), - task_spec=task_local_storage(:_dagger_task_spec), -) +task_id() = get_tls().sch_handle.thunk_id.id """ - set_tls!(tls) + task_processor() -> Processor -Sets all Dagger TLS variables from the `NamedTuple` `tls`. +Get the current processor executing the current [`DTask`](@ref). """ -function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_task_spec, tls.task_spec) +task_processor() = get_tls().processor +@deprecate(thunk_processor(), task_processor()) + +""" + task_cancelled(; must_force::Bool=false) -> Bool + +Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +If `must_force=true`, then only return `true` if the cancellation was forced. +""" +task_cancelled(; must_force::Bool=false) = + is_cancelled(get_tls().cancel_token; must_force) + +""" + task_may_cancel!(; must_force::Bool=false) + +Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +If `must_force=true`, then only throw if the cancellation was forced. +""" +function task_may_cancel!(;must_force::Bool=false) + if task_cancelled(;must_force) + throw(InterruptException()) + end end + +""" + task_cancel!(; graceful::Bool=true) + +Cancels the current [`DTask`](@ref). If `graceful=true`, then the task will be +cancelled gracefully, otherwise it will be forced. +""" +task_cancel!(; graceful::Bool=true) = cancel!(get_tls().cancel_token; graceful) diff --git a/src/threadproc.jl b/src/threadproc.jl index 09099889a..b75c90ca3 100644 --- a/src/threadproc.jl +++ b/src/threadproc.jl @@ -27,8 +27,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n return result[] catch err if err isa InterruptException + # Direct interrupt hit us, propagate cancellation signal + # FIXME: We should tell the scheduler that the user hit Ctrl-C if !istaskdone(task) - # Propagate cancellation signal Threads.@spawn Base.throwto(task, InterruptException()) end end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 9a9d24167..6a71e5c52 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -2,7 +2,8 @@ function istask end function task_id end const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope, - :take, :execute, :move, :processor, :cancel] + :take, :execute, :move, :processor, :cancel, + :stream] macro dagdebug(thunk, category, msg, args...) cat_sym = category.value @gensym id @@ -31,6 +32,10 @@ macro dagdebug(thunk, category, msg, args...) $debug_ex_noid end end + + # Always yield to reduce differing behavior for debug vs. non-debug + # TODO: Remove this eventually + yield() end end) end diff --git a/test/mutation.jl b/test/mutation.jl index b6ac7143b..a245f445d 100644 --- a/test/mutation.jl +++ b/test/mutation.jl @@ -1,3 +1,5 @@ +import Dagger.Sch: SchedulingException + @everywhere begin struct DynamicHistogram bins::Vector{Float64} @@ -48,7 +50,7 @@ end x = Dagger.@mutable worker=w Ref{Int}() @test fetch(Dagger.@spawn mutable_update!(x)) == w wo_scope = Dagger.ProcessScope(wo) - @test_throws_unwrap Dagger.DTaskFailedException fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) + @test_throws_unwrap SchedulingException fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) end end # @testset "@mutable" diff --git a/test/processors.jl b/test/processors.jl index e97a1d239..4cedcd340 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -1,6 +1,6 @@ using Distributed import Dagger: Context, Processor, OSProc, ThreadProc, get_parent, get_processors -import Dagger.Sch: ThunkOptions +import Dagger.Sch: ThunkOptions, SchedulingException @everywhere begin @@ -37,9 +37,9 @@ end end @testset "Processor exhaustion" begin opts = ThunkOptions(proclist=[OptOutProc]) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap SchedulingException reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=(proc)->false) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap SchedulingException reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=nothing) @test collect(delayed(sum; options=opts)([1,2,3])) == 6 end diff --git a/test/runtests.jl b/test/runtests.jl index 04871f6b9..67d25276e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ tests = [ ("Mutation", "mutation.jl"), ("Task Queues", "task-queues.jl"), ("Datadeps", "datadeps.jl"), + ("Streaming", "streaming.jl"), ("Domain Utilities", "domain.jl"), ("Array - Allocation", "array/allocation.jl"), ("Array - Indexing", "array/indexing.jl"), @@ -34,7 +35,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) - Pkg.instantiate() + try + Pkg.instantiate() + catch + end using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") @@ -51,6 +55,12 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ arg_type = Int default = additional_workers help = "How many additional workers to launch" + "-v", "--verbose" + action = :store_true + help = "Run the tests with debug logs from Dagger" + "-O", "--offline" + action = :store_true + help = "Set Pkg into offline mode" end end @@ -80,12 +90,20 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ parsed_args["simulate"] && exit(0) additional_workers = parsed_args["procs"] + + if parsed_args["verbose"] + ENV["JULIA_DEBUG"] = "Dagger" + end + + if parsed_args["offline"] + Pkg.UPDATED_REGISTRY_THIS_SESSION[] = true + Pkg.offline(true) + end else to_test = all_test_names @info "Running all tests" end - using Distributed if additional_workers > 0 # We put this inside a branch because addprocs() takes a minimum of 1s to diff --git a/test/scheduler.jl b/test/scheduler.jl index b9fe01872..a949ffc1d 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -182,7 +182,7 @@ end @testset "allow errors" begin opts = ThunkOptions(;allow_errors=true) a = delayed(error; options=opts)("Test") - @test_throws_unwrap Dagger.DTaskFailedException collect(a) + @test_throws_unwrap ErrorException collect(a) end end @@ -396,7 +396,7 @@ end ([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)), ] for arg in args - if arg isa Chunk + if arg isa Dagger.Chunk aff = Dagger.affinity(arg) @test aff[1] == OSProc(1) @test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle)) @@ -540,7 +540,7 @@ end t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100) start_time = time_ns() Dagger.cancel!(t) - @test_throws_unwrap Dagger.DTaskFailedException fetch(t) + @test_throws_unwrap InterruptException fetch(t) t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) yield() fetch(t) finish_time = time_ns() diff --git a/test/scopes.jl b/test/scopes.jl index 5f82a71a0..a92cc42f2 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -1,3 +1,5 @@ +import Dagger.Sch: SchedulingException + @testset "Chunk Scopes" begin wid1, wid2 = addprocs(2, exeflags=["-t 2"]) @everywhere [wid1,wid2] using Dagger @@ -56,7 +58,7 @@ # Different nodes for (ch1, ch2) in [(ns1_ch, ns2_ch), (ns2_ch, ns1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Process Scope" begin @@ -75,7 +77,7 @@ # Different process for (ch1, ch2) in [(ps1_ch, ps2_ch), (ps2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process and node @@ -83,7 +85,7 @@ # Different process and node for (ch1, ch2) in [(ps1_ch, ns2_ch), (ns2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Exact Scope" begin @@ -104,14 +106,14 @@ # Different process, different processor for (ch1, ch2) in [(es1_ch, es2_ch), (es2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process, different processor es1_2 = ExactScope(Dagger.ThreadProc(wid1, 2)) es1_2_ch = Dagger.tochunk(nothing, OSProc(), es1_2) for (ch1, ch2) in [(es1_ch, es1_2_ch), (es1_2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Union Scope" begin diff --git a/test/streaming.jl b/test/streaming.jl new file mode 100644 index 000000000..06ee0a660 --- /dev/null +++ b/test/streaming.jl @@ -0,0 +1,472 @@ +const ACCUMULATOR = Dict{Int,Vector{Real}}() +@everywhere function accumulator(x=0) + tid = Dagger.task_id() + remotecall_wait(1, tid, x) do tid, x + acc = get!(Vector{Real}, ACCUMULATOR, tid) + push!(acc, x) + end + return +end +@everywhere accumulator(xs...) = accumulator(sum(xs)) +@everywhere accumulator(::Nothing) = accumulator(0) + +function catch_interrupt(f) + try + f() + catch err + if err isa Dagger.DTaskFailedException && err.ex isa InterruptException + return + elseif err isa Dagger.Sch.SchedulingException + return + end + rethrow(err) + end +end + +function merge_testset!(inner::Test.DefaultTestSet) + outer = Test.get_testset() + append!(outer.results, inner.results) + outer.n_passed += inner.n_passed +end + +function test_finishes(f, message::String; timeout=10, ignore_timeout=false, max_evals=10) + t = @eval Threads.@spawn begin + tset = nothing + try + @testset $message begin + try + @testset $message begin + Dagger.with_options(;stream_max_evals=$max_evals) do + catch_interrupt($f) + end + end + finally + tset = Test.get_testset() + end + end + catch + end + return tset + end + + timed_out = timedwait(()->istaskdone(t), timeout) == :timed_out + if timed_out + if !ignore_timeout + @warn "Testing task timed out: $message" + end + Dagger.cancel!(;halt_sch=true) + @everywhere GC.gc() + fetch(Dagger.@spawn 1+1) + end + + tset = fetch(t)::Test.DefaultTestSet + merge_testset!(tset) + return !timed_out +end + +all_scopes = [Dagger.ExactScope(proc) for proc in Dagger.all_processors()] +for idx in 1:5 + if idx == 1 + scopes = [Dagger.scope(worker = 1, thread = 1)] + scope_str = "Worker 1" + elseif idx == 2 && nprocs() > 1 + scopes = [Dagger.scope(worker = 2, thread = 1)] + scope_str = "Worker 2" + else + scopes = all_scopes + scope_str = "All Workers" + end + + @testset "Single Task Control Flow ($scope_str)" begin + @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + y = rand() + sleep(1) + return y + end + end + @test_throws_unwrap InterruptException fetch(x) + end + + @test test_finishes("Single task without result") do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + @test fetch(x) === nothing + end + + @test test_finishes("Single task with result"; max_evals=1_000_000) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + x = rand() + if x < 0.1 + return Dagger.finish_stream(x; result=123) + end + return x + end + end + @test fetch(x) == 123 + end + end + + @testset "Non-Streaming Inputs ($scope_str)" begin + @test test_finishes("() -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(0), values[A_tid]) + end + @test test_finishes("42 -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42), values[A_tid]) + end + @test test_finishes("(42, 43) -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42 + 43), values[A_tid]) + end + end + + @testset "Non-Streaming Outputs ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + end + + @test test_finishes("x -> (A, B)") do + local x, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + B = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[B_tid]) + end + end + + @testset "Multiple Tasks ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, A)") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(1.0) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v == 1, values[A_tid]) + end + + @test test_finishes("x -> y -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 1 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, A)") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, y) -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x, y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, z) -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x + 1 + z = Dagger.@spawn scope=rand(scopes) x + 2 + A = Dagger.@spawn scope=rand(scopes) accumulator(y, z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 3 <= v <= 5, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> (A, B)") do + local x, y, z, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + B = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[B_tid]) + end + + for T in (Float64, Int32, BigFloat) + @test test_finishes("Stream eltype $T") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand(T) + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v isa T, values[A_tid]) + end + end + end + + @testset "Max Evals ($scope_str)" begin + @test test_finishes("max_evals=0"; max_evals=0) do + @test_throws ArgumentError Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + end + @test test_finishes("max_evals=1"; max_evals=1) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + end + @test test_finishes("max_evals=100"; max_evals=100) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 100 + end + end + + # @testset "DropBuffer ($scope_str)" begin + # # TODO: Test that accumulator never gets called + # @test !test_finishes("x (drop)-> A"; ignore_timeout=false, max_evals=typemax(Int)) do + # # ENV["JULIA_DEBUG"] = "Dagger" + + # local x, A + # Dagger.spawn_streaming() do + # Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + # x = Dagger.@spawn scope=rand(scopes) rand() + # end + # A = Dagger.@spawn scope=rand(scopes) accumulator(x) + # end + # @test fetch(x) === nothing + # fetch(A) + # @test_throws_unwrap InterruptException fetch(A) + # end + + # @test !test_finishes("x ->(drop) A"; ignore_timeout=true) do + # local x, A + # Dagger.spawn_streaming() do + # x = Dagger.@spawn scope=rand(scopes) rand() + # Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + # A = Dagger.@spawn scope=rand(scopes) accumulator(x) + # end + # end + # @test fetch(x) === nothing + # @test_throws_unwrap InterruptException fetch(A) === nothing + # end + + # @test !test_finishes("x -(drop)> A"; ignore_timeout=true) do + # local x, A + # Dagger.spawn_streaming() do + # Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + # x = Dagger.@spawn scope=rand(scopes) rand() + # A = Dagger.@spawn scope=rand(scopes) accumulator(x) + # end + # end + # @test fetch(x) === nothing + # @test_throws_unwrap InterruptException fetch(A) === nothing + # end + # end + + @testset "Graceful finishing" begin + @test test_finishes("finish_stream() without return value") do + B = Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream() + + Dagger.@spawn scope=rand(scopes) accumulator(A) + end + + fetch(B) + # Since we don't return any value in the call to finish_stream(), B + # should never execute. + @test isempty(ACCUMULATOR) + end + + @test test_finishes("finish_stream() with one downstream task") do + B = Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream(42) + + Dagger.@spawn scope=rand(scopes) accumulator(A) + end + + fetch(B) + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + @test values[Dagger.task_id(B)] == [42] + end + + @test test_finishes("finish_stream() with multiple downstream tasks"; max_evals=2) do + D, E = Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream(1) + B = Dagger.@spawn scope=rand(scopes) A + 1 + C = Dagger.@spawn scope=rand(scopes) A + 1 + D = Dagger.@spawn scope=rand(scopes) accumulator(B, C) + + E = Dagger.@spawn scope=rand(scopes) accumulator() + + D, E + end + + fetch(D) + fetch(E) + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + + # D should only execute once since it depends on A/B/C + @test values[Dagger.task_id(D)] == [4] + + # E should run max_evals times since it has no dependencies + @test length(values[Dagger.task_id(E)]) == 2 + end + end + + # FIXME: Varying buffer amounts + + #= TODO: Zero-allocation test + # First execution of a streaming task will almost guaranteed allocate (compiling, setup, etc.) + # BUT, second and later executions could possibly not allocate any further ("steady-state") + # We want to be able to validate that the steady-state execution for certain tasks is non-allocating + =# +end diff --git a/test/thunk.jl b/test/thunk.jl index e6fb7e86b..5e193a505 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -69,7 +69,7 @@ end A = rand(4, 4) @test fetch(@spawn sum(A; dims=1)) ≈ sum(A; dims=1) - @test_throws_unwrap Dagger.DTaskFailedException fetch(@spawn sum(A; fakearg=2)) + @test_throws_unwrap MethodError fetch(@spawn sum(A; fakearg=2)) @test fetch(@spawn reduce(+, A; dims=1, init=2.0)) ≈ reduce(+, A; dims=1, init=2.0) @@ -194,7 +194,7 @@ end a = @spawn error("Test") wait(a) @test isready(a) - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap ErrorException fetch(a) b = @spawn 1+2 @test fetch(b) == 3 end @@ -207,7 +207,6 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) ex_str = sprint(io->Base.showerror(io,ex)) @test occursin(r"^DTaskFailedException:", ex_str) @test occursin("Test", ex_str) @@ -218,7 +217,6 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) ex_str = sprint(io->Base.showerror(io,ex)) @test occursin("Test", ex_str) @test occursin("Root Task", ex_str) @@ -226,28 +224,28 @@ end @testset "single dependent" begin a = @spawn error("Test") b = @spawn a+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap ErrorException fetch(a) end @testset "multi dependent" begin a = @spawn error("Test") b = @spawn a+2 c = @spawn a*2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap ErrorException fetch(b) + @test_throws_unwrap ErrorException fetch(c) end @testset "dependent chain" begin a = @spawn error("Test") - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap ErrorException fetch(a) b = @spawn a+1 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap ErrorException fetch(b) c = @spawn b+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap ErrorException fetch(c) end @testset "single input" begin a = @spawn 1+1 b = @spawn (a->error("Test"))(a) @test fetch(a) == 2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap ErrorException fetch(b) end @testset "multi input" begin a = @spawn 1+1 @@ -255,7 +253,7 @@ end c = @spawn ((a,b)->error("Test"))(a,b) @test fetch(a) == 2 @test fetch(b) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap ErrorException fetch(c) end @testset "diamond" begin a = @spawn 1+1 @@ -265,9 +263,10 @@ end @test fetch(a) == 2 @test fetch(b) == 3 @test fetch(c) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(d) + @test_throws_unwrap ErrorException fetch(d) end end + @testset "remote spawn" begin a = fetch(Distributed.@spawnat 2 Dagger.@spawn 1+2) @test Dagger.Sch.EAGER_INIT[] @@ -283,7 +282,7 @@ end t1 = Dagger.@spawn 1+"fail" Dagger.@spawn t1+1 end - @test_throws_unwrap Dagger.DTaskFailedException fetch(t2) + @test_throws_unwrap MethodError fetch(t2) end @testset "undefined function" begin # Issues #254, #255