From 82be9e4330e566d5613cdca26620c170b53cbddc Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Mon, 1 Apr 2024 14:25:17 +0200 Subject: [PATCH 01/23] Delete Dagger.cleanup() Because it doesn't actually do anything now. --- src/compute.jl | 6 ------ src/sch/Sch.jl | 3 --- 2 files changed, 9 deletions(-) diff --git a/src/compute.jl b/src/compute.jl index f421eaccc..093b527f4 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -36,12 +36,6 @@ end Base.@deprecate gather(ctx, x) collect(ctx, x) Base.@deprecate gather(x) collect(x) -cleanup() = cleanup(Context(global_context())) -function cleanup(ctx::Context) - Sch.cleanup(ctx) - nothing -end - function get_type(s::String) local T for t in split(s, ".") diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 73bb07bf9..12d259352 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -307,9 +307,6 @@ function populate_defaults(opts::ThunkOptions, Tf, Targs) ) end -function cleanup(ctx) -end - # Eager scheduling include("eager.jl") From 3e82247613c90ee84722b495508c8f36d158d38e Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 9 Apr 2024 03:44:35 +0200 Subject: [PATCH 02/23] Use procs() when initializing EAGER_CONTEXT Using `myid()` with `workers()` meant that when the context was initialized with a single worker the processor list would be: `[OSProc(1), OSProc(1)]`. `procs()` will always include PID 1 and any other workers, which is what we want. --- src/sch/eager.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 87a109788..259c1b6f8 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -6,7 +6,7 @@ const EAGER_STATE = Ref{Union{ComputeState,Nothing}}(nothing) function eager_context() if EAGER_CONTEXT[] === nothing - EAGER_CONTEXT[] = Context([myid(),workers()...]) + EAGER_CONTEXT[] = Context(procs()) end return EAGER_CONTEXT[] end From 6d24fd53cd1f053a2a89ae9a16ae3512c53c1626 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:42:50 -0700 Subject: [PATCH 03/23] Add metadata to EagerThunk --- Project.toml | 1 + src/dtask.jl | 14 +++++++++++++- src/submission.jl | 14 +++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fd7508cd7..439d89081 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" +Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" diff --git a/src/dtask.jl b/src/dtask.jl index 68f2d3c1b..98f74005a 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -39,6 +39,16 @@ end Options(;options...) = Options((;options...)) Options(options...) = Options((;options...)) +""" + DTaskMetadata + +Represents some useful metadata pertaining to a `DTask`: +- `return_type::Type` - The inferred return type of the task +""" +mutable struct DTaskMetadata + return_type::Type +end + """ DTask @@ -50,9 +60,11 @@ more details. mutable struct DTask uid::UInt future::ThunkFuture + metadata::DTaskMetadata finalizer_ref::DRef thunk_ref::DRef - DTask(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + + DTask(uid, future, metadata, finalizer_ref) = new(uid, future, metadata, finalizer_ref) end const EagerThunk = DTask diff --git a/src/submission.jl b/src/submission.jl index 7312e378d..bfb8cb8be 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -218,15 +218,27 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) return options end end + +function DTaskMetadata(spec::DTaskSpec) + f = chunktype(spec.f).instance + arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + return_type = Base.promote_op(f, arg_types...) + return DTaskMetadata(return_type) +end + function eager_spawn(spec::DTaskSpec) # Generate new DTask uid = eager_next_id() future = ThunkFuture() + metadata = DTaskMetadata(spec) finalizer_ref = poolset(DTaskFinalizer(uid); device=MemPool.CPURAMDevice()) # Create unlaunched DTask - return DTask(uid, future, finalizer_ref) + return DTask(uid, future, metadata, finalizer_ref) end + +chunktype(t::DTask) = t.metadata.return_type + function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) # Assign a name, if specified eager_assign_name!(spec, task) From 7d3712ea201b9a6b123c8c096179bc880c766be3 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:44:14 -0700 Subject: [PATCH 04/23] Sch: Allow occupancy key to be Any --- src/sch/util.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index e81703db5..cd006838b 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -406,12 +406,19 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) else get(state.signature_alloc_cost, sig, UInt64(0)) end::UInt64 - est_occupancy = if occupancy !== nothing && haskey(occupancy, T) - # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` - Base.unsafe_trunc(UInt32, clamp(occupancy[T], 0, 1) * typemax(UInt32)) - else - typemax(UInt32) - end::UInt32 + est_occupancy::UInt32 = typemax(UInt32) + if occupancy !== nothing + occ = nothing + if haskey(occupancy, T) + occ = occupancy[T] + elseif haskey(occupancy, Any) + occ = occupancy[Any] + end + if occ !== nothing + # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` + est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32)) + end + end #= FIXME: Estimate if cached data can be swapped to storage storage = storage_resource(p) real_alloc_util = state.worker_storage_pressure[gp][storage] From aaef3434623ca86f923b8517a9673cefc2e22412 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 26 Nov 2024 12:29:32 -0600 Subject: [PATCH 05/23] Add a --verbose option to runtests.jl --- test/runtests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 04871f6b9..fe92e41fb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -51,6 +51,9 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ arg_type = Int default = additional_workers help = "How many additional workers to launch" + "-v", "--verbose" + action = :store_true + help = "Run the tests with debug logs from Dagger" end end @@ -80,6 +83,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ parsed_args["simulate"] && exit(0) additional_workers = parsed_args["procs"] + + if parsed_args["verbose"] + ENV["JULIA_DEBUG"] = "Dagger" + end else to_test = all_test_names @info "Running all tests" From 0763d99660fdae3d47b9e29384a763ae449130da Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 22 May 2024 12:48:53 -0500 Subject: [PATCH 06/23] task-tls: Refactor into DTaskTLS struct --- src/Dagger.jl | 2 ++ src/array/indexing.jl | 2 -- src/sch/Sch.jl | 2 +- src/sch/dynamic.jl | 2 +- src/task-tls.jl | 49 ++++++++++++++++++++++--------------------- 5 files changed, 29 insertions(+), 28 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index b478ece0f..c3a589c16 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -24,6 +24,8 @@ else import Base.ScopedValues: ScopedValue, with end +import TaskLocalValues: TaskLocalValue + if !isdefined(Base, :get_extension) import Requires: @require end diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..69725eb7a 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -1,5 +1,3 @@ -import TaskLocalValues: TaskLocalValue - ### getindex struct GetIndex{T,N} <: ArrayOp{T,N} diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 12d259352..548a1104d 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1196,7 +1196,7 @@ function proc_states(f::Base.Callable, uid::UInt64) end end proc_states(f::Base.Callable) = - proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64) + proc_states(f, Dagger.get_tls().sch_uid) task_tid_for_processor(::Processor) = nothing task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index e02085ee6..5b917fdb5 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -17,7 +17,7 @@ struct SchedulerHandle end "Gets the scheduler handle for the currently-executing thunk." -sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle +sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle "Thrown when the scheduler halts before finishing processing the DAG." struct SchedulerHaltedException <: Exception end diff --git a/src/task-tls.jl b/src/task-tls.jl index ea188e004..90fdfedb3 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,41 +1,42 @@ # In-Thunk Helpers +struct DTaskTLS + processor::Processor + sch_uid::UInt + sch_handle::Any # FIXME: SchedulerHandle + task_spec::Vector{Any} # FIXME: TaskSpec +end + +const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) + """ - task_processor() + get_tls() -> DTaskTLS -Get the current processor executing the current Dagger task. +Gets all Dagger TLS variable as a `DTaskTLS`. """ -task_processor() = task_local_storage(:_dagger_processor)::Processor -@deprecate thunk_processor() task_processor() +get_tls() = DTASK_TLS[]::DTaskTLS """ - in_task() + set_tls!(tls) -Returns `true` if currently executing in a [`DTask`](@ref), else `false`. +Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. """ -in_task() = haskey(task_local_storage(), :_dagger_sch_uid) -@deprecate in_thunk() in_task() +function set_tls!(tls) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec) +end """ - get_tls() + in_task() -> Bool -Gets all Dagger TLS variable as a `NamedTuple`. +Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ -get_tls() = ( - sch_uid=task_local_storage(:_dagger_sch_uid), - sch_handle=task_local_storage(:_dagger_sch_handle), - processor=task_processor(), - task_spec=task_local_storage(:_dagger_task_spec), -) +in_task() = DTASK_TLS[] !== nothing +@deprecate in_thunk() in_task() """ - set_tls!(tls) + task_processor() -> Processor -Sets all Dagger TLS variables from the `NamedTuple` `tls`. +Get the current processor executing the current [`DTask`](@ref). """ -function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_task_spec, tls.task_spec) -end +task_processor() = get_tls().processor +@deprecate thunk_processor() task_processor() From 488ae7a58b2f09404d05e8805749d52d0dadbb5b Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 12:21:13 -0400 Subject: [PATCH 07/23] cancellation: Add cancel token support --- src/Dagger.jl | 4 ++-- src/cancellation.jl | 38 +++++++++++++++++++++++++++++++++++++- src/sch/Sch.jl | 14 ++++++++++++++ src/task-tls.jl | 21 ++++++++++++++++++++- src/threadproc.jl | 3 ++- 5 files changed, 75 insertions(+), 5 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index c3a589c16..1ec68d071 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -48,16 +48,16 @@ include("processor.jl") include("threadproc.jl") include("context.jl") include("utils/processors.jl") +include("dtask.jl") +include("cancellation.jl") include("task-tls.jl") include("scopes.jl") include("utils/scopes.jl") -include("dtask.jl") include("queue.jl") include("thunk.jl") include("submission.jl") include("chunks.jl") include("memory-spaces.jl") -include("cancellation.jl") # Task scheduling include("compute.jl") diff --git a/src/cancellation.jl b/src/cancellation.jl index c982fd20c..dcb0f5add 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,3 +1,38 @@ +# DTask-level cancellation + +struct CancelToken + cancelled::Base.RefValue{Bool} + event::Base.Event +end +CancelToken() = CancelToken(Ref(false), Base.Event()) +function cancel!(token::CancelToken) + token.cancelled[] = true + notify(token.event) + return +end +is_cancelled(token::CancelToken) = token.cancelled[] +Base.wait(token::CancelToken) = wait(token.event) +# TODO: Enable this for safety +#Serialization.serialize(io::AbstractSerializer, ::CancelToken) = +# throw(ConcurrencyViolationError("Cannot serialize a CancelToken")) + +const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing) + +function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer) + remote_token = remotecall_fetch(wid) do + return poolset(CancelToken()) + end + errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin + wait(orig_token) + @dagdebug nothing :cancel "Cancelling remote token on worker $wid" + MemPool.access_ref(remote_token) do remote_token + cancel!(remote_token) + end + end) +end + +# Global-level cancellation + """ cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) @@ -80,11 +115,11 @@ function _cancel!(state, tid, force, halt_sch) Tf === typeof(Sch.eager_thunk) && continue istaskdone(task) && continue any_cancelled = true - @dagdebug tid :cancel "Cancelling running task ($Tf)" if force @dagdebug tid :cancel "Interrupting running task ($Tf)" Threads.@spawn Base.throwto(task, InterruptException()) else + @dagdebug tid :cancel "Cancelling running task ($Tf)" # Tell the processor to just drop this task task_occupancy = task_spec[4] time_util = task_spec[2] @@ -93,6 +128,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) + cancel!(istate.cancel_tokens[tid]) end end end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 548a1104d..3df38c182 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1177,6 +1177,7 @@ struct ProcessorInternalState proc_occupancy::Base.RefValue{UInt32} time_pressure::Base.RefValue{UInt64} cancelled::Set{Int} + cancel_tokens::Dict{Int,Dagger.CancelToken} done::Base.RefValue{Bool} end struct ProcessorState @@ -1326,7 +1327,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Execute the task and return its result t = @task begin + # Set up cancellation + cancel_token = Dagger.CancelToken() + Dagger.DTASK_CANCEL_TOKEN[] = cancel_token + lock(istate.queue) do _ + istate.cancel_tokens[thunk_id] = cancel_token + end was_cancelled = false + result = try do_task(to_proc, task) catch err @@ -1343,6 +1351,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Task was cancelled, so occupancy and pressure are # already reduced pop!(istate.cancelled, thunk_id) + delete!(istate.cancel_tokens, thunk_id) was_cancelled = true end end @@ -1360,6 +1369,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re else rethrow(err) end + finally + # Ensure that any spawned tasks get cleaned up + Dagger.cancel!(cancel_token) end end lock(istate.queue) do _ @@ -1409,6 +1421,7 @@ function do_tasks(to_proc, return_queue, tasks) Dict{Int,Vector{Any}}(), Ref(UInt32(0)), Ref(UInt64(0)), Set{Int}(), + Dict{Int,Dagger.CancelToken}(), Ref(false)) runner = start_processor_runner!(istate, uid, return_queue) @static if VERSION < v"1.9" @@ -1650,6 +1663,7 @@ function do_task(to_proc, task_desc) sch_handle, processor=to_proc, task_spec=task_desc, + cancel_token=Dagger.DTASK_CANCEL_TOKEN[], )) res = Dagger.with_options(propagated) do diff --git a/src/task-tls.jl b/src/task-tls.jl index 90fdfedb3..8a8b6c66d 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -5,6 +5,7 @@ struct DTaskTLS sch_uid::UInt sch_handle::Any # FIXME: SchedulerHandle task_spec::Vector{Any} # FIXME: TaskSpec + cancel_token::CancelToken end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) @@ -22,7 +23,7 @@ get_tls() = DTASK_TLS[]::DTaskTLS Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. """ function set_tls!(tls) - DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) end """ @@ -40,3 +41,21 @@ Get the current processor executing the current [`DTask`](@ref). """ task_processor() = get_tls().processor @deprecate thunk_processor() task_processor() + +""" + task_cancelled() -> Bool + +Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +""" +task_cancelled() = get_tls().cancel_token.cancelled[] + +""" + task_may_cancel!() + +Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +""" +function task_may_cancel!() + if task_cancelled() + throw(InterruptException()) + end +end diff --git a/src/threadproc.jl b/src/threadproc.jl index 09099889a..b75c90ca3 100644 --- a/src/threadproc.jl +++ b/src/threadproc.jl @@ -27,8 +27,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n return result[] catch err if err isa InterruptException + # Direct interrupt hit us, propagate cancellation signal + # FIXME: We should tell the scheduler that the user hit Ctrl-C if !istaskdone(task) - # Propagate cancellation signal Threads.@spawn Base.throwto(task, InterruptException()) end end From 12e1bda3ffac4d85782d0b1a095fffa06d2d7d26 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:59:55 -0500 Subject: [PATCH 08/23] task-tls: Tweaks and fixes, task_id helper --- src/task-tls.jl | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/task-tls.jl b/src/task-tls.jl index 8a8b6c66d..f6889bbb1 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,6 +1,6 @@ # In-Thunk Helpers -struct DTaskTLS +mutable struct DTaskTLS processor::Processor sch_uid::UInt sch_handle::Any # FIXME: SchedulerHandle @@ -10,6 +10,8 @@ end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) +Base.copy(tls::DTaskTLS) = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) + """ get_tls() -> DTaskTLS @@ -32,7 +34,14 @@ end Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ in_task() = DTASK_TLS[] !== nothing -@deprecate in_thunk() in_task() +@deprecate(in_thunk(), in_task()) + +""" + task_id() -> Int + +Returns the ID of the current [`DTask`](@ref). +""" +task_id() = get_tls().sch_handle.thunk_id.id """ task_processor() -> Processor @@ -40,14 +49,14 @@ in_task() = DTASK_TLS[] !== nothing Get the current processor executing the current [`DTask`](@ref). """ task_processor() = get_tls().processor -@deprecate thunk_processor() task_processor() +@deprecate(thunk_processor(), task_processor()) """ task_cancelled() -> Bool Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. """ -task_cancelled() = get_tls().cancel_token.cancelled[] +task_cancelled() = is_cancelled(get_tls().cancel_token) """ task_may_cancel!() @@ -59,3 +68,10 @@ function task_may_cancel!() throw(InterruptException()) end end + +""" + task_cancel!() + +Cancels the current [`DTask`](@ref). +""" +task_cancel!() = cancel!(get_tls().cancel_token) From d1c467b8fe08e9245bf4ee422ddd9c557f346eb1 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:07:18 -0500 Subject: [PATCH 09/23] cancellation: Add graceful vs. forced --- src/cancellation.jl | 27 ++++++++++++++++++++------- src/task-tls.jl | 20 ++++++++++++-------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/cancellation.jl b/src/cancellation.jl index dcb0f5add..0aa150331 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,16 +1,29 @@ # DTask-level cancellation -struct CancelToken - cancelled::Base.RefValue{Bool} +mutable struct CancelToken + @atomic cancelled::Bool + @atomic graceful::Bool event::Base.Event end -CancelToken() = CancelToken(Ref(false), Base.Event()) -function cancel!(token::CancelToken) - token.cancelled[] = true +CancelToken() = CancelToken(false, false, Base.Event()) +function cancel!(token::CancelToken; graceful::Bool=true) + if !graceful + @atomic token.graceful = false + end + @atomic token.cancelled = true notify(token.event) return end -is_cancelled(token::CancelToken) = token.cancelled[] +function is_cancelled(token::CancelToken; must_force::Bool=false) + if token.cancelled[] + if must_force && token.graceful[] + # If we're only responding to forced cancellation, ignore graceful cancellations + return false + end + return true + end + return false +end Base.wait(token::CancelToken) = wait(token.event) # TODO: Enable this for safety #Serialization.serialize(io::AbstractSerializer, ::CancelToken) = @@ -128,7 +141,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) - cancel!(istate.cancel_tokens[tid]) + cancel!(istate.cancel_tokens[tid]; graceful=false) end end end diff --git a/src/task-tls.jl b/src/task-tls.jl index f6889bbb1..5c7d0375b 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -52,26 +52,30 @@ task_processor() = get_tls().processor @deprecate(thunk_processor(), task_processor()) """ - task_cancelled() -> Bool + task_cancelled(; must_force::Bool=false) -> Bool Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +If `must_force=true`, then only return `true` if the cancellation was forced. """ -task_cancelled() = is_cancelled(get_tls().cancel_token) +task_cancelled(; must_force::Bool=false) = + is_cancelled(get_tls().cancel_token; must_force) """ - task_may_cancel!() + task_may_cancel!(; must_force::Bool=false) Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +If `must_force=true`, then only throw if the cancellation was forced. """ -function task_may_cancel!() - if task_cancelled() +function task_may_cancel!(;must_force::Bool=false) + if task_cancelled(;must_force) throw(InterruptException()) end end """ - task_cancel!() + task_cancel!(; graceful::Bool=true) -Cancels the current [`DTask`](@ref). +Cancels the current [`DTask`](@ref). If `graceful=true`, then the task will be +cancelled gracefully, otherwise it will be forced. """ -task_cancel!() = cancel!(get_tls().cancel_token) +task_cancel!(; graceful::Bool=true) = cancel!(get_tls().cancel_token; graceful) From ef2ea1e8d77d5fba02fb16a47d314cd2246aaaf6 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:07:56 -0500 Subject: [PATCH 10/23] cancellation: Wrap InterruptException in DTaskFailedException --- src/cancellation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cancellation.jl b/src/cancellation.jl index 0aa150331..86f562bc4 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -96,7 +96,7 @@ function _cancel!(state, tid, force, halt_sch) for task in state.ready tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling ready task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end @@ -106,7 +106,7 @@ function _cancel!(state, tid, force, halt_sch) for task in keys(state.waiting) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling waiting task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end From 09fa4b7be5892e48068fb99173a05e726a845307 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 14 Sep 2024 11:54:34 -0400 Subject: [PATCH 11/23] Sch: Add unwrap_nested_exception for DTaskFailedException --- src/sch/util.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sch/util.jl b/src/sch/util.jl index cd006838b..eb5a285b4 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -29,6 +29,8 @@ unwrap_nested_exception(err::CapturedException) = unwrap_nested_exception(err.ex) unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) +unwrap_nested_exception(err::DTaskFailedException) = + unwrap_nested_exception(err.ex) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." From f593648874b1625c69f678e38fd9f471ff83be0b Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:22:52 -0500 Subject: [PATCH 12/23] Add task_id for DTask --- src/sch/eager.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 259c1b6f8..f3aca2ca0 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -134,3 +134,6 @@ function _find_thunk(e::Dagger.DTask) unwrap_weak_checked(EAGER_STATE[].thunk_dict[tid]) end end +Dagger.task_id(t::Dagger.DTask) = lock(EAGER_ID_MAP) do id_map + id_map[t.uid] +end From 711489be0581aa0ed354ec02c305bf407c165bc6 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:05:49 -0500 Subject: [PATCH 13/23] dagdebug: Always yield to avoid heisenbugs --- src/utils/dagdebug.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 9a9d24167..8b6d3530f 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -31,6 +31,10 @@ macro dagdebug(thunk, category, msg, args...) $debug_ex_noid end end + + # Always yield to reduce differing behavior for debug vs. non-debug + # TODO: Remove this eventually + yield() end end) end From 0620a024096e83700a2b02ba66602b451008f7b4 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:06:59 -0500 Subject: [PATCH 14/23] tests: Add offline mode --- test/runtests.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index fe92e41fb..aa5b34aec 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,7 +34,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) - Pkg.instantiate() + try + Pkg.instantiate() + catch + end using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") @@ -54,6 +57,9 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ "-v", "--verbose" action = :store_true help = "Run the tests with debug logs from Dagger" + "-O", "--offline" + action = :store_true + help = "Set Pkg into offline mode" end end @@ -87,6 +93,11 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ if parsed_args["verbose"] ENV["JULIA_DEBUG"] = "Dagger" end + + if parsed_args["offline"] + Pkg.UPDATED_REGISTRY_THIS_SESSION[] = true + Pkg.offline(true) + end else to_test = all_test_names @info "Running all tests" From 4daf50eeafd2f17aa5c6b10ab54057d3bd63b2c2 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 09:00:23 -0500 Subject: [PATCH 15/23] dagdebug: Add JULIA_DAGGER_DEBUG config variable --- src/Dagger.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/Dagger.jl b/src/Dagger.jl index 1ec68d071..2e1e7387f 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -147,6 +147,20 @@ function __init__() ThreadProc(myid(), tid) end end + + # Set up @dagdebug categories, if specified + try + if haskey(ENV, "JULIA_DAGGER_DEBUG") + empty!(DAGDEBUG_CATEGORIES) + for category in split(ENV["JULIA_DAGGER_DEBUG"], ",") + if category != "" + push!(DAGDEBUG_CATEGORIES, Symbol(category)) + end + end + end + catch err + @warn "Error parsing JULIA_DAGGER_DEBUG" exception=err + end end end # module From 1d796a20e9a4277788d7753bbb0757cf0190d74d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:08:31 -0500 Subject: [PATCH 16/23] options: Add internal helper to strip all options --- src/options.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/options.jl b/src/options.jl index 1c1e3ff29..00196dd59 100644 --- a/src/options.jl +++ b/src/options.jl @@ -20,6 +20,12 @@ function with_options(f, options::NamedTuple) end with_options(f; options...) = with_options(f, NamedTuple(options)) +function _without_options(f) + with(options_context => NamedTuple()) do + f() + end +end + """ get_options(key::Symbol, default) -> Any get_options(key::Symbol) -> Any From 73a24ad81b701c9eced77acbb25dd2ccc77b9802 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 21 Nov 2024 07:09:17 -0500 Subject: [PATCH 17/23] tests: Test DTaskFailedException inner type --- test/mutation.jl | 2 +- test/processors.jl | 4 ++-- test/scheduler.jl | 12 ++++++------ test/scopes.jl | 11 ++++++----- test/thunk.jl | 30 ++++++++++++++---------------- test/util.jl | 15 ++++++++++----- 6 files changed, 39 insertions(+), 35 deletions(-) diff --git a/test/mutation.jl b/test/mutation.jl index b6ac7143b..fa2f62bcf 100644 --- a/test/mutation.jl +++ b/test/mutation.jl @@ -48,7 +48,7 @@ end x = Dagger.@mutable worker=w Ref{Int}() @test fetch(Dagger.@spawn mutable_update!(x)) == w wo_scope = Dagger.ProcessScope(wo) - @test_throws_unwrap Dagger.DTaskFailedException fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) end end # @testset "@mutable" diff --git a/test/processors.jl b/test/processors.jl index e97a1d239..6e56876dd 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -37,9 +37,9 @@ end end @testset "Processor exhaustion" begin opts = ThunkOptions(proclist=[OptOutProc]) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=(proc)->false) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=nothing) @test collect(delayed(sum; options=opts)([1,2,3])) == 6 end diff --git a/test/scheduler.jl b/test/scheduler.jl index b9fe01872..b12ad3e1e 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -182,7 +182,7 @@ end @testset "allow errors" begin opts = ThunkOptions(;allow_errors=true) a = delayed(error; options=opts)("Test") - @test_throws_unwrap Dagger.DTaskFailedException collect(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) collect(a) end end @@ -396,7 +396,7 @@ end ([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)), ] for arg in args - if arg isa Chunk + if arg isa Dagger.Chunk aff = Dagger.affinity(arg) @test aff[1] == OSProc(1) @test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle)) @@ -477,7 +477,7 @@ end @test res == 2 @testset "self as input" begin a = delayed(dynamic_add_thunk_self_dominated)(1) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch result of dominated thunk" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch result of dominated thunk" collect(Context(), a) end end @testset "Fetch/Wait" begin @@ -487,11 +487,11 @@ end end @testset "self" begin a = delayed(dynamic_fetch_self)(1) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch own result" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch own result" collect(Context(), a) end @testset "dominated" begin a = delayed(identity)(delayed(dynamic_fetch_dominated)(1)) - @test_throws_unwrap Dagger.Sch.DynamicThunkException reason="Cannot fetch result of dominated thunk" collect(Context(), a) + @test_throws_unwrap (RemoteException, Dagger.Sch.DynamicThunkException) reason="Cannot fetch result of dominated thunk" collect(Context(), a) end end end @@ -540,7 +540,7 @@ end t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100) start_time = time_ns() Dagger.cancel!(t) - @test_throws_unwrap Dagger.DTaskFailedException fetch(t) + @test_throws_unwrap (Dagger.DTaskFailedException, InterruptException) fetch(t) t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) yield() fetch(t) finish_time = time_ns() diff --git a/test/scopes.jl b/test/scopes.jl index 5f82a71a0..065e5158f 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -1,3 +1,4 @@ +#@everywhere ENV["JULIA_DEBUG"] = "Dagger" @testset "Chunk Scopes" begin wid1, wid2 = addprocs(2, exeflags=["-t 2"]) @everywhere [wid1,wid2] using Dagger @@ -56,7 +57,7 @@ # Different nodes for (ch1, ch2) in [(ns1_ch, ns2_ch), (ns2_ch, ns1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Process Scope" begin @@ -75,7 +76,7 @@ # Different process for (ch1, ch2) in [(ps1_ch, ps2_ch), (ps2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process and node @@ -83,7 +84,7 @@ # Different process and node for (ch1, ch2) in [(ps1_ch, ns2_ch), (ns2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Exact Scope" begin @@ -104,14 +105,14 @@ # Different process, different processor for (ch1, ch2) in [(es1_ch, es2_ch), (es2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process, different processor es1_2 = ExactScope(Dagger.ThreadProc(wid1, 2)) es1_2_ch = Dagger.tochunk(nothing, OSProc(), es1_2) for (ch1, ch2) in [(es1_ch, es1_2_ch), (es1_2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap (Dagger.DTaskFailedException, Dagger.Sch.SchedulingException) reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Union Scope" begin diff --git a/test/thunk.jl b/test/thunk.jl index e6fb7e86b..73879545b 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -69,7 +69,7 @@ end A = rand(4, 4) @test fetch(@spawn sum(A; dims=1)) ≈ sum(A; dims=1) - @test_throws_unwrap Dagger.DTaskFailedException fetch(@spawn sum(A; fakearg=2)) + @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(@spawn sum(A; fakearg=2)) @test fetch(@spawn reduce(+, A; dims=1, init=2.0)) ≈ reduce(+, A; dims=1, init=2.0) @@ -194,7 +194,7 @@ end a = @spawn error("Test") wait(a) @test isready(a) - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) b = @spawn 1+2 @test fetch(b) == 3 end @@ -207,8 +207,7 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) - ex_str = sprint(io->Base.showerror(io,ex)) + ex_str = sprint(io->Base.showerror(io, ex)) @test occursin(r"^DTaskFailedException:", ex_str) @test occursin("Test", ex_str) @test !occursin("Root Task", ex_str) @@ -218,36 +217,35 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) - ex_str = sprint(io->Base.showerror(io,ex)) + ex_str = sprint(io->Base.showerror(io, ex)) @test occursin("Test", ex_str) @test occursin("Root Task", ex_str) end @testset "single dependent" begin a = @spawn error("Test") b = @spawn a+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) end @testset "multi dependent" begin a = @spawn error("Test") b = @spawn a+2 c = @spawn a*2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "dependent chain" begin a = @spawn error("Test") - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(a) b = @spawn a+1 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) c = @spawn b+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "single input" begin a = @spawn 1+1 b = @spawn (a->error("Test"))(a) @test fetch(a) == 2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(b) end @testset "multi input" begin a = @spawn 1+1 @@ -255,7 +253,7 @@ end c = @spawn ((a,b)->error("Test"))(a,b) @test fetch(a) == 2 @test fetch(b) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(c) end @testset "diamond" begin a = @spawn 1+1 @@ -265,7 +263,7 @@ end @test fetch(a) == 2 @test fetch(b) == 3 @test fetch(c) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(d) + @test_throws_unwrap (Dagger.DTaskFailedException, ErrorException) fetch(d) end end @testset "remote spawn" begin @@ -283,7 +281,7 @@ end t1 = Dagger.@spawn 1+"fail" Dagger.@spawn t1+1 end - @test_throws_unwrap Dagger.DTaskFailedException fetch(t2) + @test_throws_unwrap (Dagger.DTaskFailedException, MethodError) fetch(t2) end @testset "undefined function" begin # Issues #254, #255 diff --git a/test/util.jl b/test/util.jl index f01b3d95d..1131a9ebe 100644 --- a/test/util.jl +++ b/test/util.jl @@ -14,7 +14,7 @@ end replace_obj!(ex::Symbol, obj) = Expr(:(.), obj, QuoteNode(ex)) replace_obj!(ex, obj) = ex function _test_throws_unwrap(terr, ex; to_match=[]) - @gensym rerr + @gensym oerr rerr match_expr = Expr(:block) for m in to_match if m.head == :(=) @@ -35,12 +35,17 @@ function _test_throws_unwrap(terr, ex; to_match=[]) end end quote - $rerr = try - $(esc(ex)) + $oerr, $rerr = try + nothing, $(esc(ex)) catch err - Dagger.Sch.unwrap_nested_exception(err) + (err, Dagger.Sch.unwrap_nested_exception(err)) + end + if $terr isa Tuple + @test $oerr isa $terr[1] + @test $rerr isa $terr[2] + else + @test $rerr isa $terr end - @test $rerr isa $terr $match_expr end end From 8d0414496892ac5cd2d78ad6a51762ca812c4f57 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Mon, 25 Nov 2024 12:18:50 -0600 Subject: [PATCH 18/23] Sch: Skip not-yet-inited workers --- src/sch/Sch.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 3df38c182..9083f6282 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -678,6 +678,9 @@ function schedule!(ctx, state, procs=procs_to_use(ctx)) safepoint(state) @assert length(procs) > 0 + # Remove processors that aren't yet initialized + procs = filter(p -> haskey(state.worker_chans, Dagger.root_worker_id(p)), procs) + populate_processor_cache_list!(state, procs) # Schedule tasks From e772af00bd4785b569c265463de2ae23ff9dcd0e Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 12 Sep 2023 10:56:47 -0500 Subject: [PATCH 19/23] Add streaming API --- Project.toml | 1 - docs/make.jl | 1 + docs/src/streaming.md | 105 ++++++ src/Dagger.jl | 6 +- src/sch/Sch.jl | 8 +- src/sch/eager.jl | 7 + src/stream-buffers.jl | 98 ++++++ src/stream-transfer.jl | 128 ++++++++ src/stream.jl | 715 +++++++++++++++++++++++++++++++++++++++++ src/submission.jl | 2 +- src/utils/dagdebug.jl | 3 +- test/runtests.jl | 2 +- test/streaming.jl | 426 ++++++++++++++++++++++++ 13 files changed, 1494 insertions(+), 8 deletions(-) create mode 100644 docs/src/streaming.md create mode 100644 src/stream-buffers.jl create mode 100644 src/stream-transfer.jl create mode 100644 src/stream.jl create mode 100644 test/streaming.jl diff --git a/Project.toml b/Project.toml index 439d89081..fd7508cd7 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" -Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" diff --git a/docs/make.jl b/docs/make.jl index c21c03f2d..8f1f97f5c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,6 +22,7 @@ makedocs(; "Task Spawning" => "task-spawning.md", "Data Management" => "data-management.md", "Distributed Arrays" => "darray.md", + "Streaming Tasks" => "streaming.md", "Scopes" => "scopes.md", "Processors" => "processors.md", "Task Queues" => "task-queues.md", diff --git a/docs/src/streaming.md b/docs/src/streaming.md new file mode 100644 index 000000000..338f5739b --- /dev/null +++ b/docs/src/streaming.md @@ -0,0 +1,105 @@ +# Streaming Tasks + +Dagger tasks have a limited lifetime - they are created, execute, finish, and +are eventually destroyed when they're no longer needed. Thus, if one wants +to run the same kind of computations over and over, one might re-create a +similar set of tasks for each unit of data that needs processing. + +This might be fine for computations which take a long time to run (thus +dwarfing the cost of task creation, which is quite small), or when working with +a limited set of data, but this approach is not great for doing lots of small +computations on a large (or endless) amount of data. For example, processing +image frames from a webcam, reacting to messages from a message bus, reading +samples from a software radio, etc. All of these tasks are better suited to a +"streaming" model of data processing, where data is simply piped into a +continuously-running task (or DAG of tasks) forever, or until the data runs +out. + +Thankfully, if you have a problem which is best modeled as a streaming system +of tasks, Dagger has you covered! Building on its support for +[Task Queues](@ref), Dagger provides a means to convert an entire DAG of +tasks into a streaming DAG, where data flows into and out of each task +asynchronously, using the `spawn_streaming` function: + +```julia +Dagger.spawn_streaming() do # enters a streaming region + vals = Dagger.@spawn rand() + print_vals = Dagger.@spawn println(vals) +end # exits the streaming region, and starts the DAG running +``` + +In the above example, `vals` is a Dagger task which has been transformed to run +in a streaming manner - instead of just calling `rand()` once and returning its +result, it will re-run `rand()` endlessly, continuously producing new random +values. In typical Dagger style, `print_vals` is a Dagger task which depends on +`vals`, but in streaming form - it will continuously `println` the random +values produced from `vals`. Both tasks will run forever, and will run +efficiently, only doing the work necessary to generate, transfer, and consume +values. + +As the comments point out, `spawn_streaming` creates a streaming region, during +which `vals` and `print_vals` are created and configured. Both tasks are halted +until `spawn_streaming` returns, allowing large DAGs to be built all at once, +without any task losing a single value. If desired, streaming regions can be +connected, although some values might be lost while tasks are being connected: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.@spawn rand() +end + +# Some values might be generated by `vals` but thrown away +# before `print_vals` is fully setup and connected to it + +print_vals = Dagger.spawn_streaming() do + Dagger.@spawn println(vals) +end +``` + +More complicated streaming DAGs can be easily constructed, without doing +anything different. For example, we can generate multiple streams of random +numbers, write them all to their own files, and print the combined results: + +```julia +Dagger.spawn_streaming() do + all_vals = [Dagger.spawn(rand) for i in 1:4] + all_vals_written = map(1:4) do i + Dagger.spawn(all_vals[i]) do val + open("results_$i.txt"; write=true, create=true, append=true) do io + println(io, repr(val)) + end + return val + end + end + Dagger.spawn(all_vals_written...) do all_vals_written... + vals_sum = sum(all_vals_written) + println(vals_sum) + end +end +``` + +If you want to stop the streaming DAG and tear it all down, you can call +`Dagger.kill!(all_vals[1])` (or `Dagger.kill!(all_vals_written[2])`, etc., the +kill propagates throughout the DAG). + +Alternatively, tasks can stop themselves from the inside with +`finish_streaming`, optionally returning a value that can be `fetch`'d. Let's +do this when our randomly-drawn number falls within some arbitrary range: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.spawn() do + x = rand() + if x < 0.001 + # That's good enough, let's be done + return Dagger.finish_streaming("Finished!") + end + return x + end +end +fetch(vals) +``` + +In this example, the call to `fetch` will hang (while random numbers continue +to be drawn), until a drawn number is less than 0.001; at that point, `fetch` +will return with "Finished!", and the task `vals` will have terminated. diff --git a/src/Dagger.jl b/src/Dagger.jl index 2e1e7387f..4ce475871 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -23,7 +23,6 @@ if !isdefined(Base, :ScopedValues) else import Base.ScopedValues: ScopedValue, with end - import TaskLocalValues: TaskLocalValue if !isdefined(Base, :get_extension) @@ -69,6 +68,11 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps.jl") +# Streaming +include("stream.jl") +include("stream-buffers.jl") +include("stream-transfer.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 9083f6282..f42ed634e 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -253,9 +253,11 @@ end Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`. """ function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) - single = topts.single !== nothing ? topts.single : sopts.single - allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors - proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist + select_option = (sopt, topt) -> isnothing(topt) ? sopt : topt + + single = select_option(sopts.single, topts.single) + allow_errors = select_option(sopts.allow_errors, topts.allow_errors) + proclist = select_option(sopts.proclist, topts.proclist) ThunkOptions(single, proclist, topts.time_util, diff --git a/src/sch/eager.jl b/src/sch/eager.jl index f3aca2ca0..aea0abbf6 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -124,6 +124,13 @@ function eager_cleanup(state, uid) # N.B. cache and errored expire automatically delete!(state.thunk_dict, tid) end + remotecall_wait(1, uid) do uid + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + delete!(global_streams, uid) + end + end + end end function _find_thunk(e::Dagger.DTask) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl new file mode 100644 index 000000000..579a59753 --- /dev/null +++ b/src/stream-buffers.jl @@ -0,0 +1,98 @@ +""" +A buffer that drops all elements put into it. +""" +mutable struct DropBuffer{T} + open::Bool + DropBuffer{T}() where T = new{T}(true) +end +DropBuffer{T}(_) where T = DropBuffer{T}() +Base.isempty(::DropBuffer) = true +isfull(::DropBuffer) = false +capacity(::DropBuffer) = typemax(Int) +Base.length(::DropBuffer) = 0 +Base.isopen(buf::DropBuffer) = buf.open +function Base.close(buf::DropBuffer) + buf.open = false +end +function Base.put!(buf::DropBuffer, _) + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + return +end +function Base.take!(buf::DropBuffer) + while true + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + end +end + +"A process-local ring buffer." +mutable struct ProcessRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + @atomic open::Bool + function ProcessRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer, true) + end +end +Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 +isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +capacity(rb::ProcessRingBuffer) = length(rb.buffer) +Base.length(rb::ProcessRingBuffer) = @atomic rb.count +Base.isopen(rb::ProcessRingBuffer) = @atomic rb.open +function Base.close(rb::ProcessRingBuffer) + @atomic rb.open = false +end +function Base.put!(rb::ProcessRingBuffer{T}, x) where T + while isfull(rb) + yield() + if !isopen(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + end + to_write_idx = mod1(rb.write_idx, length(rb.buffer)) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ProcessRingBuffer) + while isempty(rb) + yield() + if !isopen(rb) && isempty(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) + end + if task_cancelled() && isempty(rb) + # We respect a graceful cancellation only if the buffer is empty. + # Otherwise, we may have values to continue communicating. + task_may_cancel!() + end + task_may_cancel!(; must_force=true) + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end + +""" +`take!()` all the elements from a buffer and put them in a `Vector`. +""" +function collect!(rb::ProcessRingBuffer{T}) where T + output = Vector{T}(undef, rb.count) + for i in 1:rb.count + output[i] = take!(rb) + end + + return output +end diff --git a/src/stream-transfer.jl b/src/stream-transfer.jl new file mode 100644 index 000000000..667808762 --- /dev/null +++ b/src/stream-transfer.jl @@ -0,0 +1,128 @@ +struct RemoteChannelFetcher + chan::RemoteChannel + RemoteChannelFetcher() = new(RemoteChannel()) +end +const _THEIR_TID = TaskLocalValue{Int}(()->0) +function stream_push_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_push "taking output value: $our_tid -> $their_tid" + value = try + take!(buffer) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_push "pushing output value: $our_tid -> $their_tid" + try + put!(fetcher.chan, value) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_push "channel closed: $our_tid -> $their_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_push "finished pushing output value: $our_tid -> $their_tid" +end +function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_pull "pulling input value: $their_tid -> $our_tid" + value = try + take!(fetcher.chan) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_pull "channel closed: $their_tid -> $our_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_pull "putting input value: $their_tid -> $our_tid" + try + put!(buffer, value) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_pull "finished putting input value: $their_tid -> $our_tid" +end + +#= TODO: Remove me +# This is a bad implementation because it wants to sleep on the remote side to +# wait for values, but this isn't semantically valid when done with MemPool.access_ref +struct RemoteFetcher end +function stream_push_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + sleep(1) +end +function stream_pull_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + id = our_store.uid + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "fetching values" + + free_space = capacity(buffer) - length(buffer) + if free_space == 0 + @dagdebug thunk_id :stream "waiting for drain of full input buffer" + yield() + task_may_cancel!() + wait_for_nonfull_input(our_store, their_stream.uid) + return + end + + values = T[] + while isempty(values) + values, closed = MemPool.access_ref(their_stream.store_ref.handle, id, T, thunk_id, free_space) do their_store, id, T, thunk_id, free_space + @dagdebug thunk_id :stream "trying to fetch values at worker $(myid())" + STREAM_THUNK_ID[] = thunk_id + values = T[] + @dagdebug thunk_id :stream "trying to fetch with free_space: $free_space" + wait_for_nonempty_output(their_store, id) + if isempty(their_store, id) && !isopen(their_store, id) + @dagdebug thunk_id :stream "remote stream is closed, returning" + return values, true + end + while !isempty(their_store, id) && length(values) < free_space + value = take!(their_store, id)::T + @dagdebug thunk_id :stream "fetched $value" + push!(values, value) + end + return values, false + end::Tuple{Vector{T},Bool} + if closed + throw(InterruptException()) + end + + # We explicitly yield in the loop to allow other tasks to run. This + # matters on single-threaded instances because MemPool.access_ref() + # might not yield when accessing data locally, which can cause this loop + # to spin forever. + yield() + task_may_cancel!() + end + + @dagdebug thunk_id :stream "fetched $(length(values)) values" + for value in values + put!(buffer, value) + end +end +=# diff --git a/src/stream.jl b/src/stream.jl new file mode 100644 index 000000000..f3c2cfbcb --- /dev/null +++ b/src/stream.jl @@ -0,0 +1,715 @@ +mutable struct StreamStore{T,B} + uid::UInt + waiters::Vector{Int} + input_streams::Dict{UInt,Any} # FIXME: Concrete type + output_streams::Dict{UInt,Any} # FIXME: Concrete type + input_buffers::Dict{UInt,B} + output_buffers::Dict{UInt,B} + input_buffer_amount::Int + output_buffer_amount::Int + input_fetchers::Dict{UInt,Any} + output_fetchers::Dict{UInt,Any} + open::Bool + migrating::Bool + lock::Threads.Condition + StreamStore{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} = + new{T,B}(uid, zeros(Int, 0), + Dict{UInt,Any}(), Dict{UInt,Any}(), + Dict{UInt,B}(), Dict{UInt,B}(), + input_buffer_amount, output_buffer_amount, + Dict{UInt,Any}(), Dict{UInt,Any}(), + true, false, Threads.Condition()) +end + +function tid_to_uid(thunk_id) + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end +end + +function Base.put!(store::StreamStore{T,B}, value) where {T,B} + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)" + for output_uid in keys(store.output_streams) + if !haskey(store.output_buffers, output_uid) + initialize_output_stream!(store, output_uid) + end + buffer = store.output_buffers[output_uid] + while isfull(buffer) + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting" + wait(store.lock) + if !isfull(buffer) + @dagdebug thunk_id :stream "buffer has space ($(length(buffer)) values), continuing" + end + task_may_cancel!() + end + put!(buffer, value) + end + notify(store.lock) + end +end + +function Base.take!(store::StreamStore, id::UInt) + thunk_id = STREAM_THUNK_ID[] + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + error("Must first check isempty(store, id) before taking from a stream") + end + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug thunk_id :stream "no elements, not taking" + wait(store.lock) + task_may_cancel!() + end + @dagdebug thunk_id :stream "wait finished" + if !isopen(store, id) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + unlock(store.lock) + value = try + take!(buffer) + finally + lock(store.lock) + end + @dagdebug thunk_id :stream "value accepted" + notify(store.lock) + return value + end +end +function wait_for_nonfull_input(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.input_streams, id) + @assert haskey(store.input_buffers, id) + buffer = store.input_buffers[id] + while isfull(buffer) && isopen(store) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for space in input buffer" + wait(store.lock) + end + end +end +function wait_for_nonempty_output(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.output_streams, id) + + # Wait for the output buffer to be initialized + while !haskey(store.output_buffers, id) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be initialized" + wait(store.lock) + end + isopen(store, id) || return + + # Wait for the output buffer to be nonempty + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be nonempty" + wait(store.lock) + end + end +end + +function Base.isempty(store::StreamStore, id::UInt) + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return true + end + return isempty(store.output_buffers[id]) +end +isfull(store::StreamStore, id::UInt) = isfull(store.output_buffers[id]) + +"Returns whether the store is actively open. Only check this when deciding if new values can be pushed." +Base.isopen(store::StreamStore) = store.open + +""" +Returns whether the store is actively open, or if closing, still has remaining +messages for `id`. Only check this when deciding if existing values can be +taken. +""" +function Base.isopen(store::StreamStore, id::UInt) + @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return store.open + end + if !isempty(store.output_buffers[id]) + return true + end + return store.open + end +end + +function Base.close(store::StreamStore) + store.open || return + store.open = false + @lock store.lock begin + for buffer in values(store.input_buffers) + close(buffer) + end + for buffer in values(store.output_buffers) + close(buffer) + end + notify(store.lock) + end +end + +# FIXME: Just pass Stream directly, rather than its uid +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Pair{UInt,Any}}) where {T,B} + our_uid = store.uid + @lock store.lock begin + for (output_uid, output_fetcher) in waiters + store.output_streams[output_uid] = task_to_stream(output_uid) + push!(store.waiters, output_uid) + store.output_fetchers[output_uid] = output_fetcher + end + notify(store.lock) + end +end + +function remove_waiters!(store::StreamStore, waiters::Vector{UInt}) + @lock store.lock begin + for w in waiters + delete!(store.output_buffers, w) + idx = findfirst(wo->wo==w, store.waiters) + deleteat!(store.waiters, idx) + delete!(store.input_streams, w) + end + notify(store.lock) + end +end + +mutable struct Stream{T,B} + uid::UInt + store::Union{StreamStore{T,B},Nothing} + store_ref::Chunk + function Stream{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} + # Creates a new output stream + store = StreamStore{T,B}(uid, input_buffer_amount, output_buffer_amount) + store_ref = tochunk(store) + return new{T,B}(uid, store, store_ref) + end + function Stream(stream::Stream{T,B}) where {T,B} + # References an existing output stream + return new{T,B}(stream.uid, nothing, stream.store_ref) + end +end + +struct StreamCancelledException <: Exception end +struct StreamingValue{B} + buffer::B +end +Base.take!(sv::StreamingValue) = take!(sv.buffer) + +function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::Stream{IT,IB}) where {IT,OT,IB,OB} + input_uid = input_stream.uid + our_uid = our_store.uid + local buffer, input_fetcher + @lock our_store.lock begin + if haskey(our_store.input_buffers, input_uid) + return StreamingValue(our_store.input_buffers[input_uid]) + end + + buffer = initialize_stream_buffer(OB, IT, our_store.input_buffer_amount) + # FIXME: Also pass a RemoteChannel to track remote closure + our_store.input_buffers[input_uid] = buffer + input_fetcher = our_store.input_fetchers[input_uid] + end + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while isopen(our_store) + stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) + end + catch err + if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow(err) + end + finally + @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" + end + end) + return StreamingValue(buffer) +end +initialize_input_stream!(our_store::StreamStore, arg) = arg +function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} + @assert islocked(our_store.lock) + @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" + buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) + our_store.output_buffers[output_uid] = buffer + our_uid = our_store.uid + output_stream = our_store.output_streams[output_uid] + output_fetcher = our_store.output_fetchers[output_uid] + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while true + if !isopen(our_store) && isempty(buffer) + # Only exit if the buffer is empty; otherwise, we need to + # continue draining it + break + end + stream_push_values!(output_fetcher, T, our_store, output_stream, buffer) + end + catch err + if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow(err) + end + finally + @dagdebug thunk_id :stream "output stream closed" + end + end) +end + +Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) + +function Base.isopen(stream::Stream, id::UInt)::Bool + return MemPool.access_ref(stream.store_ref.handle, id) do store, id + return isopen(store::StreamStore, id) + end +end + +function Base.close(stream::Stream) + MemPool.access_ref(stream.store_ref.handle) do store + close(store::StreamStore) + return + end + return +end + +function add_waiters!(stream::Stream, waiters::Vector{Pair{UInt,Any}}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + add_waiters!(store::StreamStore, waiters) + return + end + return +end + +add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream, UInt[waiter]) + +function remove_waiters!(stream::Stream, waiters::Vector{UInt}) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + remove_waiters!(store::StreamStore, waiters) + return + end + return +end + +remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream, Int[waiter]) + +struct StreamingFunction{F, S} + f::F + stream::S + max_evals::Int + + StreamingFunction(f::F, stream::S, max_evals) where {F, S} = + new{F, S}(f, stream, max_evals) +end + +function migrate_stream!(stream::Stream, w::Integer=myid()) + # Perform migration of the StreamStore + # MemPool will block access to the new ref until the migration completes + # FIXME: Do this with MemPool.access_ref, in case stream was already migrated + if stream.store_ref.handle.owner != w + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "Beginning migration... ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + + # TODO: Wire up listener to ferry cancel_token notifications to remote worker + tls = get_tls() + @assert w == myid() "Only pull-based migration is currently supported" + #remote_cancel_token = clone_cancel_token_remote(get_tls().cancel_token, worker_id) + + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; + pre_migration=store->begin + # Lock store to prevent any further modifications + # N.B. Serialization automatically unlocks the migrated copy + lock((store::StreamStore).lock) + + # Return the serializeable unsent inputs/outputs. We can't send the + # buffers themselves because they may be mmap'ed or something. + unsent_inputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.input_buffers) + unsent_outputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.output_buffers) + empty!(store.input_buffers) + empty!(store.output_buffers) + return (unsent_inputs, unsent_outputs) + end, + dest_post_migration=(store, unsent)->begin + # Initialize the StreamStore on the destination with the unsent inputs/outputs. + STREAM_THUNK_ID[] = thunk_id + @assert !in_task() + set_tls!(tls) + #get_tls().cancel_token = MemPool.access_ref(identity, remote_cancel_token; local_only=true) + unsent_inputs, unsent_outputs = unsent + for (input_uid, inputs) in unsent_inputs + input_stream = store.input_streams[input_uid] + initialize_input_stream!(store, input_stream) + for item in inputs + put!(store.input_buffers[input_uid], item) + end + end + for (output_uid, outputs) in unsent_outputs + initialize_output_stream!(store, output_uid) + for item in outputs + put!(store.output_buffers[output_uid], item) + end + end + + # Reset the state of this new store + store.open = true + store.migrating = false + end, + post_migration=store->begin + # Indicate that this store has migrated + store.migrating = true + store.open = false + + # Unlock the store + unlock((store::StreamStore).lock) + end) + if w == myid() + stream.store_ref.handle = new_store_ref # FIXME: It's not valid to mutate the Chunk handle, but we want to update this to enable fast location queries + stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) + end + + @dagdebug thunk_id :stream "Migration complete ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + end +end + +struct StreamingTaskQueue <: AbstractTaskQueue + tasks::Vector{Pair{DTaskSpec,DTask}} + self_streams::Dict{UInt,Any} + StreamingTaskQueue() = new(Pair{DTaskSpec,DTask}[], + Dict{UInt,Any}()) +end + +function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) + push!(queue.tasks, spec) + initialize_streaming!(queue.self_streams, spec...) +end + +function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) + append!(queue.tasks, specs) + for (spec, task) in specs + initialize_streaming!(queue.self_streams, spec, task) + end +end + +function initialize_streaming!(self_streams, spec, task) + if !isa(spec.f, StreamingFunction) + # Calculate the return type of the called function + T_old = Base.uniontypes(task.metadata.return_type) + T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) + # N.B. We treat non-dominating error paths as unreachable + T_old = filter(t->t !== Union{}, T_old) + T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + + # Get input buffer configuration + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + if input_buffer_amount <= 0 + throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) + end + + # Get output buffer configuration + output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + if output_buffer_amount <= 0 + throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) + end + + # Create the Stream + buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer) + stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount) + self_streams[task.uid] = stream + + # Get max evaluation count + max_evals = get(spec.options, :stream_max_evals, -1) + if max_evals == 0 + throw(ArgumentError("stream_max_evals cannot be 0")) + end + + spec.f = StreamingFunction(spec.f, stream, max_evals) + spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) + + # Register Stream globally + remotecall_wait(1, task.uid, stream) do uid, stream + lock(EAGER_THUNK_STREAMS) do global_streams + global_streams[uid] = stream + end + end + end +end + +function spawn_streaming(f::Base.Callable) + queue = StreamingTaskQueue() + result = with_options(f; task_queue=queue) + if length(queue.tasks) > 0 + finalize_streaming!(queue.tasks, queue.self_streams) + enqueue!(queue.tasks) + end + return result +end + +struct FinishStream{T,R} + value::Union{Some{T},Nothing} + result::R +end + +finish_stream(value::T; result::R=nothing) where {T,R} = FinishStream{T,R}(Some{T}(value), result) + +finish_stream(; result::R=nothing) where R = FinishStream{Union{},R}(nothing, result) + +const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) + +chunktype(sf::StreamingFunction{F}) where F = F + +struct StreamMigrating end + +function (sf::StreamingFunction)(args...; kwargs...) + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # Migrate our output stream store to this worker + if sf.stream isa Stream + remote_cancel_token = migrate_stream!(sf.stream) + end + + @label start + @dagdebug thunk_id :stream "Starting StreamingFunction" + worker_id = sf.stream.store_ref.handle.owner + result = if worker_id == myid() + _run_streamingfunction(nothing, nothing, sf, args...; kwargs...) + else + tls = get_tls() + remotecall_fetch(_run_streamingfunction, worker_id, tls, remote_cancel_token, sf, args...; kwargs...) + end + if result === StreamMigrating() + @goto start + end + return result +end + +function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...) + @nospecialize sf args kwargs + + store = sf.stream.store = MemPool.access_ref(identity, sf.stream.store_ref.handle; local_only=true) + @assert isopen(store) + + if tls !== nothing + # Setup TLS on this new task + tls.cancel_token = MemPool.access_ref(identity, cancel_token; local_only=true) + set_tls!(tls) + end + + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # FIXME: Remove when scheduler is distributed + uid = remotecall_fetch(1, thunk_id) do thunk_id + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end + end + + try + # TODO: This kwarg song-and-dance is required to ensure that we don't + # allocate boxes within `stream!`, when possible + kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) + kwarg_values = map(last, (kwargs...,)) + args = map(arg->initialize_input_stream!(store, arg), args) + kwarg_values = map(kwarg->initialize_input_stream!(store, kwarg), kwarg_values) + return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) + finally + if !sf.stream.store.migrating + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end + end + for stream in streams + @dagdebug thunk_id :stream "dropping waiter" + remove_waiters!(stream, uid) + @dagdebug thunk_id :stream "dropped waiter" + end + + # Ensure downstream tasks also terminate + close(sf.stream) + @dagdebug thunk_id :stream "closed stream store" + end + end +end + +# N.B We specialize to minimize/eliminate allocations +function stream!(sf::StreamingFunction, uid, + args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + f = move(thunk_processor(), sf.f) + counter = 0 + + while true + # Yield to other (streaming) tasks + yield() + + # Exit streaming on cancellation + task_may_cancel!() + + # Exit streaming on migration + if sf.stream.store.migrating + error("FIXME: max_evals should be retained") + @dagdebug STREAM_THUNK_ID[] :stream "returning for migration" + return StreamMigrating() + end + + # Get values from Stream args/kwargs + stream_args = _stream_take_values!(args) + stream_kwarg_values = _stream_take_values!(kwarg_values) + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + + if length(stream_args) > 0 || length(stream_kwarg_values) > 0 + # Notify tasks that input buffers may have space + @lock sf.stream.store.lock notify(sf.stream.store.lock) + end + + # Run a single cycle of f + counter += 1 + @dagdebug STREAM_THUNK_ID[] :stream "executing $f (eval $counter)" + stream_result = f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishStream + if stream_result.value !== nothing + value = something(stream_result.value) + put!(sf.stream, value) + end + @dagdebug STREAM_THUNK_ID[] :stream "voluntarily returning" + return stream_result.result + end + + # Put the result into the output stream + put!(sf.stream, stream_result) + + # Exit streaming on eval limit + if sf.max_evals > 0 && counter >= sf.max_evals + @dagdebug STREAM_THUNK_ID[] :stream "max evals reached ($counter)" + return + end + end +end + +function _stream_take_values!(args) + return ntuple(length(args)) do idx + arg = args[idx] + if arg isa StreamingValue + return take!(arg) + else + return arg + end + end +end + +@inline @generated function _stream_namedtuple(kwarg_names::Tuple, + stream_kwarg_values::Tuple) + name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) + NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) + return :($NT(stream_kwarg_values)) +end + +initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) + +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) +function task_to_stream(uid::UInt) + if myid() != 1 + return remotecall_fetch(task_to_stream, 1, uid) + end + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + return global_streams[uid] + end + return + end +end + +function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) + stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() + + for (spec, task) in tasks + @assert haskey(self_streams, task.uid) + our_stream = self_streams[task.uid] + + # Adapt args to accept Stream output of other streaming tasks + for (idx, (pos, arg)) in enumerate(spec.args) + if arg isa DTask + # Check if this is a streaming task + if haskey(self_streams, arg.uid) + other_stream = self_streams[arg.uid] + else + other_stream = task_to_stream(arg.uid) + end + + if other_stream !== nothing + # Generate Stream handle for input + # FIXME: Be configurable + input_fetcher = RemoteChannelFetcher() + other_stream_handle = Stream(other_stream) + spec.args[idx] = pos => other_stream_handle + our_stream.store.input_streams[arg.uid] = other_stream_handle + our_stream.store.input_fetchers[arg.uid] = input_fetcher + + # Add this task as a waiter for the associated output Stream + changes = get!(stream_waiter_changes, arg.uid) do + Pair{UInt,Any}[] + end + push!(changes, task.uid => input_fetcher) + end + end + end + + # Filter out all streaming options + to_filter = (:stream_buffer_type, + :stream_input_buffer_amount, :stream_output_buffer_amount, + :stream_max_evals) + spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), + Base.pairs(spec.options))) + if haskey(spec.options, :propagates) + propagates = filter(opt -> !(opt in to_filter), + spec.options.propagates) + spec.options = merge(spec.options, (;propagates)) + end + end + + # Adjust waiter count of Streams with dependencies + for (uid, waiters) in stream_waiter_changes + stream = task_to_stream(uid) + add_waiters!(stream, waiters) + end +end diff --git a/src/submission.jl b/src/submission.jl index bfb8cb8be..f23539271 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -220,7 +220,7 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) end function DTaskMetadata(spec::DTaskSpec) - f = chunktype(spec.f).instance + f = spec.f isa StreamingFunction ? spec.f.f : spec.f arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) return_type = Base.promote_op(f, arg_types...) return DTaskMetadata(return_type) diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 8b6d3530f..6a71e5c52 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -2,7 +2,8 @@ function istask end function task_id end const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope, - :take, :execute, :move, :processor, :cancel] + :take, :execute, :move, :processor, :cancel, + :stream] macro dagdebug(thunk, category, msg, args...) cat_sym = category.value @gensym id diff --git a/test/runtests.jl b/test/runtests.jl index aa5b34aec..67d25276e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ tests = [ ("Mutation", "mutation.jl"), ("Task Queues", "task-queues.jl"), ("Datadeps", "datadeps.jl"), + ("Streaming", "streaming.jl"), ("Domain Utilities", "domain.jl"), ("Array - Allocation", "array/allocation.jl"), ("Array - Indexing", "array/indexing.jl"), @@ -103,7 +104,6 @@ else @info "Running all tests" end - using Distributed if additional_workers > 0 # We put this inside a branch because addprocs() takes a minimum of 1s to diff --git a/test/streaming.jl b/test/streaming.jl new file mode 100644 index 000000000..b7d3ce328 --- /dev/null +++ b/test/streaming.jl @@ -0,0 +1,426 @@ +@everywhere function rand_finite(T=Float64) + x = rand(T) + if rand() < 0.1 + return Dagger.finish_stream(x) + end + return x +end +@everywhere function rand_finite_returns(T=Float64) + x = rand(T) + if rand() < 0.1 + return Dagger.finish_stream(x; result=x) + end + return x +end + +const ACCUMULATOR = Dict{Int,Vector{Real}}() +@everywhere function accumulator(x=0) + tid = Dagger.task_id() + remotecall_wait(1, tid, x) do tid, x + acc = get!(Vector{Real}, ACCUMULATOR, tid) + push!(acc, x) + end + return +end +@everywhere accumulator(xs...) = accumulator(sum(xs)) +@everywhere accumulator(::Nothing) = accumulator(0) + +function catch_interrupt(f) + try + f() + catch err + if err isa Dagger.DTaskFailedException && err.ex isa InterruptException + return + elseif err isa Dagger.Sch.SchedulingException + return + end + rethrow(err) + end +end +function merge_testset!(inner::Test.DefaultTestSet) + outer = Test.get_testset() + append!(outer.results, inner.results) + outer.n_passed += inner.n_passed +end +function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) + t = @eval Threads.@spawn begin + tset = nothing + try + @testset $message begin + try + @testset $message begin + Dagger.with_options(;stream_max_evals=$max_evals) do + catch_interrupt($f) + end + end + finally + tset = Test.get_testset() + end + end + catch + end + return tset + end + timed_out = timedwait(()->istaskdone(t), 10) == :timed_out + if timed_out + if !ignore_timeout + @warn "Testing task timed out: $message" + end + Dagger.cancel!(;halt_sch=true) + @everywhere GC.gc() + fetch(Dagger.@spawn 1+1) + end + tset = fetch(t)::Test.DefaultTestSet + merge_testset!(tset) + return !timed_out +end + +all_scopes = [Dagger.ExactScope(proc) for proc in Dagger.all_processors()] +for idx in 1:5 + if idx == 1 + scopes = [Dagger.scope(worker = 1, thread = 1)] + scope_str = "Worker 1" + elseif idx == 2 && nprocs() > 1 + scopes = [Dagger.scope(worker = 2, thread = 1)] + scope_str = "Worker 2" + else + scopes = all_scopes + scope_str = "All Workers" + end + + @testset "Single Task Control Flow ($scope_str)" begin + @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + y = rand() + sleep(1) + return y + end + end + @test_throws_unwrap InterruptException fetch(x) + end + + @test test_finishes("Single task without result") do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + @test fetch(x) === nothing + end + + @test test_finishes("Single task with result"; max_evals=1_000_000) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + x = rand() + if x < 0.1 + return Dagger.finish_stream(x; result=123) + end + return x + end + end + @test fetch(x) == 123 + end + end + + @testset "Non-Streaming Inputs ($scope_str)" begin + @test test_finishes("() -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(0), values[A_tid]) + end + @test test_finishes("42 -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42), values[A_tid]) + end + @test test_finishes("(42, 43) -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42 + 43), values[A_tid]) + end + end + + @testset "Non-Streaming Outputs ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + end + @test test_finishes("x -> (A, B)") do + local x, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + B = Dagger.@spawn accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[B_tid]) + end + end + + @testset "Multiple Tasks ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, A)") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(1.0) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v == 1, values[A_tid]) + end + + @test test_finishes("x -> y -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 1 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, A)") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end + + @test test_finishes("(x, y) -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x, y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end + + @test test_finishes("x -> (y, z) -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x + 1 + z = Dagger.@spawn scope=rand(scopes) x + 2 + A = Dagger.@spawn scope=rand(scopes) accumulator(y, z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 3 <= v <= 5, values[A_tid]) + end + + @test test_finishes("(x, y) -> z -> (A, B)") do + local x, y, z, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + B = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[B_tid]) + end + + for T in (Float64, Int32, BigFloat) + @test test_finishes("Stream eltype $T") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand(T) + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v isa T, values[A_tid]) + end + end + end + + @testset "Max Evals ($scope_str)" begin + @test test_finishes("max_evals=0"; max_evals=0) do + @test_throws ArgumentError Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + end + @test test_finishes("max_evals=1"; max_evals=1) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + end + @test test_finishes("max_evals=100"; max_evals=100) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 100 + end + end + + @testset "DropBuffer ($scope_str)" begin + # TODO: Test that accumulator never gets called + @test !test_finishes("x (drop)-> A"; ignore_timeout=true) do + local x, A + Dagger.spawn_streaming() do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + x = Dagger.@spawn scope=rand(scopes) rand() + end + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing + end + @test !test_finishes("x ->(drop) A"; ignore_timeout=true) do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + end + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing + end + @test !test_finishes("x -(drop)> A"; ignore_timeout=true) do + local x, A + Dagger.spawn_streaming() do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + end + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing + end + end + + # FIXME: Varying buffer amounts + + #= TODO: Zero-allocation test + # First execution of a streaming task will almost guaranteed allocate (compiling, setup, etc.) + # BUT, second and later executions could possibly not allocate any further ("steady-state") + # We want to be able to validate that the steady-state execution for certain tasks is non-allocating + =# +end From ec44e08c96ffc3f68ea4b24892ead428dfd08c95 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 30 Nov 2024 23:22:01 +0100 Subject: [PATCH 20/23] Remove duplicate errormonitor() This shouldn't be necessary since we `wait()` for the given task in another `errormonitor()`'d task. --- src/sch/util.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index eb5a285b4..dc4497c77 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -1,6 +1,5 @@ "Like `errormonitor`, but tracks how many outstanding tasks are running." function errormonitor_tracked(name::String, t::Task) - errormonitor(t) @safe_lock_spin1 ERRORMONITOR_TRACKED tracked begin push!(tracked, name => t) end From e43151faa14b37912d8f7c2c0d631d8553c26b55 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sun, 1 Dec 2024 01:26:11 +0100 Subject: [PATCH 21/23] Allow `nothing` to be thrown in dynamic_listener!() This was observed when running the `Single task running forever` test with multiple threads. --- src/sch/dynamic.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index 5b917fdb5..9d0213928 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -59,9 +59,10 @@ function dynamic_listener!(ctx, state, wid) tid, f, data = try take!(inp_chan) catch err - if !(unwrap_nested_exception(err) isa Union{SchedulerHaltedException, - ProcessExitedException, - InvalidStateException}) + if (!isnothing(err) && # `nothing` appears sometimes to be thrown upon cancellation + !(unwrap_nested_exception(err) isa Union{SchedulerHaltedException, + ProcessExitedException, + InvalidStateException})) iob = IOContext(IOBuffer(), :color=>true) println(iob, "Error in receiving dynamic request:") Base.showerror(iob, err) From 09bfd5699e62589488fabd80bf0b9b8face762de Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sun, 1 Dec 2024 01:28:01 +0100 Subject: [PATCH 22/23] Unwrap nested exceptions in the streaming task input handlers Necessary because `stream_pull_values!()` may throw different exceptions depending on whether the exception occurred locally or remotely. --- src/sch/util.jl | 2 ++ src/stream.jl | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index dc4497c77..66d308e3b 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -30,6 +30,8 @@ unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) unwrap_nested_exception(err::DTaskFailedException) = unwrap_nested_exception(err.ex) +unwrap_nested_exception(err::TaskFailedException) = + unwrap_nested_exception(err.task.exception) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." diff --git a/src/stream.jl b/src/stream.jl index f3c2cfbcb..7b69ccec6 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -237,7 +237,8 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) end catch err - if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + unwrapped_err = Sch.unwrap_nested_exception(err) + if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && !isopen(buffer)) return else rethrow(err) From 96a0c46e67e2a25bc42cfa08032f9a1ffac0ad5a Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sun, 1 Dec 2024 01:30:07 +0100 Subject: [PATCH 23/23] Wait for input/output handlers to finish when closing a StreamStore Otherwise it's possible that the scheduler will close the DTask before all outputs have been sent, which would cause the downstream tasks to hang. This is how it could happen: 1. A streaming task starts. 2. The output handler task calls `take!(::ProcessRingBuffer)` on an output buffer, finds it empty, and `yield()`'s. 3. The task executes, pushes its output to the output buffers, reaches `max_evals` and finishes. 4. The scheduler finishes the corresponding DTask. 5. The `take!(::ProcessRingBuffer)` call resumes. The buffer isn't empty anymore but it calls `task_may_cancel(; must_force=true)` before continuing and throws an exception since the scheduler has finished the DTask. The result is that the last output is never sent, and the exeption is swallowed by the output handler started by `initialize_output_stream!()`. 6. Downstream tasks don't get that last result so they never reach `max_evals` and spin forever. Fixed by storing the handler tasks in the `StreamStore` and closing them in `close(::StreamStore)`. Also increased the timeout of the 'Single task running forever' task because it will sometimes timeout before the default 10s is up. --- src/stream.jl | 31 ++++++++++++++++++++++++++++--- test/streaming.jl | 9 ++++++--- 2 files changed, 34 insertions(+), 6 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index 7b69ccec6..a00a0d237 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -12,13 +12,18 @@ mutable struct StreamStore{T,B} open::Bool migrating::Bool lock::Threads.Condition + + input_handlers::Dict{UInt, Task} + output_handlers::Dict{UInt, Task} + StreamStore{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} = new{T,B}(uid, zeros(Int, 0), Dict{UInt,Any}(), Dict{UInt,Any}(), Dict{UInt,B}(), Dict{UInt,B}(), input_buffer_amount, output_buffer_amount, Dict{UInt,Any}(), Dict{UInt,Any}(), - true, false, Threads.Condition()) + true, false, Threads.Condition(), + Dict{UInt, Task}(), Dict{UInt, Task}()) end function tid_to_uid(thunk_id) @@ -164,6 +169,19 @@ function Base.close(store::StreamStore) end notify(store.lock) end + + # We have to close the input fetchers for the input handlers to finish + for fetcher in values(store.input_fetchers) + close(fetcher.chan) + end + + # Wait for the handlers to finish + for handler in values(store.input_handlers) + wait(handler) + end + for handler in values(store.output_handlers) + wait(handler) + end end # FIXME: Just pass Stream directly, rather than its uid @@ -229,7 +247,7 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S end thunk_id = STREAM_THUNK_ID[] tls = get_tls() - Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin + t = Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin set_tls!(tls) STREAM_THUNK_ID[] = thunk_id try @@ -247,9 +265,14 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" end end) + + our_store.input_handlers[input_uid] = t + return StreamingValue(buffer) end + initialize_input_stream!(our_store::StreamStore, arg) = arg + function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} @assert islocked(our_store.lock) @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" @@ -260,7 +283,7 @@ function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt output_fetcher = our_store.output_fetchers[output_uid] thunk_id = STREAM_THUNK_ID[] tls = get_tls() - Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin + t = Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin set_tls!(tls) STREAM_THUNK_ID[] = thunk_id try @@ -282,6 +305,8 @@ function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt @dagdebug thunk_id :stream "output stream closed" end end) + + our_store.output_handlers[output_uid] = t end Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) diff --git a/test/streaming.jl b/test/streaming.jl index b7d3ce328..ea2e8fb2b 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -37,12 +37,14 @@ function catch_interrupt(f) rethrow(err) end end + function merge_testset!(inner::Test.DefaultTestSet) outer = Test.get_testset() append!(outer.results, inner.results) outer.n_passed += inner.n_passed end -function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) + +function test_finishes(f, message::String; timeout=10, ignore_timeout=false, max_evals=10) t = @eval Threads.@spawn begin tset = nothing try @@ -61,7 +63,7 @@ function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) end return tset end - timed_out = timedwait(()->istaskdone(t), 10) == :timed_out + timed_out = timedwait(()->istaskdone(t), timeout) == :timed_out if timed_out if !ignore_timeout @warn "Testing task timed out: $message" @@ -89,7 +91,7 @@ for idx in 1:5 end @testset "Single Task Control Flow ($scope_str)" begin - @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do + @test !test_finishes("Single task running forever"; timeout=15, max_evals=1_000_000, ignore_timeout=true) do local x Dagger.spawn_streaming() do x = Dagger.@spawn scope=rand(scopes) () -> begin @@ -98,6 +100,7 @@ for idx in 1:5 return y end end + @test_throws_unwrap InterruptException fetch(x) end