diff --git a/src/Dagger.jl b/src/Dagger.jl index c0cb23526..fa30c7c1a 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -60,6 +60,8 @@ include("scopes.jl") include("utils/scopes.jl") include("chunks.jl") include("utils/signature.jl") +include("thunkid.jl") +include("utils/lfucache.jl") include("options.jl") include("dtask.jl") include("cancellation.jl") diff --git a/src/cancellation.jl b/src/cancellation.jl index 748c2aacd..ff9e19fcb 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -129,8 +129,9 @@ function _cancel!(state, tid, force, graceful, halt_sch) wids = unique(map(root_worker_id, values(state.running_on))) for wid in wids remotecall_fetch(wid, tid, sch_uid, force) do _tid, sch_uid, force - Dagger.Sch.proc_states(sch_uid) do states - for (proc, state) in states + states = Dagger.Sch.proc_states(sch_uid) + MemPool.lock_read(states.lock) do + for (proc, state) in states.dict istate = state.state any_cancelled = false @lock istate.queue begin diff --git a/src/compute.jl b/src/compute.jl index f655cbac4..189c7ec24 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -48,7 +48,7 @@ end Find the set of direct dependents for each task. """ function dependents(node::Thunk) - deps = Dict{Union{Thunk,Chunk}, Set{Thunk}}() + deps = Dict{Thunk, Set{Thunk}}() visited = Set{Thunk}() to_visit = Set{Thunk}() push!(to_visit, node) @@ -58,13 +58,11 @@ function dependents(node::Thunk) if !haskey(deps, next) deps[next] = Set{Thunk}() end - for inp in next.options.syncdeps - if istask(inp) || (inp isa Chunk) - s = get!(()->Set{Thunk}(), deps, inp) - push!(s, next) - if istask(inp) && !(inp in visited) - push!(to_visit, inp) - end + for inp in Iterators.map(syncdep->unwrap_weak_checked(something(syncdep.thunk)), next.options.syncdeps) + s = get!(()->Set{Thunk}(), deps, inp) + push!(s, next) + if istask(inp) && !(inp in visited) + push!(to_visit, inp) end end push!(visited, next) @@ -73,14 +71,14 @@ function dependents(node::Thunk) end """ - noffspring(dpents::Dict{Union{Thunk,Chunk}, Set{Thunk}}) -> Dict{Thunk, Int} + noffspring(dpents::Dict{Thunk, Set{Thunk}}) -> Dict{Thunk, Int} Recursively find the number of tasks dependent on each task in the DAG. Takes a Dict as returned by [`dependents`](@ref). """ -function noffspring(dpents::Dict{Union{Thunk,Chunk}, Set{Thunk}}) +function noffspring(dpents::Dict{Thunk, Set{Thunk}}) noff = Dict{Thunk,Int}() - to_visit = collect(filter(istask, keys(dpents))) + to_visit = collect(keys(dpents)) while !isempty(to_visit) next = popfirst!(to_visit) haskey(noff, next) && continue @@ -126,7 +124,7 @@ function order(node::Thunk, ndeps) haskey(output, next) && continue s += 1 output[next] = s - parents = collect(filter(istask, next.options.syncdeps)) + parents = collect(Iterators.map(syncdep->unwrap_weak_checked(something(syncdep.thunk)), next.options.syncdeps)) if !isempty(parents) # If parents is empty, sort! should be a no-op, but raises an ambiguity error # when InlineStrings.jl is loaded (at least, version 1.1.0), because InlineStrings diff --git a/src/datadeps.jl b/src/datadeps.jl index eee2a6ece..225bcd5a4 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -396,7 +396,7 @@ function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::Ab other_task, other_write_num = other_task_write_num write_num == other_write_num && continue @dagdebug nothing :spawn_datadeps "Sync with writer via $ainfo -> $other_ainfo" - push!(syncdeps, other_task) + push!(syncdeps, ThunkSyncdep(other_task)) end end function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps) @@ -408,7 +408,7 @@ function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::Abs for (other_task, other_write_num) in other_tasks write_num == other_write_num && continue @dagdebug nothing :spawn_datadeps "Sync with reader via $ainfo -> $other_ainfo" - push!(syncdeps, other_task) + push!(syncdeps, ThunkSyncdep(other_task)) end end end @@ -427,14 +427,14 @@ function _get_write_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, t if other_task_write_num !== nothing other_task, other_write_num = other_task_write_num if write_num != other_write_num - push!(syncdeps, other_task) + push!(syncdeps, ThunkSyncdep(other_task)) end end end function _get_read_deps!(state::DataDepsState{DataDepsNonAliasingState}, arg, task, write_num, syncdeps) for (other_task, other_write_num) in state.alias_state.args_readers[arg] if write_num != other_write_num - push!(syncdeps, other_task) + push!(syncdeps, ThunkSyncdep(other_task)) end end end @@ -590,6 +590,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) write_num = 1 proc_idx = 1 pressures = Dict{Processor,Int}() + proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024) for (spec, task) in queue.seen_tasks[task_order] # Populate all task dependencies populate_task_info!(state, spec, task) @@ -723,9 +724,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end @assert our_proc in all_procs our_space = only(memory_spaces(our_proc)) - our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) - task_scope = @something(spec.options.scope, AnyScope()) - our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + + # Find the scope for this task (and its copies) + task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope()) + if task_scope == scope + # Optimize for the common case, cache the proc=>scope mapping + our_scope = get!(proc_to_scope_lfu, our_proc) do + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + return constrain(UnionScope(map(ExactScope, our_procs)...), scope) + end + else + # Use the provided scope and constrain it to the available processors + our_procs = filter(proc->proc in all_procs, collect(processors(our_space))) + our_scope = constrain(UnionScope(map(ExactScope, our_procs)...), task_scope) + end if our_scope isa InvalidScope throw(Sch.SchedulingException("Scopes are not compatible: $(our_scope.x), $(our_scope.y)")) end @@ -769,10 +781,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) generate_slot!(state, data_space, arg) end copy_to_scope = our_scope - copy_to_syncdeps = Set{Any}() + copy_to_syncdeps = Set{ThunkSyncdep}() get_write_deps!(state, ainfo, task, write_num, copy_to_syncdeps) @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx][$dep_mod] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) + copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps meta=true Dagger.move!(dep_mod, our_space, data_space, arg_remote, arg_local) add_writer!(state, ainfo, copy_to, write_num) astate.data_locality[ainfo] = our_space @@ -790,10 +802,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) generate_slot!(state, data_space, arg) end copy_to_scope = our_scope - copy_to_syncdeps = Set{Any}() + copy_to_syncdeps = Set{ThunkSyncdep}() get_write_deps!(state, arg, task, write_num, copy_to_syncdeps) @dagdebug nothing :spawn_datadeps "($(repr(value(f))))[$idx] $(length(copy_to_syncdeps)) syncdeps" - copy_to = Dagger.@spawn scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) + copy_to = Dagger.@spawn scope=copy_to_scope exec_scope=copy_to_scope syncdeps=copy_to_syncdeps Dagger.move!(identity, our_space, data_space, arg_remote, arg_local) add_writer!(state, arg, copy_to, write_num) astate.data_locality[arg] = our_space @@ -820,7 +832,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Calculate this task's syncdeps if spec.options.syncdeps === nothing - spec.options.syncdeps = Set{Any}() + spec.options.syncdeps = Set{ThunkSyncdep}() end syncdeps = spec.options.syncdeps for (idx, (_, arg)) in enumerate(task_args) @@ -853,6 +865,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) # Launch user's task spec.options.scope = our_scope + spec.options.exec_scope = our_scope enqueue!(upper_queue, spec=>task) # Update read/write tracking for arguments @@ -947,10 +960,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) @assert arg_remote !== arg_local data_local_proc = first(processors(data_local_space)) copy_from_scope = UnionScope(map(ExactScope, collect(processors(data_local_space)))...) - copy_from_syncdeps = Set() + copy_from_syncdeps = Set{ThunkSyncdep}() get_write_deps!(state, ainfo, nothing, write_num, copy_from_syncdeps) @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) + copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(dep_mod, data_local_space, data_remote_space, arg_local, arg_remote) else @dagdebug nothing :spawn_datadeps "[$dep_mod] Skipped copy-from (local): $data_remote_space" end @@ -980,10 +993,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue) @assert arg_remote !== arg_local data_local_proc = first(processors(data_local_space)) copy_from_scope = ExactScope(data_local_proc) - copy_from_syncdeps = Set() + copy_from_syncdeps = Set{ThunkSyncdep}() get_write_deps!(state, arg, nothing, write_num, copy_from_syncdeps) @dagdebug nothing :spawn_datadeps "$(length(copy_from_syncdeps)) syncdeps" - copy_from = Dagger.@spawn scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) + copy_from = Dagger.@spawn scope=copy_from_scope exec_scope=copy_from_scope syncdeps=copy_from_syncdeps meta=true Dagger.move!(identity, data_local_space, data_remote_space, arg_local, arg_remote) else @dagdebug nothing :spawn_datadeps "Skipped copy-from (local): $data_remote_space" end diff --git a/src/dtask.jl b/src/dtask.jl index b74774287..e94803502 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -106,12 +106,15 @@ function Base.show(io::IO, t::DTask) print(io, "DTask ($status)") end istask(t::DTask) = true +function Base.convert(::Type{ThunkSyncdep}, task::Dagger.DTask) + return ThunkSyncdep(ThunkID(task.uid, isdefined(task, :thunk_ref) ? task.thunk_ref : nothing)) +end +ThunkSyncdep(task::DTask) = convert(ThunkSyncdep, task) -const EAGER_ID_COUNTER = Threads.Atomic{UInt64}(1) function eager_next_id() if myid() == 1 - Threads.atomic_add!(EAGER_ID_COUNTER, one(UInt64)) + return UInt64(next_id()) else - remotecall_fetch(eager_next_id, 1) + return remotecall_fetch(eager_next_id, 1)::UInt64 end end diff --git a/src/memory-spaces.jl b/src/memory-spaces.jl index 090c0a14f..2467dd625 100644 --- a/src/memory-spaces.jl +++ b/src/memory-spaces.jl @@ -134,6 +134,7 @@ equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash || equivalent_structure(x.inner, y.inner) Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h) Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash +Base.:(==)(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash will_alias(x::AliasingWrapper, y::AliasingWrapper) = will_alias(x.inner, y.inner) @@ -225,6 +226,8 @@ struct ContiguousAliasing{S} <: AbstractAliasing span::MemorySpan{S} end memory_spans(a::ContiguousAliasing{S}) where S = MemorySpan{S}[a.span] +will_alias(x::ContiguousAliasing{S}, y::ContiguousAliasing{S}) where S = + will_alias(x.span, y.span) struct IteratedAliasing{T} <: AbstractAliasing x::T end @@ -284,20 +287,29 @@ function aliasing(x::SubArray{T,N,A}) where {T,N,A<:Array} return UnknownAliasing() end end -function will_alias(x::StridedAliasing{T,N,S}, y::StridedAliasing{T,N,S}) where {T,N,S} +function will_alias(x::StridedAliasing{T1,N1,S1}, y::StridedAliasing{T2,N2,S2}) where {T1,T2,N1,N2,S1,S2} + # Check if the base pointers are the same + # FIXME: Conservatively incorrect via `unsafe_wrap` and friends if x.base_ptr != y.base_ptr - # FIXME: Conservatively incorrect via `unsafe_wrap` and friends return false end - for dim in 1:N - if ((x.base_inds[dim].stop) < (y.base_inds[dim].start) || (y.base_inds[dim].stop) < (x.base_inds[dim].start)) - return false + if T1 === T2 && N1 == N2 && may_alias(x.base_ptr.space, y.base_ptr.space) + # Check if the base indices overlap + for dim in 1:N1 + if ((x.base_inds[dim].stop) < (y.base_inds[dim].start) || (y.base_inds[dim].stop) < (x.base_inds[dim].start)) + return false + end end + return true + else + return invoke(will_alias, Tuple{Any, Any}, x, y) end - - return true end +will_alias(x::StridedAliasing{T,N,S}, y::ContiguousAliasing{S}) where {T,N,S} = + x.base_ptr == y.span.ptr +will_alias(x::ContiguousAliasing{S}, y::StridedAliasing{T,N,S}) where {T,N,S} = + will_alias(y, x) # FIXME: Upgrade Contiguous/StridedAlising to same number of dims struct TriangularAliasing{T,S} <: AbstractAliasing diff --git a/src/options.jl b/src/options.jl index 12e8c3dcf..eca59fbc9 100644 --- a/src/options.jl +++ b/src/options.jl @@ -10,6 +10,7 @@ Stores per-task options to be passed to the scheduler. - `processor::Processor`: The processor associated with this task's function. Generally ignored by the scheduler. - `compute_scope::AbstractScope`: The execution scope of the task, which determines where the task can be scheduled and executed. `scope` is another name for this option. - `result_scope::AbstractScope`: The data scope of the task's result, which determines where the task's result can be accessed from. +- `exec_scope::AbstractScope`: The execution scope of the task, which determines where the task can be scheduled and executed. Can be set to avoid computing the scope in the scheduler, when known. - `single::Int=0`: (Deprecated) Force task onto worker with specified id. `0` disables this option. - `proclist=nothing`: (Deprecated) Force task to use one or more processors that are instances/subtypes of a contained type. Alternatively, a function can be supplied, and the function will be called with a processor as the sole argument and should return a `Bool` result to indicate whether or not to use the given processor. `nothing` enables all default processors. - `get_result::Bool=false`: Whether the worker should store the result directly (`true`) or as a `Chunk` (`false`) @@ -37,13 +38,14 @@ Base.@kwdef mutable struct Options scope::Union{AbstractScope,Nothing} = nothing compute_scope::Union{AbstractScope,Nothing} = scope result_scope::Union{AbstractScope,Nothing} = nothing + exec_scope::Union{AbstractScope,Nothing} = nothing single::Union{Int,Nothing} = nothing proclist = nothing get_result::Union{Bool,Nothing} = nothing meta::Union{Bool,Nothing} = nothing - syncdeps::Union{Set{Any},Nothing} = nothing + syncdeps::Union{Set{ThunkSyncdep},Nothing} = nothing time_util::Union{Dict{Type,Any},Nothing} = nothing alloc_util::Union{Dict{Type,UInt64},Nothing} = nothing @@ -101,6 +103,15 @@ _set_option!(options::Base.Pairs, field, value) = error("Cannot set option in Ba end return ex end +function Base.setproperty!(options::Options, field::Symbol, value) + if field == :scope || field == :compute_scope || field == :result_scope + # If the scope is changed, we need to clear the exec_scope as it is no longer valid + setfield!(options, :exec_scope, nothing) + end + fidx = findfirst(==(field), fieldnames(Options)) + ftype = fieldtypes(Options)[fidx] + return setfield!(options, field, convert(ftype, value)) +end """ populate_defaults!(opts::Options, sig::Vector{DataType}) -> Options @@ -113,6 +124,7 @@ function populate_defaults!(opts::Options, sig) maybe_default!(opts, Val{:processor}(), sig) maybe_default!(opts, Val{:compute_scope}(), sig) maybe_default!(opts, Val{:result_scope}(), sig) + maybe_default!(opts, Val{:exec_scope}(), sig) maybe_default!(opts, Val{:single}(), sig) maybe_default!(opts, Val{:proclist}(), sig) maybe_default!(opts, Val{:get_result}(), sig) @@ -143,30 +155,6 @@ function maybe_default!(opts::Options, ::Val{opt}, sig::Signature) where opt end end -struct BasicLFUCache{K,V} - cache::Dict{K,V} - freq::Dict{K,Int} - max_size::Int - - BasicLFUCache{K,V}(max_size::Int) where {K,V} = new(Dict{K,V}(), Dict{K,Int}(), max_size) -end -function Base.get!(f, cache::BasicLFUCache{K,V}, key::K) where {K,V} - if haskey(cache.cache, key) - cache.freq[key] += 1 - return cache.cache[key] - end - val = f()::V - cache.cache[key] = val - cache.freq[key] = 1 - if length(cache.cache) > cache.max_size - # Find the least frequently used key - _, lfu_key::K = findmin(cache.freq) - delete!(cache.cache, lfu_key) - delete!(cache.freq, lfu_key) - end - return val -end - const SIGNATURE_DEFAULT_CACHE = TaskLocalValue{BasicLFUCache{Tuple{UInt,Symbol},Any}}(()->BasicLFUCache{Tuple{UInt,Symbol},Any}(256)) # SchedulerOptions integration diff --git a/src/queue.jl b/src/queue.jl index b0b0ea45d..c8c6007ec 100644 --- a/src/queue.jl +++ b/src/queue.jl @@ -45,9 +45,9 @@ end function _add_prev_deps!(queue::InOrderTaskQueue, spec::DTaskSpec) # Add previously-enqueued task(s) to this task's syncdeps opts = spec.options - syncdeps = opts.syncdeps = @something(opts.syncdeps, Set()) + syncdeps = opts.syncdeps = @something(opts.syncdeps, Set{ThunkSyncdep}()) for task in queue.prev_tasks - push!(syncdeps, task) + push!(syncdeps, ThunkSyncdep(task)) end end function enqueue!(queue::InOrderTaskQueue, spec::Pair{DTaskSpec,DTask}) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 41d160413..1aea1f319 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -14,7 +14,7 @@ import Random: randperm, randperm! import Base: @invokelatest import ..Dagger -import ..Dagger: Context, Processor, SchedulerOptions, Options, Thunk, WeakThunk, ThunkFuture, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject, Argument, Signature +import ..Dagger: Context, Processor, SchedulerOptions, Options, Thunk, WeakThunk, ThunkFuture, ThunkID, DTaskFailedException, Chunk, WeakChunk, OSProc, AnyScope, DefaultScope, InvalidScope, LockedObject, Argument, Signature import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, wrap_weak, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, default_enabled, processor, get_processors, get_parent, execute!, rmprocs!, task_processor, constrain, cputhreadtime, maybe_take_or_alloc! import ..Dagger: @dagdebug, @safe_lock_spin1, @maybelog, @take_or_alloc! import DataStructures: PriorityQueue, enqueue!, dequeue_pair!, peek @@ -80,8 +80,7 @@ struct ComputeState waiting::OneToMany waiting_data::Dict{Union{Thunk,Chunk},Set{Thunk}} ready::Vector{Thunk} - cache::WeakKeyDict{Thunk, Any} - valid::WeakKeyDict{Thunk, Nothing} + valid::Dict{Thunk, Nothing} running::Set{Thunk} running_on::Dict{Thunk,OSProc} thunk_dict::Dict{Int, WeakThunk} @@ -98,7 +97,7 @@ struct ComputeState halt::Base.Event lock::ReentrantLock futures::Dict{Thunk, Vector{ThunkFuture}} - errored::WeakKeyDict{Thunk,Bool} + errored::Dict{Thunk,Bool} thunks_to_delete::Set{Thunk} chan::RemoteChannel{Channel{AnyTaskResult}} end @@ -110,8 +109,7 @@ function start_state(deps::Dict, node_order, chan) OneToMany(), deps, Vector{Thunk}(undef, 0), - WeakKeyDict{Thunk, Any}(), - WeakKeyDict{Thunk, Nothing}(), + Dict{Thunk, Nothing}(), Set{Thunk}(), Dict{Thunk,OSProc}(), Dict{Int, WeakThunk}(), @@ -128,13 +126,13 @@ function start_state(deps::Dict, node_order, chan) Base.Event(), ReentrantLock(), Dict{Thunk, Vector{ThunkFuture}}(), - WeakKeyDict{Thunk,Bool}(), + Dict{Thunk,Bool}(), Set{Thunk}(), chan) for k in sort(collect(keys(deps)), by=node_order) if istask(k) - waiting = Set{Thunk}(Iterators.filter(istask, k.options.syncdeps)) + waiting = Set{Thunk}(Iterators.map(syncdep->unwrap_weak_checked(something(syncdep.thunk)), k.options.syncdeps)) if isempty(waiting) push!(state.ready, k) else @@ -223,13 +221,16 @@ function init_proc(state, p, log_sink) end function _cleanup_proc(uid, log_sink) empty!(CHUNK_CACHE) # FIXME: Should be keyed on uid! - proc_states(uid) do states - for (proc, state) in states + states = proc_states(uid) + MemPool.lock_read(states.lock) do + for (proc, state) in states.dict istate = state.state istate.done[] = true notify(istate.reschedule) end - empty!(states) + end + MemPool.lock(states.lock) do + empty!(states.dict) end end function cleanup_proc(state, p, log_sink) @@ -588,6 +589,11 @@ end Dagger.populate_defaults!(options, sig) # Calculate scope + if options.exec_scope !== nothing + # Bypass scope calculation if it's been done for us already + scope = options.exec_scope + @goto scope_computed + end scope = constrain(@something(options.compute_scope, options.scope, DefaultScope()), @something(options.result_scope, AnyScope())) if scope isa InvalidScope @@ -614,6 +620,7 @@ end @goto pop_task end end + @label scope_computed input_procs = @reusable_vector :schedule!_input_procs Processor OSProc() 32 input_procs_cleanup = @reuse_defer_cleanup empty!(input_procs) @@ -772,6 +779,7 @@ end function task_delete!(state, thunk) clear_result!(state, thunk) delete!(state.valid, thunk) + delete!(state.errored, thunk) delete!(state.thunk_dict, thunk.id) end @@ -993,19 +1001,72 @@ struct ProcessorState runner::Task end -const PROCESSOR_TASK_STATE = LockedObject(Dict{UInt64,Dict{Processor,ProcessorState}}()) +const PROCESSOR_TASK_STATE_LOCK = MemPool.ReadWriteLock() +struct ProcessorStateDict + lock::MemPool.ReadWriteLock + dict::Dict{Processor,ProcessorState} + ProcessorStateDict() = new(MemPool.ReadWriteLock(), Dict{Processor,ProcessorState}()) +end +const PROCESSOR_TASK_STATE = Dict{UInt64,ProcessorStateDict}() -function proc_states(f::Base.Callable, uid::UInt64) - lock(PROCESSOR_TASK_STATE) do all_states - if !haskey(all_states, uid) - all_states[uid] = Dict{Processor,ProcessorState}() +function proc_states(uid::UInt64=Dagger.get_tls().sch_uid) + states = MemPool.lock_read(PROCESSOR_TASK_STATE_LOCK) do + if haskey(PROCESSOR_TASK_STATE, uid) + return PROCESSOR_TASK_STATE[uid] + end + return nothing + end + if states === nothing + states = MemPool.lock(PROCESSOR_TASK_STATE_LOCK) do + dict = ProcessorStateDict() + PROCESSOR_TASK_STATE[uid] = dict + return dict + end + end + return states +end +function proc_states_values(uid::UInt64=Dagger.get_tls().sch_uid) + states = proc_states(uid) + return MemPool.lock_read(states.lock) do + return collect(values(states.dict)) + end +end +function proc_state!(f, uid::UInt64, proc::Processor) + states = proc_states(uid) + state = MemPool.lock_read(states.lock) do + return get(states.dict, proc, nothing) + end + if state === nothing + state = f()::ProcessorState + MemPool.lock(states.lock) do + states.dict[proc] = state end - our_states = all_states[uid] - return f(our_states) end + return state end -proc_states(f::Base.Callable) = - proc_states(f, Dagger.get_tls().sch_uid) +proc_state!(f, proc::Processor) = proc_state!(f, Dagger.get_tls().sch_uid, proc) +function proc_state(uid::UInt64, proc::Processor) + states = proc_states(uid) + state = MemPool.lock_read(states.lock) do + return get(states.dict, proc, nothing) + end + if state === nothing + state = MemPool.lock(states.lock) do + state = ProcessorState(ProcessorInternalState(), Task()) + states.dict[proc] = state + return state + end + end + return state +end +proc_state(proc::Processor) = proc_state(Dagger.get_tls().sch_uid, proc) +function maybe_proc_state(uid::UInt64, proc::Processor) + states = proc_states(uid) + return MemPool.lock_read(states.lock) do + return get(states.dict, proc, nothing) + end +end +maybe_proc_state(proc::Processor) = maybe_proc_state(Dagger.get_tls().sch_uid, proc) task_tid_for_processor(::Processor) = nothing task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid @@ -1085,7 +1146,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Try to steal from local queues randomly # TODO: Prioritize stealing from busiest processors - states = proc_states(all_states->collect(values(all_states)), uid) + states = proc_states_values(uid) # TODO: Try to pre-allocate this P = randperm(length(states)) for state in getindex.(Ref(states), P) @@ -1186,14 +1247,10 @@ function (dts::DoTaskSpec)() bt = catch_backtrace() (CapturedException(err, bt), nothing) finally - istate = proc_states(task.sch_uid) do states - if haskey(states, to_proc) - return states[to_proc].state - end - # Processor was removed due to scheduler exit - return nothing - end - if istate !== nothing + state = maybe_proc_state(task.sch_uid, to_proc) + # state will be nothing if processor was removed due to scheduler exit + if state !== nothing + istate = state.state while true # Wait until the task has been recorded in the processor state done = lock(istate.queue) do _ @@ -1253,11 +1310,8 @@ function do_tasks(to_proc, return_queue, tasks) ctx_vars = first(tasks).ctx_vars ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile) uid = first(tasks).sch_uid - state = proc_states(uid) do states - if haskey(states, to_proc) - return states[to_proc] - end - + start_event = nothing + state = proc_state!(uid, to_proc) do # Initialize the processor state and runner queue = PriorityQueue{TaskSpec, UInt32}() queue_locked = LockedObject(queue) @@ -1275,9 +1329,10 @@ function do_tasks(to_proc, return_queue, tasks) @static if VERSION < v"1.9" reschedule.waiter = runner end - state = states[to_proc] = ProcessorState(istate, runner) + return ProcessorState(istate, runner) + end + if start_event !== nothing notify(start_event) - return state end istate = state.state lock(istate.queue) do queue @@ -1304,7 +1359,7 @@ function do_tasks(to_proc, return_queue, tasks) # Kick other processors to make them steal # TODO: Alternatively, automatically balance work instead of blindly enqueueing - states = collect(proc_states(values, uid)) + states = proc_states_values(uid) P = randperm(length(states)) for other_state in getindex.(Ref(states), P) other_istate = other_state.state diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index 0b972bdf1..5e8cbe9bf 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -1,14 +1,6 @@ export SchedulerHaltedException export sch_handle, halt!, exec!, get_dag_ids, add_thunk! -"Identifies a thunk by its ID, and preserves the thunk in the scheduler." -struct ThunkID - id::Int - ref::Union{DRef,Nothing} -end -ThunkID(id::Int) = ThunkID(id, nothing) -Dagger.istask(::ThunkID) = true - "A handle to the scheduler, used by dynamic thunks." struct SchedulerHandle thunk_id::ThunkID diff --git a/src/sch/eager.jl b/src/sch/eager.jl index bd703f1ee..67e895815 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -78,9 +78,7 @@ function thunk_yield(f) h = sch_handle() tls = Dagger.get_tls() proc = Dagger.task_processor() - proc_istate = proc_states(tls.sch_uid) do states - states[proc].state - end + proc_istate = proc_state(tls.sch_uid, proc).state task_occupancy = tls.task_spec.est_occupancy # Decrease our occupancy and inform the processor to reschedule diff --git a/src/sch/util.jl b/src/sch/util.jl index 9141f9a3d..b18fc7d49 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -142,14 +142,10 @@ function schedule_dependents!(state, thunk, failed) end ctr = 0 for dep in state.waiting_data[thunk] - @dagdebug dep :schedule "Checking dependent" dep_isready = false if haskey(state.waiting, dep) set = state.waiting[dep] thunk in set && pop!(set, thunk) - if length(set) > 0 - @dagdebug dep :schedule "Dependent has $(length(set)) upstreams" - end dep_isready = isempty(set) if dep_isready delete!(state.waiting, dep) @@ -175,8 +171,9 @@ end Prepares the scheduler to schedule `thunk`. Will mark `thunk` as ready if its inputs are satisfied. """ -function reschedule_syncdeps!(state, thunk, seen=nothing) - Dagger.maybe_take_or_alloc!(RESCHEDULE_SYNCDEPS_SEEN_CACHE[], seen) do seen +function reschedule_syncdeps!(state, thunk) + #Dagger.maybe_take_or_alloc!(RESCHEDULE_SYNCDEPS_SEEN_CACHE[], seen) do seen + seen = Vector{Thunk}() #=FIXME:REALLOC=# to_visit = Thunk[thunk] while !isempty(to_visit) @@ -203,13 +200,14 @@ function reschedule_syncdeps!(state, thunk, seen=nothing) end w = get!(()->Set{Thunk}(), state.waiting, thunk) if thunk.options.syncdeps !== nothing - for input in thunk.options.syncdeps - input = unwrap_weak_checked(input) - istask(input) && input in seen && continue + for weak_input in thunk.options.syncdeps + @assert weak_input isa Dagger.ThunkSyncdep && weak_input.thunk !== nothing + input = unwrap_weak_checked(weak_input.thunk::WeakThunk)::Thunk + input in seen && continue # Unseen push!(get!(()->Set{Thunk}(), state.waiting_data, input), thunk) - istask(input) || continue + #istask(input) || continue # Unseen task if get(state.errored, input, false) @@ -232,9 +230,10 @@ function reschedule_syncdeps!(state, thunk, seen=nothing) end end end - end + #end end -const RESCHEDULE_SYNCDEPS_SEEN_CACHE = TaskLocalValue{ReusableCache{Set{Thunk},Nothing}}(()->ReusableCache(Set{Thunk}, nothing, 1)) +# N.B. Vector is faster than Set for small collections (which are probably most common) +const RESCHEDULE_SYNCDEPS_SEEN_CACHE = TaskLocalValue{ReusableCache{Vector{Thunk},Nothing}}(()->ReusableCache(Vector{Thunk}, nothing, 1)) "Marks `thunk` and all dependent thunks as failed." function set_failed!(state, origin, thunk=origin) diff --git a/src/submission.jl b/src/submission.jl index 32cdc6d05..1befffa43 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -74,7 +74,7 @@ const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}} uid, future = payload.uid, payload.future fargs, options, reschedule = payload.fargs, payload.options, payload.reschedule - id = next_id() + id = Int(uid) @maybelog ctx timespan_start(ctx, :add_thunk, (;thunk_id=id), (;f=fargs[1], args=fargs[2:end], options, uid)) @@ -82,7 +82,7 @@ const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}} old_fargs_cleanup = @reuse_defer_cleanup empty!(old_fargs) append!(old_fargs, Iterators.map(copy, fargs)) - syncdeps_vec = @reusable_vector :eager_submit_interal!_syncdeps_vec Any nothing 32 + syncdeps_vec = @reusable_vector :eager_submit_interal!_syncdeps_vec ThunkSyncdep ThunkSyncdep() 32 syncdeps_vec_cleanup = @reuse_defer_cleanup empty!(syncdeps_vec) if options.syncdeps !== nothing append!(syncdeps_vec, options.syncdeps) @@ -143,12 +143,16 @@ const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}} @lock state.lock begin @inbounds syncdeps_vec[idx] = state.thunk_dict[tid] end + elseif dep isa ThunkSyncdep + @assert dep.id !== nothing && dep.thunk === nothing + thunk = @lock state.lock state.thunk_dict[dep.id.id] + @inbounds syncdeps_vec[idx] = ThunkSyncdep(thunk) end end end if !isempty(syncdeps_vec) || any(arg->istask(value(arg)), fargs) if options.syncdeps === nothing - options.syncdeps = Set{Any}() + options.syncdeps = Set{ThunkSyncdep}() else empty!(options.syncdeps) end @@ -158,7 +162,7 @@ const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}} end for arg in fargs if istask(value(arg)) - push!(syncdeps, value(arg)) + push!(syncdeps, ThunkSyncdep(value(arg))) end end end @@ -205,6 +209,7 @@ const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}} end end + @assert options.syncdeps === nothing || all(dep->dep isa Dagger.ThunkSyncdep && dep.thunk isa Dagger.WeakThunk, options.syncdeps) @maybelog ctx timespan_finish(ctx, :add_thunk, (;thunk_id=id), (;f=fargs[1], args=fargs[2:end], options, uid)) return thunk_id @@ -281,22 +286,6 @@ function eager_process_args_submission_to_local!(id_map, spec_pairs::Vector{Pair eager_process_args_submission_to_local!(id_map, spec_pair) end end -function eager_process_options_submission_to_local!(id_map, options::Options) - if options.syncdeps !== nothing - raw_syncdeps = options.syncdeps - syncdeps = Set{Any}() - for raw_dep in raw_syncdeps - if raw_dep isa DTask - push!(syncdeps, Sch.ThunkID(id_map[raw_dep.uid], raw_dep.thunk_ref)) - elseif raw_dep isa Sch.ThunkID - push!(syncdeps, raw_dep) - else - error("Invalid syncdep type: $(typeof(raw_dep))") - end - end - options.syncdeps = syncdeps - end -end function DTaskMetadata(spec::DTaskSpec) f = value(spec.fargs[1]) @@ -323,7 +312,6 @@ function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) # Lookup DTask -> ThunkID lock(Sch.EAGER_ID_MAP) do id_map eager_process_args_submission_to_local!(id_map, spec=>task) - eager_process_options_submission_to_local!(id_map, spec.options) end # Submit the task diff --git a/src/thunk.jl b/src/thunk.jl index e4299aae1..cf45ef2a5 100644 --- a/src/thunk.jl +++ b/src/thunk.jl @@ -4,7 +4,7 @@ const ID_COUNTER = Threads.Atomic{Int}(1) next_id() = Threads.atomic_add!(ID_COUNTER, 1) const EMPTY_ARGS = Argument[] -const EMPTY_SYNCDEPS = Set{Any}() +const EMPTY_SYNCDEPS = Set{ThunkSyncdep}() Base.@kwdef mutable struct ThunkSpec fargs::Vector{Argument} = EMPTY_ARGS id::Int = 0 @@ -100,18 +100,18 @@ function Thunk(f, xs...; end spec.options = options::Options if options.syncdeps === nothing - options.syncdeps = Set{Any}() + options.syncdeps = Set{ThunkSyncdep}() end syncdeps_set = options.syncdeps for idx in 2:length(spec.fargs) x = value(spec.fargs[idx]) - if is_task_or_chunk(x) - push!(syncdeps_set, x) + if istask(x) + push!(syncdeps_set, ThunkSyncdep(x)) end end if syncdeps !== nothing for dep in syncdeps - push!(syncdeps_set, dep) + push!(syncdeps_set, ThunkSyncdep(dep)) end end spec.id = id @@ -134,6 +134,8 @@ function Base.getproperty(thunk::Thunk, field::Symbol) return getfield(thunk, field) end end +Base.convert(::Type{ThunkSyncdep}, thunk::Thunk) = + ThunkSyncdep(nothing, thunk) function affinity(t::Thunk) if t.affinity !== nothing @@ -235,12 +237,11 @@ istask(::WeakThunk) = true task_id(t::WeakThunk) = task_id(unwrap_weak_checked(t)) unwrap_weak(t::WeakThunk) = t.x.value unwrap_weak(t) = t -function unwrap_weak_checked(t::WeakThunk) - t = unwrap_weak(t) - @assert t !== nothing - t +function unwrap_weak_checked(t) + t_val = unwrap_weak(t) + @assert !isweak(t) || t_val !== nothing + return t_val end -unwrap_weak_checked(t) = t wrap_weak(t::Thunk) = WeakThunk(t) wrap_weak(t::WeakThunk) = t wrap_weak(t) = t @@ -250,6 +251,8 @@ isweak(t) = true Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value)) Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t) chunktype(t::WeakThunk) = chunktype(unwrap_weak_checked(t)) +Base.convert(::Type{ThunkSyncdep}, t::WeakThunk) = ThunkSyncdep(nothing, t) +ThunkSyncdep(t::WeakThunk) = ThunkSyncdep(nothing, t) "A summary of the data contained in a Thunk, which can be safely serialized." struct ThunkSummary diff --git a/src/thunkid.jl b/src/thunkid.jl new file mode 100644 index 000000000..4df63abf7 --- /dev/null +++ b/src/thunkid.jl @@ -0,0 +1,20 @@ + +"Identifies a thunk by its ID, and preserves the thunk in the scheduler." +struct ThunkID + id::Int + ref::Union{DRef,Nothing} +end +ThunkID(id::Int) = ThunkID(id, nothing) +istask(::ThunkID) = true + +struct ThunkSyncdep + id::Union{ThunkID,Nothing} + thunk +end +ThunkSyncdep() = ThunkSyncdep(nothing, nothing) +ThunkSyncdep(id::ThunkID) = ThunkSyncdep(id, nothing) +ThunkSyncdep(x) = convert(ThunkSyncdep, x) +Base.getindex(syncdep::ThunkSyncdep) = @something(syncdep.id, syncdep.thunk) +Base.convert(::Type{ThunkSyncdep}, id::ThunkID) = ThunkSyncdep(id, nothing) +unwrap_weak(t::ThunkSyncdep) = unwrap_weak(t.thunk) +istask(::ThunkSyncdep) = true diff --git a/src/utils/lfucache.jl b/src/utils/lfucache.jl new file mode 100644 index 000000000..43c4081d5 --- /dev/null +++ b/src/utils/lfucache.jl @@ -0,0 +1,23 @@ +struct BasicLFUCache{K,V} + cache::Dict{K,V} + freq::Dict{K,Int} + max_size::Int + + BasicLFUCache{K,V}(max_size::Int) where {K,V} = new(Dict{K,V}(), Dict{K,Int}(), max_size) +end +function Base.get!(f, cache::BasicLFUCache{K,V}, key::K) where {K,V} + if haskey(cache.cache, key) + cache.freq[key] += 1 + return cache.cache[key] + end + val = f()::V + cache.cache[key] = val + cache.freq[key] = 1 + if length(cache.cache) > cache.max_size + # Find the least frequently used key + _, lfu_key::K = findmin(cache.freq) + delete!(cache.cache, lfu_key) + delete!(cache.freq, lfu_key) + end + return val +end \ No newline at end of file diff --git a/src/utils/logging-events.jl b/src/utils/logging-events.jl index 07111e254..7d05946c6 100644 --- a/src/utils/logging-events.jl +++ b/src/utils/logging-events.jl @@ -172,8 +172,8 @@ init_similar(::TaskArgumentMoves) = TaskArgumentMoves() function (ta::TaskArgumentMoves)(ev::Event{:start}) if ev.category == :move data = ev.timeline.data - if ismutable(data) - thunk_id = ev.id.thunk_id::Int + thunk_id = ev.id.thunk_id::Int + if ismutable(data) && thunk_id != 0 # Ignore Datadeps moves, because we don't have TIDs for them position = Dagger.raw_position(ev.id.position::Dagger.ArgPosition)::Union{Symbol,Int} d = get!(Dict{Union{Int,Symbol},Dagger.LoggedMutableObject}, ta.pre_move_args, thunk_id) d[position] = Dagger.objectid_or_chunkid(data) @@ -195,7 +195,7 @@ function (ta::TaskArgumentMoves)(ev::Event{:finish}) else @warn "No TID $(thunk_id), Position $(position)" end - else + elseif thunk_id != 0 @warn "No TID $(thunk_id)" end end @@ -227,32 +227,24 @@ end Records the dependencies of each submitted task. """ struct TaskDependencies end -function (::TaskDependencies)(ev::Event{:start}) +(td::TaskDependencies)(ev::Event{:start}) = nothing +function (::TaskDependencies)(ev::Event{:finish}) local deps_tids::Vector{Int} function get_deps!(deps) for dep in deps + @assert dep isa Dagger.ThunkSyncdep && dep.thunk isa Dagger.WeakThunk dep = Dagger.unwrap_weak_checked(dep) - if dep isa Dagger.Thunk || dep isa Dagger.Sch.ThunkID - push!(deps_tids, dep.id) - elseif dep isa Dagger.DTask && myid() == 1 - tid = lock(Dagger.Sch.EAGER_ID_MAP) do id_map - id_map[dep.uid] - end - push!(deps_tids, tid) - else - @warn "Unexpected dependency type: $dep" - end + @assert dep isa Dagger.Thunk + push!(deps_tids, dep.id) end end if ev.category == :add_thunk deps_tids = Int[] - get_deps!(Iterators.filter(Dagger.istask, Iterators.map(Dagger.value, ev.timeline.args))) get_deps!(@something(ev.timeline.options.syncdeps, Set())) return ev.id.thunk_id => deps_tids end return end -(td::TaskDependencies)(ev::Event{:finish}) = nothing """ TaskUIDtoTID diff --git a/src/utils/scopes.jl b/src/utils/scopes.jl index 2b310e4ce..949ae2276 100644 --- a/src/utils/scopes.jl +++ b/src/utils/scopes.jl @@ -33,8 +33,7 @@ function compatible_processors(scope::AbstractScope, procs::Vector{<:Processor}) gproc_scope = ProcessScope(gproc) if !isa(constrain(scope, gproc_scope), InvalidScope) for proc in get_processors(gproc) - proc_scope = ExactScope(proc) - if !isa(constrain(scope, proc_scope), InvalidScope) + if proc_in_scope(proc, scope) push!(compat_procs, proc) end end diff --git a/src/utils/viz.jl b/src/utils/viz.jl index 14bdc3d02..605d5655d 100644 --- a/src/utils/viz.jl +++ b/src/utils/viz.jl @@ -88,15 +88,30 @@ function logs_to_dot(logs::Dict; disconnected=false, show_data::Bool=true, id = logs[w][:id][idx] if category == :add_thunk && kind == :start id::NamedTuple - taskdeps = logs[w][:taskdeps][idx]::Pair{Int,Vector{Int}} + tid = id.thunk_id::Int taskname = logs[w][:taskfuncnames][idx]::String - tid, deps = taskdeps v = get!(tid_to_vertex, tid) do add_vertex!(g) tid_to_vertex[tid] = nv(g) nv(g) end tid_to_auto_name[tid] = taskname + if haskey(logs[w], :taskuidtotid) + uid_tid = logs[w][:taskuidtotid][idx] + if uid_tid !== nothing + uid, tid = uid_tid::Pair{UInt,Int} + uid_to_tid[uid] = tid + end + end + elseif category == :add_thunk && kind == :finish + id::NamedTuple + taskdeps = logs[w][:taskdeps][idx]::Pair{Int,Vector{Int}} + tid, deps = taskdeps + v = get!(tid_to_vertex, tid) do + add_vertex!(g) + tid_to_vertex[tid] = nv(g) + nv(g) + end for dep in deps dep_v = get!(tid_to_vertex, dep) do add_vertex!(g) @@ -105,13 +120,6 @@ function logs_to_dot(logs::Dict; disconnected=false, show_data::Bool=true, end add_edge!(g, dep_v, v) end - if haskey(logs[w], :taskuidtotid) - uid_tid = logs[w][:taskuidtotid][idx] - if uid_tid !== nothing - uid, tid = uid_tid::Pair{UInt,Int} - uid_to_tid[uid] = tid - end - end elseif category == :compute && kind == :start id::NamedTuple tid = id.thunk_id diff --git a/test/datadeps.jl b/test/datadeps.jl index 6e7d25b0c..cd83be95f 100644 --- a/test/datadeps.jl +++ b/test/datadeps.jl @@ -98,7 +98,7 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int) _logs = logs[w] for idx in 1:length(_logs[:core]) core_log = _logs[:core][idx] - if core_log.category == :add_thunk && core_log.kind == :start + if core_log.category == :add_thunk && core_log.kind == :finish taskdeps = _logs[:taskdeps][idx]::Pair{Int,Vector{Int}} if taskdeps[1] == tid return taskdeps[2] diff --git a/test/extlang/python.jl b/test/extlang/python.jl index 31a64c1e3..64e6d62c2 100644 --- a/test/extlang/python.jl +++ b/test/extlang/python.jl @@ -1,8 +1,7 @@ -using PythonCall using CondaPkg - CondaPkg.add("numpy") +using PythonCall np = pyimport("numpy") # Restart scheduler to see new methods