From 90a974f2457dc6a8d65afc4c66653bc44aa00ae3 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:42:50 -0700 Subject: [PATCH 01/56] Add metadata to EagerThunk --- Project.toml | 1 + src/dtask.jl | 14 +++++++++++++- src/submission.jl | 14 +++++++++++++- 3 files changed, 27 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index fd7508cd7..439d89081 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" +Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" diff --git a/src/dtask.jl b/src/dtask.jl index 68f2d3c1b..98f74005a 100644 --- a/src/dtask.jl +++ b/src/dtask.jl @@ -39,6 +39,16 @@ end Options(;options...) = Options((;options...)) Options(options...) = Options((;options...)) +""" + DTaskMetadata + +Represents some useful metadata pertaining to a `DTask`: +- `return_type::Type` - The inferred return type of the task +""" +mutable struct DTaskMetadata + return_type::Type +end + """ DTask @@ -50,9 +60,11 @@ more details. mutable struct DTask uid::UInt future::ThunkFuture + metadata::DTaskMetadata finalizer_ref::DRef thunk_ref::DRef - DTask(uid, future, finalizer_ref) = new(uid, future, finalizer_ref) + + DTask(uid, future, metadata, finalizer_ref) = new(uid, future, metadata, finalizer_ref) end const EagerThunk = DTask diff --git a/src/submission.jl b/src/submission.jl index 7312e378d..cbeb2a795 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -218,15 +218,27 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) return options end end + +function DTaskMetadata(spec::DTaskSpec) + f = chunktype(spec.f).instance + arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) + return_type = Base._return_type(f, Base.to_tuple_type(arg_types)) + return DTaskMetadata(return_type) +end + function eager_spawn(spec::DTaskSpec) # Generate new DTask uid = eager_next_id() future = ThunkFuture() + metadata = DTaskMetadata(spec) finalizer_ref = poolset(DTaskFinalizer(uid); device=MemPool.CPURAMDevice()) # Create unlaunched DTask - return DTask(uid, future, finalizer_ref) + return DTask(uid, future, metadata, finalizer_ref) end + +chunktype(t::DTask) = t.metadata.return_type + function eager_launch!((spec, task)::Pair{DTaskSpec,DTask}) # Assign a name, if specified eager_assign_name!(spec, task) From cbac6056a013394169ffbfdb1efe3bdbc254c717 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 30 Nov 2023 20:44:14 -0700 Subject: [PATCH 02/56] Sch: Allow occupancy key to be Any --- src/sch/util.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index e81703db5..cd006838b 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -406,12 +406,19 @@ function has_capacity(state, p, gp, time_util, alloc_util, occupancy, sig) else get(state.signature_alloc_cost, sig, UInt64(0)) end::UInt64 - est_occupancy = if occupancy !== nothing && haskey(occupancy, T) - # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` - Base.unsafe_trunc(UInt32, clamp(occupancy[T], 0, 1) * typemax(UInt32)) - else - typemax(UInt32) - end::UInt32 + est_occupancy::UInt32 = typemax(UInt32) + if occupancy !== nothing + occ = nothing + if haskey(occupancy, T) + occ = occupancy[T] + elseif haskey(occupancy, Any) + occ = occupancy[Any] + end + if occ !== nothing + # Clamp to 0-1, and scale between 0 and `typemax(UInt32)` + est_occupancy = Base.unsafe_trunc(UInt32, clamp(occ, 0, 1) * typemax(UInt32)) + end + end #= FIXME: Estimate if cached data can be swapped to storage storage = storage_resource(p) real_alloc_util = state.worker_storage_pressure[gp][storage] From e441bd06d2057b097ec3f8418fa26312d0de29b3 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 12 Sep 2023 10:56:47 -0500 Subject: [PATCH 03/56] Add streaming API --- Project.toml | 1 - docs/make.jl | 1 + docs/src/streaming.md | 105 +++++++++++ src/Dagger.jl | 5 + src/sch/eager.jl | 7 + src/stream-buffers.jl | 202 ++++++++++++++++++++ src/stream-fetchers.jl | 24 +++ src/stream.jl | 418 +++++++++++++++++++++++++++++++++++++++++ 8 files changed, 762 insertions(+), 1 deletion(-) create mode 100644 docs/src/streaming.md create mode 100644 src/stream-buffers.jl create mode 100644 src/stream-fetchers.jl create mode 100644 src/stream.jl diff --git a/Project.toml b/Project.toml index 439d89081..fd7508cd7 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,6 @@ Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MemPool = "f9f48841-c794-520a-933b-121f7ba6ed94" -Mmap = "a63ad114-7e13-5084-954f-fe012c677804" OnlineStats = "a15396b6-48d5-5d58-9928-6d29437db91e" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Profile = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79" diff --git a/docs/make.jl b/docs/make.jl index c21c03f2d..8f1f97f5c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -22,6 +22,7 @@ makedocs(; "Task Spawning" => "task-spawning.md", "Data Management" => "data-management.md", "Distributed Arrays" => "darray.md", + "Streaming Tasks" => "streaming.md", "Scopes" => "scopes.md", "Processors" => "processors.md", "Task Queues" => "task-queues.md", diff --git a/docs/src/streaming.md b/docs/src/streaming.md new file mode 100644 index 000000000..0a13a1472 --- /dev/null +++ b/docs/src/streaming.md @@ -0,0 +1,105 @@ +# Streaming Tasks + +Dagger tasks have a limited lifetime - they are created, execute, finish, and +are eventually destroyed when they're no longer needed. Thus, if one wants +to run the same kind of computations over and over, one might re-create a +similar set of tasks for each unit of data that needs processing. + +This might be fine for computations which take a long time to run (thus +dwarfing the cost of task creation, which is quite small), or when working with +a limited set of data, but this approach is not great for doing lots of small +computations on a large (or endless) amount of data. For example, processing +image frames from a webcam, reacting to messages from a message bus, reading +samples from a software radio, etc. All of these tasks are better suited to a +"streaming" model of data processing, where data is simply piped into a +continuously-running task (or DAG of tasks) forever, or until the data runs +out. + +Thankfully, if you have a problem which is best modeled as a streaming system +of tasks, Dagger has you covered! Building on its support for +["Task Queues"](@ref), Dagger provides a means to convert an entire DAG of +tasks into a streaming DAG, where data flows into and out of each task +asynchronously, using the `spawn_streaming` function: + +```julia +Dagger.spawn_streaming() do # enters a streaming region + vals = Dagger.@spawn rand() + print_vals = Dagger.@spawn println(vals) +end # exits the streaming region, and starts the DAG running +``` + +In the above example, `vals` is a Dagger task which has been transformed to run +in a streaming manner - instead of just calling `rand()` once and returning its +result, it will re-run `rand()` endlessly, continuously producing new random +values. In typical Dagger style, `print_vals` is a Dagger task which depends on +`vals`, but in streaming form - it will continuously `println` the random +values produced from `vals`. Both tasks will run forever, and will run +efficiently, only doing the work necessary to generate, transfer, and consume +values. + +As the comments point out, `spawn_streaming` creates a streaming region, during +which `vals` and `print_vals` are created and configured. Both tasks are halted +until `spawn_streaming` returns, allowing large DAGs to be built all at once, +without any task losing a single value. If desired, streaming regions can be +connected, although some values might be lost while tasks are being connected: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.@spawn rand() +end + +# Some values might be generated by `vals` but thrown away +# before `print_vals` is fully setup and connected to it + +print_vals = Dagger.spawn_streaming() do + Dagger.@spawn println(vals) +end +``` + +More complicated streaming DAGs can be easily constructed, without doing +anything different. For example, we can generate multiple streams of random +numbers, write them all to their own files, and print the combined results: + +```julia +Dagger.spawn_streaming() do + all_vals = [Dagger.spawn(rand) for i in 1:4] + all_vals_written = map(1:4) do i + Dagger.spawn(all_vals[i]) do val + open("results_$i.txt"; write=true, create=true, append=true) do io + println(io, repr(val)) + end + return val + end + end + Dagger.spawn(all_vals_written...) do all_vals_written... + vals_sum = sum(all_vals_written) + println(vals_sum) + end +end +``` + +If you want to stop the streaming DAG and tear it all down, you can call +`Dagger.kill!(all_vals[1])` (or `Dagger.kill!(all_vals_written[2])`, etc., the +kill propagates throughout the DAG). + +Alternatively, tasks can stop themselves from the inside with +`finish_streaming`, optionally returning a value that can be `fetch`'d. Let's +do this when our randomly-drawn number falls within some arbitrary range: + +```julia +vals = Dagger.spawn_streaming() do + Dagger.spawn() do + x = rand() + if x < 0.001 + # That's good enough, let's be done + return Dagger.finish_streaming("Finished!") + end + return x + end +end +fetch(vals) +``` + +In this example, the call to `fetch` will hang (while random numbers continue +to be drawn), until a drawn number is less than 0.001; at that point, `fetch` +will return with "Finished!", and the task `vals` will have terminated. diff --git a/src/Dagger.jl b/src/Dagger.jl index b478ece0f..c340c6579 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -67,6 +67,11 @@ include("sch/Sch.jl"); using .Sch # Data dependency task queue include("datadeps.jl") +# Streaming +include("stream-buffers.jl") +include("stream-fetchers.jl") +include("stream.jl") + # Array computations include("array/darray.jl") include("array/alloc.jl") diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 87a109788..a849957ea 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -124,6 +124,13 @@ function eager_cleanup(state, uid) # N.B. cache and errored expire automatically delete!(state.thunk_dict, tid) end + remotecall_wait(1, uid) do uid + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + delete!(global_streams, uid) + end + end + end end function _find_thunk(e::Dagger.DTask) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl new file mode 100644 index 000000000..753f8c11c --- /dev/null +++ b/src/stream-buffers.jl @@ -0,0 +1,202 @@ +""" +A buffer that drops all elements put into it. Only to be used as the output +buffer for a task - will throw if attached as an input. +""" +struct DropBuffer{T} end +DropBuffer{T}(_) where T = DropBuffer{T}() +Base.isempty(::DropBuffer) = true +isfull(::DropBuffer) = false +Base.put!(::DropBuffer, _) = nothing +Base.take!(::DropBuffer) = error("Cannot `take!` from a DropBuffer") + +"A process-local buffer backed by a `Channel{T}`." +struct ChannelBuffer{T} + channel::Channel{T} + len::Int + count::Threads.Atomic{Int} + ChannelBuffer{T}(len::Int=1024) where T = + new{T}(Channel{T}(len), len, Threads.Atomic{Int}(0)) +end +Base.isempty(cb::ChannelBuffer) = isempty(cb.channel) +isfull(cb::ChannelBuffer) = cb.count[] == cb.len +function Base.put!(cb::ChannelBuffer{T}, x) where T + put!(cb.channel, convert(T, x)) + Threads.atomic_add!(cb.count, 1) +end +function Base.take!(cb::ChannelBuffer) + take!(cb.channel) + Threads.atomic_sub!(cb.count, 1) +end + +"A cross-worker buffer backed by a `RemoteChannel{T}`." +struct RemoteChannelBuffer{T} + channel::RemoteChannel{Channel{T}} + len::Int + count::Threads.Atomic{Int} + RemoteChannelBuffer{T}(len::Int=1024) where T = + new{T}(RemoteChannel(()->Channel{T}(len)), len, Threads.Atomic{Int}(0)) +end +Base.isempty(cb::RemoteChannelBuffer) = isempty(cb.channel) +isfull(cb::RemoteChannelBuffer) = cb.count[] == cb.len +function Base.put!(cb::RemoteChannelBuffer{T}, x) where T + put!(cb.channel, convert(T, x)) + Threads.atomic_add!(cb.count, 1) +end +function Base.take!(cb::RemoteChannelBuffer) + take!(cb.channel) + Threads.atomic_sub!(cb.count, 1) +end + +"A process-local ring buffer." +mutable struct ProcessRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function ProcessRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 +isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +function Base.put!(rb::ProcessRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ProcessRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end + +#= TODO +"A server-local ring buffer backed by shared-memory." +mutable struct ServerRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function ServerRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::ServerRingBuffer) = (@atomic rb.count) == 0 +function Base.put!(rb::ServerRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::ServerRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end +=# + +#= +"A TCP-based ring buffer." +mutable struct TCPRingBuffer{T} + read_idx::Int + write_idx::Int + @atomic count::Int + buffer::Vector{T} + function TCPRingBuffer{T}(len::Int=1024) where T + buffer = Vector{T}(undef, len) + return new{T}(1, 1, 0, buffer) + end +end +Base.isempty(rb::TCPRingBuffer) = (@atomic rb.count) == 0 +function Base.put!(rb::TCPRingBuffer{T}, x) where T + len = length(rb.buffer) + while (@atomic rb.count) == len + yield() + end + to_write_idx = mod1(rb.write_idx, len) + rb.buffer[to_write_idx] = convert(T, x) + rb.write_idx += 1 + @atomic rb.count += 1 +end +function Base.take!(rb::TCPRingBuffer) + while (@atomic rb.count) == 0 + yield() + end + to_read_idx = rb.read_idx + rb.read_idx += 1 + @atomic rb.count -= 1 + to_read_idx = mod1(to_read_idx, length(rb.buffer)) + return rb.buffer[to_read_idx] +end +=# + +#= +""" +A flexible puller which switches to the most efficient buffer type based +on the sender and receiver locations. +""" +mutable struct UniBuffer{T} + buffer::Union{ProcessRingBuffer{T}, Nothing} +end +function initialize_stream_buffer!(::Type{UniBuffer{T}}, T, send_proc, recv_proc, buffer_amount) where T + if buffer_amount == 0 + error("Return NullBuffer") + end + send_osproc = get_parent(send_proc) + recv_osproc = get_parent(recv_proc) + if send_osproc.pid == recv_osproc.pid + inner = RingBuffer{T}(buffer_amount) + elseif system_uuid(send_osproc.pid) == system_uuid(recv_osproc.pid) + inner = ProcessBuffer{T}(buffer_amount) + else + inner = RemoteBuffer{T}(buffer_amount) + end + return UniBuffer{T}(buffer_amount) +end + +struct LocalPuller{T,B} + buffer::B{T} + id::UInt + function LocalPuller{T,B}(id::UInt, buffer_amount::Integer) where {T,B} + buffer = initialize_stream_buffer!(B, T, buffer_amount) + return new{T,B}(buffer, id) + end +end +function Base.take!(pull::LocalPuller{T,B}) where {T,B} + if pull.buffer === nothing + pull.buffer = + error("Return NullBuffer") + end + value = take!(pull.buffer) +end +function initialize_input_stream!(stream::Stream{T,B}, id::UInt, send_proc::Processor, recv_proc::Processor, buffer_amount::Integer) where {T,B} + local_buffer = remotecall_fetch(stream.ref.handle.owner, stream.ref.handle, id) do ref, id + local_buffer, remote_buffer = initialize_stream_buffer!(B, T, send_proc, recv_proc, buffer_amount) + ref.buffers[id] = remote_buffer + return local_buffer + end + stream.buffer = local_buffer + return stream +end +=# diff --git a/src/stream-fetchers.jl b/src/stream-fetchers.jl new file mode 100644 index 000000000..f8660cdf1 --- /dev/null +++ b/src/stream-fetchers.jl @@ -0,0 +1,24 @@ +struct RemoteFetcher end +function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} + if store_ref.handle.owner == myid() + store = fetch(store_ref)::Store_remote + while !isfull(buffer) + value = take!(store, id)::T + put!(buffer, value) + end + else + tls = Dagger.get_tls() + values = remotecall_fetch(store_ref.handle.owner, store_ref.handle, id, T, Store_remote) do store_ref, id, T, Store_remote + store = MemPool.poolget(store_ref)::Store_remote + values = T[] + while !isempty(store, id) + value = take!(store, id)::T + push!(values, value) + end + return values + end::Vector{T} + for value in values + put!(buffer, value) + end + end +end diff --git a/src/stream.jl b/src/stream.jl new file mode 100644 index 000000000..656e646f0 --- /dev/null +++ b/src/stream.jl @@ -0,0 +1,418 @@ +mutable struct StreamStore{T,B} + waiters::Vector{Int} + buffers::Dict{Int,B} + buffer_amount::Int + open::Bool + lock::Threads.Condition + StreamStore{T,B}(buffer_amount::Integer) where {T,B} = + new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount, + true, Threads.Condition()) +end +tid() = Dagger.Sch.sch_handle().thunk_id.id +function uid() + thunk_id = tid() + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end +end +function Base.put!(store::StreamStore{T,B}, value) where {T,B} + @lock store.lock begin + if !isopen(store) + @dagdebug nothing :stream_put "[$(uid())] closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug nothing :stream_put "[$(uid())] adding $value" + for buffer in values(store.buffers) + while isfull(buffer) + @dagdebug nothing :stream_put "[$(uid())] buffer full, waiting" + wait(store.lock) + end + put!(buffer, value) + end + notify(store.lock) + end +end +function Base.take!(store::StreamStore, id::UInt) + @lock store.lock begin + buffer = store.buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug nothing :stream_take "[$(uid())] no elements, not taking" + wait(store.lock) + end + @dagdebug nothing :stream_take "[$(uid())] wait finished" + if !isopen(store, id) + @dagdebug nothing :stream_take "[$(uid())] closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end + unlock(store.lock) + value = try + take!(buffer) + finally + lock(store.lock) + end + @dagdebug nothing :stream_take "[$(uid())] value accepted" + notify(store.lock) + return value + end +end +Base.isempty(store::StreamStore, id::UInt) = isempty(store.buffers[id]) +isfull(store::StreamStore, id::UInt) = isfull(store.buffers[id]) +"Returns whether the store is actively open. Only check this when deciding if new values can be pushed." +Base.isopen(store::StreamStore) = store.open +"Returns whether the store is actively open, or if closing, still has remaining messages for `id`. Only check this when deciding if existing values can be taken." +function Base.isopen(store::StreamStore, id::UInt) + @lock store.lock begin + if !isempty(store.buffers[id]) + return true + end + return store.open + end +end +function Base.close(store::StreamStore) + if store.open + store.open = false + @lock store.lock notify(store.lock) + end +end +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Int}) where {T,B} + @lock store.lock begin + for w in waiters + buffer = initialize_stream_buffer(B, T, store.buffer_amount) + store.buffers[w] = buffer + end + append!(store.waiters, waiters) + notify(store.lock) + end +end +function remove_waiters!(store::StreamStore, waiters::Vector{Int}) + @lock store.lock begin + for w in waiters + delete!(store.buffers, w) + idx = findfirst(wo->wo==w, store.waiters) + deleteat!(store.waiters, idx) + end + notify(store.lock) + end +end + +mutable struct Stream{T,B} + store::Union{StreamStore{T,B},Nothing} + store_ref::Chunk + input_buffer::Union{B,Nothing} + buffer_amount::Int + function Stream{T,B}(buffer_amount::Integer=0) where {T,B} + # Creates a new output stream + store = StreamStore{T,B}(buffer_amount) + store_ref = tochunk(store) + return new{T,B}(store, store_ref, nothing, buffer_amount) + end + function Stream{B}(stream::Stream{T}, buffer_amount::Integer=0) where {T,B} + # References an existing output stream + return new{T,B}(nothing, stream.store_ref, nothing, buffer_amount) + end +end +function initialize_input_stream!(stream::Stream{T,B}) where {T,B} + stream.input_buffer = initialize_stream_buffer(B, T, stream.buffer_amount) +end + +Base.put!(stream::Stream, @nospecialize(value)) = + put!(stream.store, value) +function Base.take!(stream::Stream{T,B}, id::UInt) where {T,B} + # FIXME: Make remote fetcher configurable + stream_fetch_values!(RemoteFetcher, T, stream.store_ref, stream.input_buffer, id) + return take!(stream.input_buffer) +end +function Base.isopen(stream::Stream, id::UInt)::Bool + return remotecall_fetch(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + return isopen(MemPool.poolget(ref)::StreamStore, id) + end +end +function Base.close(stream::Stream) + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + close(MemPool.poolget(ref)::StreamStore) + end +end +function add_waiters!(stream::Stream, waiters::Vector{Int}) + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + add_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + end +end +add_waiters!(stream::Stream, waiter::Integer) = + add_waiters!(stream::Stream, Int[waiter]) +function remove_waiters!(stream::Stream, waiters::Vector{Int}) + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + remove_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + end +end +remove_waiters!(stream::Stream, waiter::Integer) = + remove_waiters!(stream::Stream, Int[waiter]) + +function migrate_stream!(stream::Stream, w::Integer=myid()) + if !isdefined(MemPool, :migrate!) + @warn "MemPool migration support not enabled!\nPerformance may be degraded" maxlog=1 + return + end + + # Perform migration of the StreamStore + # MemPool will block access to the new ref until the migration completes + if stream.store_ref.handle.owner != w + # Take lock to prevent any further modifications + # N.B. Serialization automatically unlocks + remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref + lock((MemPool.poolget(ref)::StreamStore).lock) + end + + MemPool.migrate!(stream.store_ref.handle, w) + end +end + +struct StreamingTaskQueue <: AbstractTaskQueue + tasks::Vector{Pair{DTaskSpec,DTask}} + self_streams::Dict{UInt,Any} + StreamingTaskQueue() = new(Pair{DTaskSpec,DTask}[], + Dict{UInt,Any}()) +end + +function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) + push!(queue.tasks, spec) + initialize_streaming!(queue.self_streams, spec...) +end +function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) + append!(queue.tasks, specs) + for (spec, task) in specs + initialize_streaming!(queue.self_streams, spec, task) + end +end +function initialize_streaming!(self_streams, spec, task) + if !isa(spec.f, StreamingFunction) + # Adapt called function for streaming and generate output Streams + T_old = Base.uniontypes(task.metadata.return_type) + T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) + # We treat non-dominating error paths as unreachable + T_old = filter(t->t !== Union{}, T_old) + T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) + if output_buffer_amount <= 0 + throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) + end + output_buffer = get(spec.options, :stream_output_buffer, ProcessRingBuffer) + stream = Stream{T,output_buffer}(output_buffer_amount) + spec.options = NamedTuple(filter(opt -> opt[1] != :stream_output_buffer && + opt[1] != :stream_output_buffer_amount, + Base.pairs(spec.options))) + self_streams[task.uid] = stream + + spec.f = StreamingFunction(spec.f, stream) + spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) + + # Register Stream globally + remotecall_wait(1, task.uid, stream) do uid, stream + lock(EAGER_THUNK_STREAMS) do global_streams + global_streams[uid] = stream + end + end + end +end + +function spawn_streaming(f::Base.Callable) + queue = StreamingTaskQueue() + result = with_options(f; task_queue=queue) + if length(queue.tasks) > 0 + finalize_streaming!(queue.tasks, queue.self_streams) + enqueue!(queue.tasks) + end + return result +end + +struct FinishStream{T,R} + value::Union{Some{T},Nothing} + result::R +end +finish_stream(value::T; result::R=nothing) where {T,R} = + FinishStream{T,R}(Some{T}(value), result) +finish_stream(; result::R=nothing) where R = + FinishStream{Union{},R}(nothing, result) + +function cancel_stream!(t::DTask) + stream = task_to_stream(t.uid) + if stream !== nothing + close(stream) + end +end + +struct StreamingFunction{F, S} + f::F + stream::S +end +chunktype(sf::StreamingFunction{F}) where F = F +function (sf::StreamingFunction)(args...; kwargs...) + @nospecialize sf args kwargs + result = nothing + thunk_id = tid() + uid = remotecall_fetch(1, thunk_id) do thunk_id + lock(Sch.EAGER_ID_MAP) do id_map + for (uid, otid) in id_map + if thunk_id == otid + return uid + end + end + end + end + + # Migrate our output stream to this worker + if sf.stream isa Stream + migrate_stream!(sf.stream) + end + + try + # TODO: This kwarg song-and-dance is required to ensure that we don't + # allocate boxes within `stream!`, when possible + kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) + kwarg_values = map(last, (kwargs...,)) + for arg in args + if arg isa Stream + initialize_input_stream!(arg) + end + end + return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) + finally + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end + end + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end + end + for stream in streams + @dagdebug nothing :stream_close "[$uid] dropping waiter" + remove_waiters!(stream, uid) + @dagdebug nothing :stream_close "[$uid] dropped waiter" + end + + # Ensure downstream tasks also terminate + @dagdebug nothing :stream_close "[$uid] closed stream" + close(sf.stream) + end +end +# N.B We specialize to minimize/eliminate allocations +function stream!(sf::StreamingFunction, uid, + args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) + f = move(thunk_processor(), sf.f) + while true + # Get values from Stream args/kwargs + stream_args = _stream_take_values!(args, uid) + stream_kwarg_values = _stream_take_values!(kwarg_values, uid) + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + + # Run a single cycle of f + stream_result = f(stream_args...; stream_kwargs...) + + # Exit streaming on graceful request + if stream_result isa FinishStream + if stream_result.value !== nothing + value = something(stream_result.value) + put!(sf.stream, value) + end + return stream_result.result + end + + # Put the result into the output stream + put!(sf.stream, stream_result) + end +end +function _stream_take_values!(args, uid) + return ntuple(length(args)) do idx + arg = args[idx] + if arg isa Stream + take!(arg, uid) + else + arg + end + end +end +@inline @generated function _stream_namedtuple(kwarg_names::Tuple, + stream_kwarg_values::Tuple) + name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) + NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) + return :($NT(stream_kwarg_values)) +end +initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) + +const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) +function task_to_stream(uid::UInt) + if myid() != 1 + return remotecall_fetch(task_to_stream, 1, uid) + end + lock(EAGER_THUNK_STREAMS) do global_streams + if haskey(global_streams, uid) + return global_streams[uid] + end + return + end +end + +function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) + stream_waiter_changes = Dict{UInt,Vector{Int}}() + + for (spec, task) in tasks + @assert haskey(self_streams, task.uid) + + # Adapt args to accept Stream output of other streaming tasks + for (idx, (pos, arg)) in enumerate(spec.args) + if arg isa DTask + # Check if this is a streaming task + if haskey(self_streams, arg.uid) + other_stream = self_streams[arg.uid] + else + other_stream = task_to_stream(arg.uid) + end + + if other_stream !== nothing + # Get input stream configs and configure input stream + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + if input_buffer_amount <= 0 + throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) + end + input_buffer = get(spec.options, :stream_input_buffer, ProcessRingBuffer) + # FIXME: input_fetcher = get(spec.options, :stream_input_fetcher, RemoteFetcher) + input_stream = Stream{input_buffer}(other_stream, input_buffer_amount) + + # Replace the DTask with the input Stream + spec.args[idx] = pos => other_stream + + # Add this task as a waiter for the associated output Stream + changes = get!(stream_waiter_changes, arg.uid) do + Int[] + end + push!(changes, task.uid) + end + end + end + + # Filter out all streaming options + to_filter = (:stream_input_buffer, :stream_input_buffer_amount, + :stream_output_buffer, :stream_output_buffer_amount) + spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), + Base.pairs(spec.options))) + if haskey(spec.options, :propagates) + propagates = filter(opt -> !(opt in to_filter), + spec.options.propagates) + spec.options = merge(spec.options, (;propagates)) + end + end + + # Adjust waiter count of Streams with dependencies + for (uid, waiters) in stream_waiter_changes + stream = task_to_stream(uid) + add_waiters!(stream, waiters) + end +end From 17096fa33bfc5287e9762bc1ce2a158eedfe8353 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Wed, 13 Mar 2024 22:43:44 +0100 Subject: [PATCH 04/56] Reference Dagger.EAGER_THUNK_STREAMS explicitly --- src/sch/eager.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index a849957ea..c964b722d 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -125,7 +125,7 @@ function eager_cleanup(state, uid) delete!(state.thunk_dict, tid) end remotecall_wait(1, uid) do uid - lock(EAGER_THUNK_STREAMS) do global_streams + lock(Dagger.EAGER_THUNK_STREAMS) do global_streams if haskey(global_streams, uid) delete!(global_streams, uid) end From 563f6467a7e15583b6d00bd73f3e7e8fe9339d85 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Thu, 14 Mar 2024 10:42:33 +0100 Subject: [PATCH 05/56] Use Base.promote_op() instead of Base._return_type() return_type() is kinda broken in v1.10, see: https://github.com/JuliaLang/julia/issues/52385 In any case Base.promote_op() is the official public API for this operation so we should use it anyway. --- src/submission.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/submission.jl b/src/submission.jl index cbeb2a795..bfb8cb8be 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -222,7 +222,7 @@ end function DTaskMetadata(spec::DTaskSpec) f = chunktype(spec.f).instance arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) - return_type = Base._return_type(f, Base.to_tuple_type(arg_types)) + return_type = Base.promote_op(f, arg_types...) return DTaskMetadata(return_type) end From 2f29be7d6cad1cf07073403627ed1a5ddf64f35c Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sun, 17 Mar 2024 11:55:17 +0100 Subject: [PATCH 06/56] Special-case StreamingFunction in EagerThunkMetadata() constructor This always us to handle all the other kinds of task specs. --- src/submission.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/submission.jl b/src/submission.jl index bfb8cb8be..f23539271 100644 --- a/src/submission.jl +++ b/src/submission.jl @@ -220,7 +220,7 @@ function eager_process_options_submission_to_local(id_map, options::NamedTuple) end function DTaskMetadata(spec::DTaskSpec) - f = chunktype(spec.f).instance + f = spec.f isa StreamingFunction ? spec.f.f : spec.f arg_types = ntuple(i->chunktype(spec.args[i][2]), length(spec.args)) return_type = Base.promote_op(f, arg_types...) return DTaskMetadata(return_type) From d25d6c1729621c6480e96d7b3276601db5d643b0 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Mon, 1 Apr 2024 13:42:28 +0200 Subject: [PATCH 07/56] Fix reference to task-queues.md in the docs This should get the docs building again. --- docs/src/streaming.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/streaming.md b/docs/src/streaming.md index 0a13a1472..338f5739b 100644 --- a/docs/src/streaming.md +++ b/docs/src/streaming.md @@ -17,7 +17,7 @@ out. Thankfully, if you have a problem which is best modeled as a streaming system of tasks, Dagger has you covered! Building on its support for -["Task Queues"](@ref), Dagger provides a means to convert an entire DAG of +[Task Queues](@ref), Dagger provides a means to convert an entire DAG of tasks into a streaming DAG, where data flows into and out of each task asynchronously, using the `spawn_streaming` function: From aa07cc9645417c49d9a0fc478a60dcdf8327ad26 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Mon, 1 Apr 2024 14:25:17 +0200 Subject: [PATCH 08/56] Delete Dagger.cleanup() Because it doesn't actually do anything now. --- src/compute.jl | 6 ------ src/sch/Sch.jl | 3 --- 2 files changed, 9 deletions(-) diff --git a/src/compute.jl b/src/compute.jl index f421eaccc..093b527f4 100644 --- a/src/compute.jl +++ b/src/compute.jl @@ -36,12 +36,6 @@ end Base.@deprecate gather(ctx, x) collect(ctx, x) Base.@deprecate gather(x) collect(x) -cleanup() = cleanup(Context(global_context())) -function cleanup(ctx::Context) - Sch.cleanup(ctx) - nothing -end - function get_type(s::String) local T for t in split(s, ".") diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 73bb07bf9..12d259352 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -307,9 +307,6 @@ function populate_defaults(opts::ThunkOptions, Tf, Targs) ) end -function cleanup(ctx) -end - # Eager scheduling include("eager.jl") From f8d0b8bbf6a28fa256ad658a8f090fa83e651f69 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 10 Apr 2024 10:32:11 -0700 Subject: [PATCH 09/56] streaming: Show thunk ID in logs --- src/Dagger.jl | 1 + src/stream-fetchers.jl | 3 ++- src/stream.jl | 34 ++++++++++++++++++++-------------- src/utils/dagdebug.jl | 3 ++- 4 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index c340c6579..8f7d36cc7 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -23,6 +23,7 @@ if !isdefined(Base, :ScopedValues) else import Base.ScopedValues: ScopedValue, with end +import TaskLocalValues: TaskLocalValue if !isdefined(Base, :get_extension) import Requires: @require diff --git a/src/stream-fetchers.jl b/src/stream-fetchers.jl index f8660cdf1..5c6834d6e 100644 --- a/src/stream-fetchers.jl +++ b/src/stream-fetchers.jl @@ -7,8 +7,9 @@ function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_r put!(buffer, value) end else - tls = Dagger.get_tls() + thunk_id = STREAM_THUNK_ID[] values = remotecall_fetch(store_ref.handle.owner, store_ref.handle, id, T, Store_remote) do store_ref, id, T, Store_remote + STREAM_THUNK_ID[] = thunk_id store = MemPool.poolget(store_ref)::Store_remote values = T[] while !isempty(store, id) diff --git a/src/stream.jl b/src/stream.jl index 656e646f0..5702cc522 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -8,9 +8,7 @@ mutable struct StreamStore{T,B} new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount, true, Threads.Condition()) end -tid() = Dagger.Sch.sch_handle().thunk_id.id -function uid() - thunk_id = tid() +function tid_to_uid(thunk_id) lock(Sch.EAGER_ID_MAP) do id_map for (uid, otid) in id_map if thunk_id == otid @@ -20,15 +18,17 @@ function uid() end end function Base.put!(store::StreamStore{T,B}, value) where {T,B} + thunk_id = STREAM_THUNK_ID[] + uid = tid_to_uid(thunk_id) @lock store.lock begin if !isopen(store) - @dagdebug nothing :stream_put "[$(uid())] closed!" + @dagdebug thunk_id :stream "[$uid] closed!" throw(InvalidStateException("Stream is closed", :closed)) end - @dagdebug nothing :stream_put "[$(uid())] adding $value" + @dagdebug thunk_id :stream "[$uid] adding $value" for buffer in values(store.buffers) while isfull(buffer) - @dagdebug nothing :stream_put "[$(uid())] buffer full, waiting" + @dagdebug thunk_id :stream "[$uid] buffer full, waiting" wait(store.lock) end put!(buffer, value) @@ -37,15 +37,17 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} end end function Base.take!(store::StreamStore, id::UInt) + thunk_id = STREAM_THUNK_ID[] + uid = tid_to_uid(thunk_id) @lock store.lock begin buffer = store.buffers[id] while isempty(buffer) && isopen(store, id) - @dagdebug nothing :stream_take "[$(uid())] no elements, not taking" + @dagdebug thunk_id :stream "[$uid] no elements, not taking" wait(store.lock) end - @dagdebug nothing :stream_take "[$(uid())] wait finished" + @dagdebug thunk_id :stream "[$uid] wait finished" if !isopen(store, id) - @dagdebug nothing :stream_take "[$(uid())] closed!" + @dagdebug thunk_id :stream "[$uid] closed!" throw(InvalidStateException("Stream is closed", :closed)) end unlock(store.lock) @@ -54,7 +56,7 @@ function Base.take!(store::StreamStore, id::UInt) finally lock(store.lock) end - @dagdebug nothing :stream_take "[$(uid())] value accepted" + @dagdebug thunk_id :stream "[$uid] value accepted" notify(store.lock) return value end @@ -244,6 +246,8 @@ function cancel_stream!(t::DTask) end end +const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) + struct StreamingFunction{F, S} f::F stream::S @@ -252,7 +256,9 @@ chunktype(sf::StreamingFunction{F}) where F = F function (sf::StreamingFunction)(args...; kwargs...) @nospecialize sf args kwargs result = nothing - thunk_id = tid() + thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + # FIXME: Remove when scheduler is distributed uid = remotecall_fetch(1, thunk_id) do thunk_id lock(Sch.EAGER_ID_MAP) do id_map for (uid, otid) in id_map @@ -293,13 +299,13 @@ function (sf::StreamingFunction)(args...; kwargs...) end end for stream in streams - @dagdebug nothing :stream_close "[$uid] dropping waiter" + @dagdebug thunk_id :stream "[$uid] dropping waiter" remove_waiters!(stream, uid) - @dagdebug nothing :stream_close "[$uid] dropped waiter" + @dagdebug thunk_id :stream "[$uid] dropped waiter" end # Ensure downstream tasks also terminate - @dagdebug nothing :stream_close "[$uid] closed stream" + @dagdebug thunk_id :stream "[$uid] closed stream" close(sf.stream) end end diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 9a9d24167..1e2b625bd 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -2,7 +2,8 @@ function istask end function task_id end const DAGDEBUG_CATEGORIES = Symbol[:global, :submit, :schedule, :scope, - :take, :execute, :move, :processor, :cancel] + :take, :execute, :move, :processor, :cancel, + :stream] macro dagdebug(thunk, category, msg, args...) cat_sym = category.value @gensym id From 1be1a4115776ecaf3aa0cf331211fae04f0315e8 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 10 Apr 2024 10:33:05 -0700 Subject: [PATCH 10/56] streaming: Add tests --- test/runtests.jl | 2 +- test/streaming.jl | 94 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 test/streaming.jl diff --git a/test/runtests.jl b/test/runtests.jl index 04871f6b9..6c2114d85 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ tests = [ ("Mutation", "mutation.jl"), ("Task Queues", "task-queues.jl"), ("Datadeps", "datadeps.jl"), + ("Streaming", "streaming.jl"), ("Domain Utilities", "domain.jl"), ("Array - Allocation", "array/allocation.jl"), ("Array - Indexing", "array/indexing.jl"), @@ -85,7 +86,6 @@ else @info "Running all tests" end - using Distributed if additional_workers > 0 # We put this inside a branch because addprocs() takes a minimum of 1s to diff --git a/test/streaming.jl b/test/streaming.jl new file mode 100644 index 000000000..00a4be3f3 --- /dev/null +++ b/test/streaming.jl @@ -0,0 +1,94 @@ +@everywhere ENV["JULIA_DEBUG"] = "Dagger" + +@everywhere function rand_finite() + x = rand() + if x < 0.1 + return Dagger.finish_stream(x) + end + return x +end +function catch_interrupt(f) + try + f() + catch err + if err isa Dagger.ThunkFailedException && err.ex isa InterruptException + return + elseif err isa Dagger.Sch.SchedulingException + return + end + rethrow(err) + end +end +function test_finishes(f, message::String; ignore_timeout=false) + t = @eval Threads.@spawn @testset $message catch_interrupt($f) + if timedwait(()->istaskdone(t), 10) == :timed_out + if !ignore_timeout + @warn "Testing task timed out: $message" + end + Dagger.cancel!(;halt_sch=true, force=true) + fetch(Dagger.@spawn 1+1) + return false + end + return true +end +@testset "Basics" begin + @test test_finishes("Single task") do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn rand_finite() + end + @test fetch(x) === nothing + end + + @test !test_finishes("Single task running forever"; ignore_timeout=true) do + local x + Dagger.spawn_streaming() do + x = Dagger.spawn() do + y = rand() + sleep(1) + return y + end + end + fetch(x) + end + + @test test_finishes("Two tasks (sequential)") do + local x, y + @warn "\n\n\nStart streaming\n\n\n" + Dagger.spawn_streaming() do + x = Dagger.@spawn rand_finite() + y = Dagger.@spawn x+1 + end + @test fetch(x) === nothing + @test_throws Dagger.ThunkFailedException fetch(y) + end + + # TODO: Two tasks (parallel) + + # TODO: Three tasks (2 -> 1) and (1 -> 2) + # TODO: Four tasks (diamond) + + # TODO: With pass-through/Without result + # TODO: With pass-through/With result + # TODO: Without pass-through/Without result + + @test test_finishes("Without pass-through/With result") do + local x + Dagger.spawn_streaming() do + x = Dagger.spawn() do + x = rand() + if x < 0.1 + return Dagger.finish_stream(x; result=123) + end + return x + end + end + @test fetch(x) == 123 + end +end +# TODO: Custom stream buffers/buffer amounts +# TODO: Cross-worker streaming +# TODO: Different stream element types (immutable and mutable) + +# TODO: Zero-allocation examples +# FIXME: Streaming across threads From f58404a7139f9055eaeeea4d6ba5c38a9ccc4d82 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 9 Apr 2024 03:44:35 +0200 Subject: [PATCH 11/56] Use procs() when initializing EAGER_CONTEXT Using `myid()` with `workers()` meant that when the context was initialized with a single worker the processor list would be: `[OSProc(1), OSProc(1)]`. `procs()` will always include PID 1 and any other workers, which is what we want. --- src/sch/eager.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index c964b722d..7ccdfbb29 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -6,7 +6,7 @@ const EAGER_STATE = Ref{Union{ComputeState,Nothing}}(nothing) function eager_context() if EAGER_CONTEXT[] === nothing - EAGER_CONTEXT[] = Context([myid(),workers()...]) + EAGER_CONTEXT[] = Context(procs()) end return EAGER_CONTEXT[] end From ee11f3fc9a1f3145f922781803619dd93804231a Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Thu, 2 May 2024 12:48:48 -0700 Subject: [PATCH 12/56] streaming: Fix concurrency issues --- src/stream-fetchers.jl | 44 +++++++++++++++------------ src/stream.jl | 68 +++++++++++++++++++++++------------------- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/src/stream-fetchers.jl b/src/stream-fetchers.jl index 5c6834d6e..cc4d942bb 100644 --- a/src/stream-fetchers.jl +++ b/src/stream-fetchers.jl @@ -1,25 +1,31 @@ struct RemoteFetcher end function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} - if store_ref.handle.owner == myid() - store = fetch(store_ref)::Store_remote - while !isfull(buffer) - value = take!(store, id)::T - put!(buffer, value) + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "fetching values" + @label fetch_values + # FIXME: Pass buffer free space + # TODO: It would be ideal if we could wait on store.lock, but get unlocked during migration + values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id) do store, id, T, Store_remote, thunk_id + if !isopen(store) + throw(InvalidStateException("Stream is closed", :closed)) end - else - thunk_id = STREAM_THUNK_ID[] - values = remotecall_fetch(store_ref.handle.owner, store_ref.handle, id, T, Store_remote) do store_ref, id, T, Store_remote - STREAM_THUNK_ID[] = thunk_id - store = MemPool.poolget(store_ref)::Store_remote - values = T[] - while !isempty(store, id) - value = take!(store, id)::T - push!(values, value) - end - return values - end::Vector{T} - for value in values - put!(buffer, value) + @dagdebug thunk_id :stream "trying to fetch values at $(myid())" + store::Store_remote + in_store = store + STREAM_THUNK_ID[] = thunk_id + values = T[] + while !isempty(store, id) + value = take!(store, id)::T + push!(values, value) end + return values + end::Vector{T} + if isempty(values) + @goto fetch_values + end + + @dagdebug thunk_id :stream "fetched $(length(values)) values" + for value in values + put!(buffer, value) end end diff --git a/src/stream.jl b/src/stream.jl index 5702cc522..17f47e95f 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -19,16 +19,15 @@ function tid_to_uid(thunk_id) end function Base.put!(store::StreamStore{T,B}, value) where {T,B} thunk_id = STREAM_THUNK_ID[] - uid = tid_to_uid(thunk_id) @lock store.lock begin if !isopen(store) - @dagdebug thunk_id :stream "[$uid] closed!" + @dagdebug thunk_id :stream "closed!" throw(InvalidStateException("Stream is closed", :closed)) end - @dagdebug thunk_id :stream "[$uid] adding $value" + @dagdebug thunk_id :stream "adding $value" for buffer in values(store.buffers) while isfull(buffer) - @dagdebug thunk_id :stream "[$uid] buffer full, waiting" + @dagdebug thunk_id :stream "buffer full, waiting" wait(store.lock) end put!(buffer, value) @@ -38,16 +37,15 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} end function Base.take!(store::StreamStore, id::UInt) thunk_id = STREAM_THUNK_ID[] - uid = tid_to_uid(thunk_id) @lock store.lock begin buffer = store.buffers[id] while isempty(buffer) && isopen(store, id) - @dagdebug thunk_id :stream "[$uid] no elements, not taking" + @dagdebug thunk_id :stream "no elements, not taking" wait(store.lock) end - @dagdebug thunk_id :stream "[$uid] wait finished" + @dagdebug thunk_id :stream "wait finished" if !isopen(store, id) - @dagdebug thunk_id :stream "[$uid] closed!" + @dagdebug thunk_id :stream "closed!" throw(InvalidStateException("Stream is closed", :closed)) end unlock(store.lock) @@ -56,7 +54,7 @@ function Base.take!(store::StreamStore, id::UInt) finally lock(store.lock) end - @dagdebug thunk_id :stream "[$uid] value accepted" + @dagdebug thunk_id :stream "value accepted" notify(store.lock) return value end @@ -129,46 +127,53 @@ function Base.take!(stream::Stream{T,B}, id::UInt) where {T,B} return take!(stream.input_buffer) end function Base.isopen(stream::Stream, id::UInt)::Bool - return remotecall_fetch(stream.store_ref.handle.owner, stream.store_ref.handle) do ref - return isopen(MemPool.poolget(ref)::StreamStore, id) + return MemPool.access_ref(stream.store_ref.handle, id) do store, id + return isopen(store::StreamStore, id) end end function Base.close(stream::Stream) - remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref - close(MemPool.poolget(ref)::StreamStore) + MemPool.access_ref(stream.store_ref.handle) do store + close(store::StreamStore) + return end + return end function add_waiters!(stream::Stream, waiters::Vector{Int}) - remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref - add_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + add_waiters!(store::StreamStore, waiters) + return end + return end add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream::Stream, Int[waiter]) function remove_waiters!(stream::Stream, waiters::Vector{Int}) - remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref - remove_waiters!(MemPool.poolget(ref)::StreamStore, waiters) + MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters + remove_waiters!(store::StreamStore, waiters) + return end + return end remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream::Stream, Int[waiter]) function migrate_stream!(stream::Stream, w::Integer=myid()) - if !isdefined(MemPool, :migrate!) - @warn "MemPool migration support not enabled!\nPerformance may be degraded" maxlog=1 - return - end - # Perform migration of the StreamStore # MemPool will block access to the new ref until the migration completes + # FIXME: Do this with MemPool.access_ref, in case stream was already migrated if stream.store_ref.handle.owner != w - # Take lock to prevent any further modifications - # N.B. Serialization automatically unlocks - remotecall_wait(stream.store_ref.handle.owner, stream.store_ref.handle) do ref - lock((MemPool.poolget(ref)::StreamStore).lock) + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; pre_migration=store->begin + # Lock store to prevent any further modifications + # N.B. Serialization automatically unlocks the migrated copy + lock((store::StreamStore).lock) + end, post_migration=store->begin + # Unlock the store + # FIXME: Indicate to all waiters that this store is dead + unlock((store::StreamStore).lock) + end) + if w == myid() + stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) end - - MemPool.migrate!(stream.store_ref.handle, w) end end @@ -272,6 +277,7 @@ function (sf::StreamingFunction)(args...; kwargs...) # Migrate our output stream to this worker if sf.stream isa Stream migrate_stream!(sf.stream) + @dagdebug thunk_id :stream "Migration complete" end try @@ -299,13 +305,13 @@ function (sf::StreamingFunction)(args...; kwargs...) end end for stream in streams - @dagdebug thunk_id :stream "[$uid] dropping waiter" + @dagdebug thunk_id :stream "dropping waiter" remove_waiters!(stream, uid) - @dagdebug thunk_id :stream "[$uid] dropped waiter" + @dagdebug thunk_id :stream "dropped waiter" end # Ensure downstream tasks also terminate - @dagdebug thunk_id :stream "[$uid] closed stream" + @dagdebug thunk_id :stream "closed stream" close(sf.stream) end end From 0d70835f6816b45b31f110e47b1960f4ae8edb40 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Fri, 24 May 2024 18:29:48 +0200 Subject: [PATCH 13/56] Add a --verbose option to runtests.jl This is a bit nicer than commenting/uncommenting a line in the code. --- test/runtests.jl | 7 +++++++ test/streaming.jl | 2 -- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 6c2114d85..a4863c791 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -52,6 +52,9 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ arg_type = Int default = additional_workers help = "How many additional workers to launch" + "-v", "--verbose" + action = :store_true + help = "Run the tests with debug logs from Dagger" end end @@ -81,6 +84,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ parsed_args["simulate"] && exit(0) additional_workers = parsed_args["procs"] + + if parsed_args["verbose"] + ENV["JULIA_DEBUG"] = "Dagger" + end else to_test = all_test_names @info "Running all tests" diff --git a/test/streaming.jl b/test/streaming.jl index 00a4be3f3..39df7b4d6 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -1,5 +1,3 @@ -@everywhere ENV["JULIA_DEBUG"] = "Dagger" - @everywhere function rand_finite() x = rand() if x < 0.1 From 27392f0637d5f3eb32f27b44eca62d02b97b5ac2 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Thu, 16 May 2024 16:38:51 +0200 Subject: [PATCH 14/56] Ensure that stream_fetch_values!() yields in its loop Otherwise it may spin (see comments for details). Also refactored it into a while-loop instead of using a @goto. --- src/stream-fetchers.jl | 46 ++++++++++++++++++++++++------------------ 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/src/stream-fetchers.jl b/src/stream-fetchers.jl index cc4d942bb..0a738479d 100644 --- a/src/stream-fetchers.jl +++ b/src/stream-fetchers.jl @@ -2,26 +2,32 @@ struct RemoteFetcher end function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} thunk_id = STREAM_THUNK_ID[] @dagdebug thunk_id :stream "fetching values" - @label fetch_values - # FIXME: Pass buffer free space - # TODO: It would be ideal if we could wait on store.lock, but get unlocked during migration - values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id) do store, id, T, Store_remote, thunk_id - if !isopen(store) - throw(InvalidStateException("Stream is closed", :closed)) - end - @dagdebug thunk_id :stream "trying to fetch values at $(myid())" - store::Store_remote - in_store = store - STREAM_THUNK_ID[] = thunk_id - values = T[] - while !isempty(store, id) - value = take!(store, id)::T - push!(values, value) - end - return values - end::Vector{T} - if isempty(values) - @goto fetch_values + + values = T[] + while isempty(values) + # FIXME: Pass buffer free space + # TODO: It would be ideal if we could wait on store.lock, but get unlocked during migration + values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id) do store, id, T, Store_remote, thunk_id + if !isopen(store) + throw(InvalidStateException("Stream is closed", :closed)) + end + @dagdebug thunk_id :stream "trying to fetch values at $(myid())" + store::Store_remote + in_store = store + STREAM_THUNK_ID[] = thunk_id + values = T[] + while !isempty(store, id) + value = take!(store, id)::T + push!(values, value) + end + return values + end::Vector{T} + + # We explicitly yield in the loop to allow other tasks to run. This + # matters on single-threaded instances because MemPool.access_ref() + # might not yield when accessing data locally, which can cause this loop + # to spin forever. + yield() end @dagdebug thunk_id :stream "fetched $(length(values)) values" From 717ceb7b59262555eb7c864652f2d03923332570 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 25 May 2024 00:57:43 +0200 Subject: [PATCH 15/56] Add support for limiting the evaluations of a streaming DAG This is useful for testing and benchmarking. --- src/sch/Sch.jl | 8 +++++--- src/stream.jl | 15 +++++++++------ test/streaming.jl | 18 ++++++++++++++++++ 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 12d259352..9da41e618 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -253,9 +253,11 @@ end Combine `SchedulerOptions` and `ThunkOptions` into a new `ThunkOptions`. """ function Base.merge(sopts::SchedulerOptions, topts::ThunkOptions) - single = topts.single !== nothing ? topts.single : sopts.single - allow_errors = topts.allow_errors !== nothing ? topts.allow_errors : sopts.allow_errors - proclist = topts.proclist !== nothing ? topts.proclist : sopts.proclist + select_option = (sopt, topt) -> isnothing(topt) ? sopt : topt + + single = select_option(sopts.single, topts.single) + allow_errors = select_option(sopts.allow_errors, topts.allow_errors) + proclist = select_option(sopts.proclist, topts.proclist) ThunkOptions(single, proclist, topts.time_util, diff --git a/src/stream.jl b/src/stream.jl index 17f47e95f..691691fb5 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -208,12 +208,10 @@ function initialize_streaming!(self_streams, spec, task) end output_buffer = get(spec.options, :stream_output_buffer, ProcessRingBuffer) stream = Stream{T,output_buffer}(output_buffer_amount) - spec.options = NamedTuple(filter(opt -> opt[1] != :stream_output_buffer && - opt[1] != :stream_output_buffer_amount, - Base.pairs(spec.options))) self_streams[task.uid] = stream - spec.f = StreamingFunction(spec.f, stream) + max_evals = get(spec.options, :stream_max_evals, -1) + spec.f = StreamingFunction(spec.f, stream, max_evals) spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) # Register Stream globally @@ -256,6 +254,7 @@ const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) struct StreamingFunction{F, S} f::F stream::S + max_evals::Int end chunktype(sf::StreamingFunction{F}) where F = F function (sf::StreamingFunction)(args...; kwargs...) @@ -319,7 +318,9 @@ end function stream!(sf::StreamingFunction, uid, args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) f = move(thunk_processor(), sf.f) - while true + counter = 0 + + while sf.max_evals < 0 || counter < sf.max_evals # Get values from Stream args/kwargs stream_args = _stream_take_values!(args, uid) stream_kwarg_values = _stream_take_values!(kwarg_values, uid) @@ -327,6 +328,7 @@ function stream!(sf::StreamingFunction, uid, # Run a single cycle of f stream_result = f(stream_args...; stream_kwargs...) + counter += 1 # Exit streaming on graceful request if stream_result isa FinishStream @@ -412,7 +414,8 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) # Filter out all streaming options to_filter = (:stream_input_buffer, :stream_input_buffer_amount, - :stream_output_buffer, :stream_output_buffer_amount) + :stream_output_buffer, :stream_output_buffer_amount, + :stream_max_evals) spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), Base.pairs(spec.options))) if haskey(spec.options, :propagates) diff --git a/test/streaming.jl b/test/streaming.jl index 39df7b4d6..f87f4aec5 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -5,6 +5,7 @@ end return x end + function catch_interrupt(f) try f() @@ -17,6 +18,7 @@ function catch_interrupt(f) rethrow(err) end end + function test_finishes(f, message::String; ignore_timeout=false) t = @eval Threads.@spawn @testset $message catch_interrupt($f) if timedwait(()->istaskdone(t), 10) == :timed_out @@ -29,6 +31,7 @@ function test_finishes(f, message::String; ignore_timeout=false) end return true end + @testset "Basics" begin @test test_finishes("Single task") do local x @@ -50,6 +53,21 @@ end fetch(x) end + @test test_finishes("Max evaluations") do + counter = 0 + function incrementer() + counter += 1 + end + + x = Dagger.with_options(; stream_max_evals=10) do + Dagger.spawn_streaming() do + Dagger.@spawn incrementer() + end + end + wait(x) + @test counter == 10 + end + @test test_finishes("Two tasks (sequential)") do local x, y @warn "\n\n\nStart streaming\n\n\n" From a0c0805e73765bdb5e7343e54559b1a213096ab6 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 25 May 2024 01:30:08 +0200 Subject: [PATCH 16/56] Dev the migration-helper branch of MemPool.jl This is currently necessary for the streaming branch, we'll have to change this later but it's good to have CI working for now. --- .buildkite/pipeline.yml | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index d3fc08f12..0f7ef011b 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -5,7 +5,8 @@ sandbox_capable: "true" os: linux arch: x86_64 - command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'" + command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.add(; url=\"https://github.com/JuliaData/MemPool.jl\", rev=\"jps/migration-helper\")'" + .bench: &bench if: build.message =~ /\[run benchmarks\]/ agents: @@ -14,6 +15,7 @@ os: linux arch: x86_64 num_cpus: 16 + steps: - label: Julia 1.9 timeout_in_minutes: 90 @@ -25,6 +27,7 @@ steps: julia_args: "--threads=1" - JuliaCI/julia-coverage#v1: codecov: true + - label: Julia 1.10 timeout_in_minutes: 90 <<: *test @@ -35,6 +38,7 @@ steps: julia_args: "--threads=1" - JuliaCI/julia-coverage#v1: codecov: true + - label: Julia nightly timeout_in_minutes: 90 <<: *test @@ -77,6 +81,7 @@ steps: - JuliaCI/julia-coverage#v1: codecov: true command: "julia -e 'using Pkg; Pkg.develop(;path=pwd()); Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.develop(;path=\"lib/DaggerWebDash\"); include(\"lib/DaggerWebDash/test/runtests.jl\")'" + - label: Benchmarks timeout_in_minutes: 120 <<: *bench @@ -93,6 +98,7 @@ steps: BENCHMARK_SCALE: "5:5:50" artifacts: - benchmarks/result* + - label: DTables.jl stability test timeout_in_minutes: 20 plugins: From f4d709cf7d600876a2474a46f8b4800d33eccaa0 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 25 Jun 2024 16:14:38 +0200 Subject: [PATCH 17/56] Minor style cleanup --- src/stream.jl | 48 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index 691691fb5..7b822e630 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -8,6 +8,7 @@ mutable struct StreamStore{T,B} new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount, true, Threads.Condition()) end + function tid_to_uid(thunk_id) lock(Sch.EAGER_ID_MAP) do id_map for (uid, otid) in id_map @@ -17,6 +18,7 @@ function tid_to_uid(thunk_id) end end end + function Base.put!(store::StreamStore{T,B}, value) where {T,B} thunk_id = STREAM_THUNK_ID[] @lock store.lock begin @@ -35,6 +37,7 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} notify(store.lock) end end + function Base.take!(store::StreamStore, id::UInt) thunk_id = STREAM_THUNK_ID[] @lock store.lock begin @@ -59,11 +62,18 @@ function Base.take!(store::StreamStore, id::UInt) return value end end + Base.isempty(store::StreamStore, id::UInt) = isempty(store.buffers[id]) isfull(store::StreamStore, id::UInt) = isfull(store.buffers[id]) + "Returns whether the store is actively open. Only check this when deciding if new values can be pushed." Base.isopen(store::StreamStore) = store.open -"Returns whether the store is actively open, or if closing, still has remaining messages for `id`. Only check this when deciding if existing values can be taken." + +""" +Returns whether the store is actively open, or if closing, still has remaining +messages for `id`. Only check this when deciding if existing values can be +taken. +""" function Base.isopen(store::StreamStore, id::UInt) @lock store.lock begin if !isempty(store.buffers[id]) @@ -72,12 +82,14 @@ function Base.isopen(store::StreamStore, id::UInt) return store.open end end + function Base.close(store::StreamStore) if store.open store.open = false @lock store.lock notify(store.lock) end end + function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Int}) where {T,B} @lock store.lock begin for w in waiters @@ -88,6 +100,7 @@ function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Int}) where {T,B} notify(store.lock) end end + function remove_waiters!(store::StreamStore, waiters::Vector{Int}) @lock store.lock begin for w in waiters @@ -115,22 +128,25 @@ mutable struct Stream{T,B} return new{T,B}(nothing, stream.store_ref, nothing, buffer_amount) end end + function initialize_input_stream!(stream::Stream{T,B}) where {T,B} stream.input_buffer = initialize_stream_buffer(B, T, stream.buffer_amount) end -Base.put!(stream::Stream, @nospecialize(value)) = - put!(stream.store, value) +Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) + function Base.take!(stream::Stream{T,B}, id::UInt) where {T,B} # FIXME: Make remote fetcher configurable stream_fetch_values!(RemoteFetcher, T, stream.store_ref, stream.input_buffer, id) return take!(stream.input_buffer) end + function Base.isopen(stream::Stream, id::UInt)::Bool return MemPool.access_ref(stream.store_ref.handle, id) do store, id return isopen(store::StreamStore, id) end end + function Base.close(stream::Stream) MemPool.access_ref(stream.store_ref.handle) do store close(store::StreamStore) @@ -138,6 +154,7 @@ function Base.close(stream::Stream) end return end + function add_waiters!(stream::Stream, waiters::Vector{Int}) MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters add_waiters!(store::StreamStore, waiters) @@ -145,8 +162,9 @@ function add_waiters!(stream::Stream, waiters::Vector{Int}) end return end -add_waiters!(stream::Stream, waiter::Integer) = - add_waiters!(stream::Stream, Int[waiter]) + +add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream, Int[waiter]) + function remove_waiters!(stream::Stream, waiters::Vector{Int}) MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters remove_waiters!(store::StreamStore, waiters) @@ -154,8 +172,8 @@ function remove_waiters!(stream::Stream, waiters::Vector{Int}) end return end -remove_waiters!(stream::Stream, waiter::Integer) = - remove_waiters!(stream::Stream, Int[waiter]) + +remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream, Int[waiter]) function migrate_stream!(stream::Stream, w::Integer=myid()) # Perform migration of the StreamStore @@ -188,12 +206,14 @@ function enqueue!(queue::StreamingTaskQueue, spec::Pair{DTaskSpec,DTask}) push!(queue.tasks, spec) initialize_streaming!(queue.self_streams, spec...) end + function enqueue!(queue::StreamingTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}) append!(queue.tasks, specs) for (spec, task) in specs initialize_streaming!(queue.self_streams, spec, task) end end + function initialize_streaming!(self_streams, spec, task) if !isa(spec.f, StreamingFunction) # Adapt called function for streaming and generate output Streams @@ -237,10 +257,10 @@ struct FinishStream{T,R} value::Union{Some{T},Nothing} result::R end -finish_stream(value::T; result::R=nothing) where {T,R} = - FinishStream{T,R}(Some{T}(value), result) -finish_stream(; result::R=nothing) where R = - FinishStream{Union{},R}(nothing, result) + +finish_stream(value::T; result::R=nothing) where {T,R} = FinishStream{T,R}(Some{T}(value), result) + +finish_stream(; result::R=nothing) where R = FinishStream{Union{},R}(nothing, result) function cancel_stream!(t::DTask) stream = task_to_stream(t.uid) @@ -256,7 +276,9 @@ struct StreamingFunction{F, S} stream::S max_evals::Int end + chunktype(sf::StreamingFunction{F}) where F = F + function (sf::StreamingFunction)(args...; kwargs...) @nospecialize sf args kwargs result = nothing @@ -314,6 +336,7 @@ function (sf::StreamingFunction)(args...; kwargs...) close(sf.stream) end end + # N.B We specialize to minimize/eliminate allocations function stream!(sf::StreamingFunction, uid, args::Tuple, kwarg_names::Tuple, kwarg_values::Tuple) @@ -343,6 +366,7 @@ function stream!(sf::StreamingFunction, uid, put!(sf.stream, stream_result) end end + function _stream_take_values!(args, uid) return ntuple(length(args)) do idx arg = args[idx] @@ -353,12 +377,14 @@ function _stream_take_values!(args, uid) end end end + @inline @generated function _stream_namedtuple(kwarg_names::Tuple, stream_kwarg_values::Tuple) name_ex = Expr(:tuple, map(name->QuoteNode(name.parameters[1]), kwarg_names.parameters)...) NT = :(NamedTuple{$name_ex,$stream_kwarg_values}) return :($NT(stream_kwarg_values)) end + initialize_stream_buffer(B, T, buffer_amount) = B{T}(buffer_amount) const EAGER_THUNK_STREAMS = LockedObject(Dict{UInt,Any}()) From a7bdfdb5568f181917757478c1248133eadac4c2 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Mon, 19 Aug 2024 00:39:19 +0200 Subject: [PATCH 18/56] Use `DTaskFailedException` and increase the default timeout --- test/streaming.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/streaming.jl b/test/streaming.jl index f87f4aec5..f142b4962 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -10,18 +10,18 @@ function catch_interrupt(f) try f() catch err - if err isa Dagger.ThunkFailedException && err.ex isa InterruptException + if err isa Dagger.DTaskFailedException && err.ex isa InterruptException return elseif err isa Dagger.Sch.SchedulingException return end - rethrow(err) + rethrow() end end function test_finishes(f, message::String; ignore_timeout=false) t = @eval Threads.@spawn @testset $message catch_interrupt($f) - if timedwait(()->istaskdone(t), 10) == :timed_out + if timedwait(()->istaskdone(t), 20) == :timed_out if !ignore_timeout @warn "Testing task timed out: $message" end @@ -76,7 +76,7 @@ end y = Dagger.@spawn x+1 end @test fetch(x) === nothing - @test_throws Dagger.ThunkFailedException fetch(y) + @test_throws Dagger.DTaskFailedException fetch(y) end # TODO: Two tasks (parallel) From 770a241f842986cbdb6047875d9eb3862e0f9138 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Mon, 19 Aug 2024 00:43:33 +0200 Subject: [PATCH 19/56] Initial support for robustly migrating streaming tasks This works by converting the output buffers into a safely-serializeable container and sending that to the new node. --- src/stream-buffers.jl | 12 ++++++++++++ src/stream.jl | 32 +++++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl index 753f8c11c..e73d9990c 100644 --- a/src/stream-buffers.jl +++ b/src/stream-buffers.jl @@ -81,6 +81,18 @@ function Base.take!(rb::ProcessRingBuffer) return rb.buffer[to_read_idx] end +""" +`take!()` all the elements from a buffer and put them in a `Vector`. +""" +function collect!(rb::ProcessRingBuffer{T}) where T + output = Vector{T}(undef, rb.count) + for i in 1:rb.count + output[i] = take!(rb) + end + + return output +end + #= TODO "A server-local ring buffer backed by shared-memory." mutable struct ServerRingBuffer{T} diff --git a/src/stream.jl b/src/stream.jl index 7b822e630..9bbba972c 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -180,15 +180,29 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) # MemPool will block access to the new ref until the migration completes # FIXME: Do this with MemPool.access_ref, in case stream was already migrated if stream.store_ref.handle.owner != w - new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; pre_migration=store->begin - # Lock store to prevent any further modifications - # N.B. Serialization automatically unlocks the migrated copy - lock((store::StreamStore).lock) - end, post_migration=store->begin - # Unlock the store - # FIXME: Indicate to all waiters that this store is dead - unlock((store::StreamStore).lock) - end) + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; + pre_migration=store->begin + # Lock store to prevent any further modifications + # N.B. Serialization automatically unlocks the migrated copy + lock((store::StreamStore).lock) + + # Return the serializeable unsent outputs. We can't send the + # buffers themselves because they may be mmap'ed or something. + Dict(id => collect!(buffer) for (id, buffer) in store.buffers) + end, + dest_post_migration=(store, unsent_outputs)->begin + # Initialize the StreamStore on the destination with the unsent outputs. + for (id, outputs) in unsent_outputs + for item in outputs + put!(store.buffers[id], item) + end + end + end, + post_migration=store->begin + # Unlock the store + # FIXME: Indicate to all waiters that this store is dead + unlock((store::StreamStore).lock) + end) if w == myid() stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) end From 0b968d62b0eb59cded8af41f1831aab7ed1e5834 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 20 Aug 2024 22:35:25 +0200 Subject: [PATCH 20/56] Inherit the top-level testsets in the streaming tests This makes them be displayed as if they were running under the original task. --- test/streaming.jl | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/test/streaming.jl b/test/streaming.jl index f142b4962..4c3d1f7a6 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -6,21 +6,31 @@ return x end -function catch_interrupt(f) - try - f() - catch err - if err isa Dagger.DTaskFailedException && err.ex isa InterruptException - return - elseif err isa Dagger.Sch.SchedulingException - return +function test_in_task(f, message, parent_testsets) + task_local_storage(:__BASETESTNEXT__, parent_testsets) + + @testset "$message" begin + try + f() + catch err + if err isa Dagger.DTaskFailedException && err.ex isa InterruptException + return + elseif err isa Dagger.Sch.SchedulingException + return + end + rethrow() end - rethrow() end end function test_finishes(f, message::String; ignore_timeout=false) - t = @eval Threads.@spawn @testset $message catch_interrupt($f) + # We sneakily pass a magic variable from the current TLS into the new + # task. It's used by the Test stdlib to hold a list of the current + # testsets, so we need it to be able to record the tests from the new + # task in the original testset that we're currently running under. + parent_testsets = get(task_local_storage(), :__BASETESTNEXT__, []) + t = Threads.@spawn test_in_task(f, message, parent_testsets) + if timedwait(()->istaskdone(t), 20) == :timed_out if !ignore_timeout @warn "Testing task timed out: $message" From 0268b7e3610792de9ec8a31023ff771749196908 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Tue, 20 Aug 2024 23:01:06 +0200 Subject: [PATCH 21/56] Replace `rand_finite()` with a deterministic `Producer` functor This makes the tests a little easier to control. --- test/streaming.jl | 60 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/test/streaming.jl b/test/streaming.jl index 4c3d1f7a6..623600667 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -1,9 +1,41 @@ -@everywhere function rand_finite() - x = rand() - if x < 0.1 - return Dagger.finish_stream(x) +@everywhere begin + """ + A functor to produce a certain number of outputs. + + Note: always use this like `Dagger.spawn(Producer())` rather than + `Dagger.@spawn Producer()`. The macro form will just create fresh objects + every time and stream forever. + """ + mutable struct Producer + N::Union{Int, Float64} + count::Int + mailbox::Union{RemoteChannel, Nothing} + + Producer(N=5, mailbox=nothing) = new(N, 0, mailbox) + end + + function (self::Producer)() + self.count += 1 + + # Sleeping will make the loop yield (handy for single-threaded + # processes), and stops Dagger from being too spammy in debug mode. + if self.N == Inf + sleep(0.1) + end + + # Check if there are any instructions for us + if !isnothing(self.mailbox) && isready(self.mailbox) + msg = take!(self.mailbox) + if msg === :exit + put!(self.mailbox, self.count) + return Dagger.finish_stream(self.count) + else + error("Unrecognized Producer message: $msg") + end + end + + self.count >= self.N ? Dagger.finish_stream(self.count) : self.count end - return x end function test_in_task(f, message, parent_testsets) @@ -43,10 +75,12 @@ function test_finishes(f, message::String; ignore_timeout=false) end @testset "Basics" begin + master_scope = Dagger.scope(worker=myid()) + @test test_finishes("Single task") do local x Dagger.spawn_streaming() do - x = Dagger.@spawn rand_finite() + x = Dagger.spawn(Producer()) end @test fetch(x) === nothing end @@ -64,25 +98,23 @@ end end @test test_finishes("Max evaluations") do - counter = 0 - function incrementer() - counter += 1 - end - + producer = Producer(20) x = Dagger.with_options(; stream_max_evals=10) do Dagger.spawn_streaming() do - Dagger.@spawn incrementer() + # Spawn on the same node so we can access the local `producer` variable + Dagger.spawn(producer, Dagger.Options(; scope=master_scope)) end end + wait(x) - @test counter == 10 + @test producer.count == 10 end @test test_finishes("Two tasks (sequential)") do local x, y @warn "\n\n\nStart streaming\n\n\n" Dagger.spawn_streaming() do - x = Dagger.@spawn rand_finite() + x = Dagger.spawn(Producer()) y = Dagger.@spawn x+1 end @test fetch(x) === nothing From 0dbdab3d61c98fa296c5bba88cdd5f4981a51fa3 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Fri, 6 Sep 2024 16:15:06 +0200 Subject: [PATCH 22/56] fixup! Initial support for robustly migrating streaming tasks --- src/stream.jl | 80 +++++++++++++++++++++++++++++++++++------------ test/streaming.jl | 40 ++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 20 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index 9bbba972c..e32ce91d6 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -3,10 +3,11 @@ mutable struct StreamStore{T,B} buffers::Dict{Int,B} buffer_amount::Int open::Bool + migrating::Bool lock::Threads.Condition StreamStore{T,B}(buffer_amount::Integer) where {T,B} = new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount, - true, Threads.Condition()) + true, false, Threads.Condition()) end function tid_to_uid(thunk_id) @@ -175,11 +176,24 @@ end remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream, Int[waiter]) +function migrate_streamingfunction!(sf::StreamingFunction, w::Integer=myid()) + current_worker = sf.stream.store_ref.handle.owner + if myid() != current_worker + return remotecall_fetch(migrate_streamingfunction!, current_worker, sf, w) + end + + sf.stream.store.migrating = true + @lock sf.status_event wait(sf.status_event) # Wait for the streaming function to finish +end + function migrate_stream!(stream::Stream, w::Integer=myid()) # Perform migration of the StreamStore # MemPool will block access to the new ref until the migration completes # FIXME: Do this with MemPool.access_ref, in case stream was already migrated if stream.store_ref.handle.owner != w + thunk_id = STREAM_THUNK_ID[] + @dagdebug thunk_id :stream "Beginning migration..." + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; pre_migration=store->begin # Lock store to prevent any further modifications @@ -197,6 +211,9 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) put!(store.buffers[id], item) end end + + # Ensure that the 'migrating' flag is not set + store.migrating = false end, post_migration=store->begin # Unlock the store @@ -206,6 +223,8 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) if w == myid() stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) end + + @dagdebug thunk_id :stream "Migration complete" end end @@ -289,11 +308,25 @@ struct StreamingFunction{F, S} f::F stream::S max_evals::Int + status_event::Threads.Event + migration_complete::Threads.Event end chunktype(sf::StreamingFunction{F}) where F = F function (sf::StreamingFunction)(args...; kwargs...) + ret = :migrating + while ret === :migrating + worker_id = sf.stream.store_ref.handle.owner + ret = if worker_id == myid() + _run_streamingfunction(args...; kwargs...) + else + remotecall_fetch(_run_streamingfunction, worker_id, args...; kwargs...) + end + end +end + +function _run_streamingfunction(args...; kwargs...) @nospecialize sf args kwargs result = nothing thunk_id = Sch.sch_handle().thunk_id.id @@ -309,10 +342,9 @@ function (sf::StreamingFunction)(args...; kwargs...) end end - # Migrate our output stream to this worker + # Migrate our output stream store to this worker if sf.stream isa Stream migrate_stream!(sf.stream) - @dagdebug thunk_id :stream "Migration complete" end try @@ -327,27 +359,31 @@ function (sf::StreamingFunction)(args...; kwargs...) end return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) finally - # Remove ourself as a waiter for upstream Streams - streams = Set{Stream}() - for (idx, arg) in enumerate(args) - if arg isa Stream - push!(streams, arg) + if !sf.stream.store.migrated + # Remove ourself as a waiter for upstream Streams + streams = Set{Stream}() + for (idx, arg) in enumerate(args) + if arg isa Stream + push!(streams, arg) + end end - end - for (idx, (pos, arg)) in enumerate(kwargs) - if arg isa Stream - push!(streams, arg) + for (idx, (pos, arg)) in enumerate(kwargs) + if arg isa Stream + push!(streams, arg) + end end - end - for stream in streams - @dagdebug thunk_id :stream "dropping waiter" - remove_waiters!(stream, uid) - @dagdebug thunk_id :stream "dropped waiter" + for stream in streams + @dagdebug thunk_id :stream "dropping waiter" + remove_waiters!(stream, uid) + @dagdebug thunk_id :stream "dropped waiter" + end + + # Ensure downstream tasks also terminate + @dagdebug thunk_id :stream "closed stream" + close(sf.stream) end - # Ensure downstream tasks also terminate - @dagdebug thunk_id :stream "closed stream" - close(sf.stream) + notify(sf.status_event) end end @@ -358,6 +394,10 @@ function stream!(sf::StreamingFunction, uid, counter = 0 while sf.max_evals < 0 || counter < sf.max_evals + if sf.stream.store.migrating + return :migrating + end + # Get values from Stream args/kwargs stream_args = _stream_take_values!(args, uid) stream_kwarg_values = _stream_take_values!(kwarg_values, uid) diff --git a/test/streaming.jl b/test/streaming.jl index 623600667..87e0550ad 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -1,3 +1,5 @@ +import MemPool: access_ref + @everywhere begin """ A functor to produce a certain number of outputs. @@ -77,6 +79,44 @@ end @testset "Basics" begin master_scope = Dagger.scope(worker=myid()) + @test test_finishes("Migration") do + if nprocs() == 1 + @warn "Skipping migration test because it requires at least 1 extra worker" + return + end + + # Start streaming locally + mailbox = RemoteChannel() + producer = Producer(Inf, mailbox) + x = Dagger.spawn_streaming() do + Dagger.spawn(producer, Dagger.Options(; scope=master_scope)) + end + + # Wait for the stream to get started + while producer.count < 2 + sleep(0.1) + end + + # Migrate to another worker + access_ref(x.thunk_ref) do thunk + access_ref(thunk.f.handle) do streaming_function + Dagger.migrate_stream!(streaming_function.stream, workers()[1]) + end + end + + # Wait a bit for the stream to get started again on the other node + sleep(0.5) + + # Stop it + put!(mailbox, :exit) + fetch(x) + + final_count = take!(mailbox) + @info "Counts:" producer.count final_count + end + + return + @test test_finishes("Single task") do local x Dagger.spawn_streaming() do From 1cf99b8a8e81d1dfc0faeee904c610804dfeeb68 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 12:22:38 -0400 Subject: [PATCH 23/56] fixup! fixup! Initial support for robustly migrating streaming tasks --- src/stream.jl | 52 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index e32ce91d6..ce07eeaf3 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -176,6 +176,18 @@ end remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream, Int[waiter]) +struct StreamingFunction{F, S} + f::F + stream::S + max_evals::Int + + status_event::Threads.Event + migration_complete::Threads.Event + + StreamingFunction(f::F, stream::S, max_evals) where {F, S} = + new{F, S}(f, stream, max_evals, Threads.Event(), Threads.Event()) +end + function migrate_streamingfunction!(sf::StreamingFunction, w::Integer=myid()) current_worker = sf.stream.store_ref.handle.owner if myid() != current_worker @@ -304,30 +316,32 @@ end const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) -struct StreamingFunction{F, S} - f::F - stream::S - max_evals::Int - status_event::Threads.Event - migration_complete::Threads.Event -end - chunktype(sf::StreamingFunction{F}) where F = F +struct StreamMigrating end + function (sf::StreamingFunction)(args...; kwargs...) - ret = :migrating - while ret === :migrating - worker_id = sf.stream.store_ref.handle.owner - ret = if worker_id == myid() - _run_streamingfunction(args...; kwargs...) - else - remotecall_fetch(_run_streamingfunction, worker_id, args...; kwargs...) - end + thunk_id = Sch.sch_handle().thunk_id.id + @label start + @dagdebug nothing :stream "Starting StreamingFunction" + worker_id = sf.stream.store_ref.handle.owner + result = if worker_id == myid() + _run_streamingfunction(nothing, sf, args...; kwargs...) + else + tls = get_tls() + remotecall_fetch(_run_streamingfunction, worker_id, tls, sf, args...; kwargs...) end + if result === StreamMigrating() + @goto start + end + return result end -function _run_streamingfunction(args...; kwargs...) +function _run_streamingfunction(tls, sf, args...; kwargs...) @nospecialize sf args kwargs + if tls !== nothing + set_tls!(tls) + end result = nothing thunk_id = Sch.sch_handle().thunk_id.id STREAM_THUNK_ID[] = thunk_id @@ -359,7 +373,7 @@ function _run_streamingfunction(args...; kwargs...) end return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) finally - if !sf.stream.store.migrated + if !sf.stream.store.migrating # Remove ourself as a waiter for upstream Streams streams = Set{Stream}() for (idx, arg) in enumerate(args) @@ -395,7 +409,7 @@ function stream!(sf::StreamingFunction, uid, while sf.max_evals < 0 || counter < sf.max_evals if sf.stream.store.migrating - return :migrating + return StreamMigrating() end # Get values from Stream args/kwargs From 71ee8549aa429ae04b55cb9d4d861445d319536c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 22 May 2024 12:48:53 -0500 Subject: [PATCH 24/56] task-tls: Refactor into DTaskTLS struct --- src/Dagger.jl | 2 ++ src/array/indexing.jl | 2 -- src/sch/Sch.jl | 2 +- src/sch/dynamic.jl | 2 +- src/task-tls.jl | 43 ++++++++++++++++++++++--------------------- 5 files changed, 26 insertions(+), 25 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index 8f7d36cc7..cd081e65f 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -25,6 +25,8 @@ else end import TaskLocalValues: TaskLocalValue +import TaskLocalValues: TaskLocalValue + if !isdefined(Base, :get_extension) import Requires: @require end diff --git a/src/array/indexing.jl b/src/array/indexing.jl index 82f44fbff..69725eb7a 100644 --- a/src/array/indexing.jl +++ b/src/array/indexing.jl @@ -1,5 +1,3 @@ -import TaskLocalValues: TaskLocalValue - ### getindex struct GetIndex{T,N} <: ArrayOp{T,N} diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 9da41e618..2b41b1cd6 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1198,7 +1198,7 @@ function proc_states(f::Base.Callable, uid::UInt64) end end proc_states(f::Base.Callable) = - proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64) + proc_states(f, Dagger.get_tls().sch_uid) task_tid_for_processor(::Processor) = nothing task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid diff --git a/src/sch/dynamic.jl b/src/sch/dynamic.jl index e02085ee6..5b917fdb5 100644 --- a/src/sch/dynamic.jl +++ b/src/sch/dynamic.jl @@ -17,7 +17,7 @@ struct SchedulerHandle end "Gets the scheduler handle for the currently-executing thunk." -sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle +sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle "Thrown when the scheduler halts before finishing processing the DAG." struct SchedulerHaltedException <: Exception end diff --git a/src/task-tls.jl b/src/task-tls.jl index ea188e004..46fdbc9ca 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,41 +1,42 @@ # In-Thunk Helpers -""" - task_processor() +struct DTaskTLS + processor::Processor + sch_uid::UInt + sch_handle::Any # FIXME: SchedulerHandle + task_spec::Vector{Any} # FIXME: TaskSpec +end -Get the current processor executing the current Dagger task. -""" -task_processor() = task_local_storage(:_dagger_processor)::Processor -@deprecate thunk_processor() task_processor() +const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) """ - in_task() + in_task() -> Bool Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ -in_task() = haskey(task_local_storage(), :_dagger_sch_uid) +in_task() = DTASK_TLS[] !== nothing @deprecate in_thunk() in_task() """ - get_tls() + task_processor() -> Processor + +Get the current processor executing the current [`DTask`](@ref). +""" +task_processor() = get_tls().processor +@deprecate thunk_processor() task_processor() + +""" + get_tls() -> DTaskTLS -Gets all Dagger TLS variable as a `NamedTuple`. +Gets all Dagger TLS variable as a `DTaskTLS`. """ -get_tls() = ( - sch_uid=task_local_storage(:_dagger_sch_uid), - sch_handle=task_local_storage(:_dagger_sch_handle), - processor=task_processor(), - task_spec=task_local_storage(:_dagger_task_spec), -) +get_tls() = DTASK_TLS[]::DTaskTLS """ set_tls!(tls) -Sets all Dagger TLS variables from the `NamedTuple` `tls`. +Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. """ function set_tls!(tls) - task_local_storage(:_dagger_sch_uid, tls.sch_uid) - task_local_storage(:_dagger_sch_handle, tls.sch_handle) - task_local_storage(:_dagger_processor, tls.processor) - task_local_storage(:_dagger_task_spec, tls.task_spec) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec) end From 79ee021e16aec745325f3949550dba6d3445596c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 12:18:56 -0400 Subject: [PATCH 25/56] fixup! task-tls: Refactor into DTaskTLS struct --- src/task-tls.jl | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/task-tls.jl b/src/task-tls.jl index 46fdbc9ca..90fdfedb3 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -9,22 +9,6 @@ end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) -""" - in_task() -> Bool - -Returns `true` if currently executing in a [`DTask`](@ref), else `false`. -""" -in_task() = DTASK_TLS[] !== nothing -@deprecate in_thunk() in_task() - -""" - task_processor() -> Processor - -Get the current processor executing the current [`DTask`](@ref). -""" -task_processor() = get_tls().processor -@deprecate thunk_processor() task_processor() - """ get_tls() -> DTaskTLS @@ -40,3 +24,19 @@ Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedT function set_tls!(tls) DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec) end + +""" + in_task() -> Bool + +Returns `true` if currently executing in a [`DTask`](@ref), else `false`. +""" +in_task() = DTASK_TLS[] !== nothing +@deprecate in_thunk() in_task() + +""" + task_processor() -> Processor + +Get the current processor executing the current [`DTask`](@ref). +""" +task_processor() = get_tls().processor +@deprecate thunk_processor() task_processor() From 09e5826584fcebc9f83191df3bfb35a6e4faec54 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 12:21:13 -0400 Subject: [PATCH 26/56] cancellation: Add cancel token support --- src/Dagger.jl | 2 +- src/cancellation.jl | 17 ++++++++++++++++- src/sch/Sch.jl | 11 +++++++++++ src/task-tls.jl | 21 ++++++++++++++++++++- src/threadproc.jl | 3 ++- 5 files changed, 50 insertions(+), 4 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index cd081e65f..5fbd7e3fb 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -49,6 +49,7 @@ include("processor.jl") include("threadproc.jl") include("context.jl") include("utils/processors.jl") +include("cancellation.jl") include("task-tls.jl") include("scopes.jl") include("utils/scopes.jl") @@ -58,7 +59,6 @@ include("thunk.jl") include("submission.jl") include("chunks.jl") include("memory-spaces.jl") -include("cancellation.jl") # Task scheduling include("compute.jl") diff --git a/src/cancellation.jl b/src/cancellation.jl index c982fd20c..5387f101a 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,3 +1,17 @@ +# DTask-level cancellation + +struct CancelToken + cancelled::Base.RefValue{Bool} +end +CancelToken() = CancelToken(Ref(false)) +function cancel!(token::CancelToken) + token.cancelled[] = true +end + +const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing) + +# Global-level cancellation + """ cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false) @@ -80,11 +94,11 @@ function _cancel!(state, tid, force, halt_sch) Tf === typeof(Sch.eager_thunk) && continue istaskdone(task) && continue any_cancelled = true - @dagdebug tid :cancel "Cancelling running task ($Tf)" if force @dagdebug tid :cancel "Interrupting running task ($Tf)" Threads.@spawn Base.throwto(task, InterruptException()) else + @dagdebug tid :cancel "Cancelling running task ($Tf)" # Tell the processor to just drop this task task_occupancy = task_spec[4] time_util = task_spec[2] @@ -93,6 +107,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) + cancel!(istate.cancel_tokens[tid]) end end end diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 2b41b1cd6..48eba1b31 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1179,6 +1179,7 @@ struct ProcessorInternalState proc_occupancy::Base.RefValue{UInt32} time_pressure::Base.RefValue{UInt64} cancelled::Set{Int} + cancel_tokens::Dict{Int,Dagger.CancelToken} done::Base.RefValue{Bool} end struct ProcessorState @@ -1328,7 +1329,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Execute the task and return its result t = @task begin + # Set up cancellation + cancel_token = Dagger.CancelToken() + Dagger.DTASK_CANCEL_TOKEN[] = cancel_token + lock(istate.queue) do _ + istate.cancel_tokens[thunk_id] = cancel_token + end was_cancelled = false + result = try do_task(to_proc, task) catch err @@ -1345,6 +1353,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re # Task was cancelled, so occupancy and pressure are # already reduced pop!(istate.cancelled, thunk_id) + delete!(istate.cancel_tokens, thunk_id) was_cancelled = true end end @@ -1411,6 +1420,7 @@ function do_tasks(to_proc, return_queue, tasks) Dict{Int,Vector{Any}}(), Ref(UInt32(0)), Ref(UInt64(0)), Set{Int}(), + Dict{Int,Dagger.CancelToken}(), Ref(false)) runner = start_processor_runner!(istate, uid, return_queue) @static if VERSION < v"1.9" @@ -1652,6 +1662,7 @@ function do_task(to_proc, task_desc) sch_handle, processor=to_proc, task_spec=task_desc, + cancel_token=Dagger.DTASK_CANCEL_TOKEN[], )) res = Dagger.with_options(propagated) do diff --git a/src/task-tls.jl b/src/task-tls.jl index 90fdfedb3..8a8b6c66d 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -5,6 +5,7 @@ struct DTaskTLS sch_uid::UInt sch_handle::Any # FIXME: SchedulerHandle task_spec::Vector{Any} # FIXME: TaskSpec + cancel_token::CancelToken end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) @@ -22,7 +23,7 @@ get_tls() = DTASK_TLS[]::DTaskTLS Sets all Dagger TLS variables from `tls`, which may be a `DTaskTLS` or a `NamedTuple`. """ function set_tls!(tls) - DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec) + DTASK_TLS[] = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) end """ @@ -40,3 +41,21 @@ Get the current processor executing the current [`DTask`](@ref). """ task_processor() = get_tls().processor @deprecate thunk_processor() task_processor() + +""" + task_cancelled() -> Bool + +Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +""" +task_cancelled() = get_tls().cancel_token.cancelled[] + +""" + task_may_cancel!() + +Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +""" +function task_may_cancel!() + if task_cancelled() + throw(InterruptException()) + end +end diff --git a/src/threadproc.jl b/src/threadproc.jl index 09099889a..b75c90ca3 100644 --- a/src/threadproc.jl +++ b/src/threadproc.jl @@ -27,8 +27,9 @@ function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @n return result[] catch err if err isa InterruptException + # Direct interrupt hit us, propagate cancellation signal + # FIXME: We should tell the scheduler that the user hit Ctrl-C if !istaskdone(task) - # Propagate cancellation signal Threads.@spawn Base.throwto(task, InterruptException()) end end From 3911a732506c29f79e6055b6555f96e99c3d4f5f Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 12:35:32 -0400 Subject: [PATCH 27/56] streaming: Handle cancellation --- src/stream-buffers.jl | 2 ++ src/stream-fetchers.jl | 1 + src/stream.jl | 13 +++++++++++-- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl index e73d9990c..85c73b198 100644 --- a/src/stream-buffers.jl +++ b/src/stream-buffers.jl @@ -64,6 +64,7 @@ function Base.put!(rb::ProcessRingBuffer{T}, x) where T len = length(rb.buffer) while (@atomic rb.count) == len yield() + task_may_cancel!() end to_write_idx = mod1(rb.write_idx, len) rb.buffer[to_write_idx] = convert(T, x) @@ -73,6 +74,7 @@ end function Base.take!(rb::ProcessRingBuffer) while (@atomic rb.count) == 0 yield() + task_may_cancel!() end to_read_idx = rb.read_idx rb.read_idx += 1 diff --git a/src/stream-fetchers.jl b/src/stream-fetchers.jl index 0a738479d..39c6cdd6e 100644 --- a/src/stream-fetchers.jl +++ b/src/stream-fetchers.jl @@ -28,6 +28,7 @@ function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_r # might not yield when accessing data locally, which can cause this loop # to spin forever. yield() + task_may_cancel!() end @dagdebug thunk_id :stream "fetched $(length(values)) values" diff --git a/src/stream.jl b/src/stream.jl index ce07eeaf3..7ad81546b 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -30,6 +30,10 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} @dagdebug thunk_id :stream "adding $value" for buffer in values(store.buffers) while isfull(buffer) + if !isopen(store) + @dagdebug thunk_id :stream "closed!" + throw(InvalidStateException("Stream is closed", :closed)) + end @dagdebug thunk_id :stream "buffer full, waiting" wait(store.lock) end @@ -85,9 +89,10 @@ function Base.isopen(store::StreamStore, id::UInt) end function Base.close(store::StreamStore) - if store.open + @lock store.lock begin + store.open && return store.open = false - @lock store.lock notify(store.lock) + notify(store.lock) end end @@ -408,6 +413,10 @@ function stream!(sf::StreamingFunction, uid, counter = 0 while sf.max_evals < 0 || counter < sf.max_evals + # Exit streaming on cancellation + task_may_cancel!() + + # Exit streaming on migration if sf.stream.store.migrating return StreamMigrating() end From f71f6045cc63680721a067d497ebdb0f6ba7e2ef Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 13:14:59 -0400 Subject: [PATCH 28/56] fixup! cancellation: Add cancel token support --- src/Dagger.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index 5fbd7e3fb..444af8121 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -49,11 +49,11 @@ include("processor.jl") include("threadproc.jl") include("context.jl") include("utils/processors.jl") +include("dtask.jl") include("cancellation.jl") include("task-tls.jl") include("scopes.jl") include("utils/scopes.jl") -include("dtask.jl") include("queue.jl") include("thunk.jl") include("submission.jl") From b930a423d7e4c3d11a8001563531e9cec7db2b8d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Fri, 13 Sep 2024 13:15:45 -0400 Subject: [PATCH 29/56] fixup! fixup! fixup! Initial support for robustly migrating streaming tasks --- src/stream.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index 7ad81546b..13b903f0a 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -238,6 +238,7 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) unlock((store::StreamStore).lock) end) if w == myid() + stream.store_ref.handle = new_store_ref # FIXME: It's not valid to mutate the Chunk handle, but we want to update this to enable fast location queries stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) end @@ -327,13 +328,21 @@ struct StreamMigrating end function (sf::StreamingFunction)(args...; kwargs...) thunk_id = Sch.sch_handle().thunk_id.id + STREAM_THUNK_ID[] = thunk_id + + # Migrate our output stream store to this worker + if sf.stream isa Stream + migrate_stream!(sf.stream) + end + @label start - @dagdebug nothing :stream "Starting StreamingFunction" + @dagdebug thunk_id :stream "Starting StreamingFunction" worker_id = sf.stream.store_ref.handle.owner result = if worker_id == myid() _run_streamingfunction(nothing, sf, args...; kwargs...) else tls = get_tls() + # FIXME: Wire up listener to ferry cancel_token notifications to remote worker remotecall_fetch(_run_streamingfunction, worker_id, tls, sf, args...; kwargs...) end if result === StreamMigrating() @@ -344,12 +353,14 @@ end function _run_streamingfunction(tls, sf, args...; kwargs...) @nospecialize sf args kwargs + if tls !== nothing set_tls!(tls) end - result = nothing + thunk_id = Sch.sch_handle().thunk_id.id STREAM_THUNK_ID[] = thunk_id + # FIXME: Remove when scheduler is distributed uid = remotecall_fetch(1, thunk_id) do thunk_id lock(Sch.EAGER_ID_MAP) do id_map @@ -361,11 +372,6 @@ function _run_streamingfunction(tls, sf, args...; kwargs...) end end - # Migrate our output stream store to this worker - if sf.stream isa Stream - migrate_stream!(sf.stream) - end - try # TODO: This kwarg song-and-dance is required to ensure that we don't # allocate boxes within `stream!`, when possible From 16d73c9b6bee6776fda5243a3963ce4a9da60f8d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 14 Sep 2024 11:54:34 -0400 Subject: [PATCH 30/56] Sch: Add unwrap_nested_exception for DTaskFailedException --- src/sch/util.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/sch/util.jl b/src/sch/util.jl index cd006838b..eb5a285b4 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -29,6 +29,8 @@ unwrap_nested_exception(err::CapturedException) = unwrap_nested_exception(err.ex) unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) +unwrap_nested_exception(err::DTaskFailedException) = + unwrap_nested_exception(err.ex) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." From 2b2da8ee1e3bd75d449aba3e8a84dfde63e8054b Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 14 Sep 2024 11:54:55 -0400 Subject: [PATCH 31/56] ProcessRingBuffer: Add length method --- src/stream-buffers.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl index 85c73b198..e8c1c9488 100644 --- a/src/stream-buffers.jl +++ b/src/stream-buffers.jl @@ -60,6 +60,7 @@ mutable struct ProcessRingBuffer{T} end Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +Base.length(rb::ProcessRingBuffer) = @atomic rb.count function Base.put!(rb::ProcessRingBuffer{T}, x) where T len = length(rb.buffer) while (@atomic rb.count) == len From 5be724f2d731d8a46271f2d2729451b4dfdc8bd6 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 14 Sep 2024 11:57:11 -0400 Subject: [PATCH 32/56] fixup! fixup! cancellation: Add cancel token support --- src/stream.jl | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index 13b903f0a..986a70301 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -36,6 +36,7 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} end @dagdebug thunk_id :stream "buffer full, waiting" wait(store.lock) + task_may_cancel!() end put!(buffer, value) end @@ -50,6 +51,7 @@ function Base.take!(store::StreamStore, id::UInt) while isempty(buffer) && isopen(store, id) @dagdebug thunk_id :stream "no elements, not taking" wait(store.lock) + task_may_cancel!() end @dagdebug thunk_id :stream "wait finished" if !isopen(store, id) @@ -313,13 +315,6 @@ finish_stream(value::T; result::R=nothing) where {T,R} = FinishStream{T,R}(Some{ finish_stream(; result::R=nothing) where R = FinishStream{Union{},R}(nothing, result) -function cancel_stream!(t::DTask) - stream = task_to_stream(t.uid) - if stream !== nothing - close(stream) - end -end - const STREAM_THUNK_ID = TaskLocalValue{Int}(()->0) chunktype(sf::StreamingFunction{F}) where F = F From d5456373f551e9994604cf79a55b4b672ec0f377 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 14 Sep 2024 12:03:34 -0400 Subject: [PATCH 33/56] streaming: Buffers and tasks per input/output Instead of taking/putting values sequentially (which may block), runs "pull" and "push" tasks for each input and output, respectively. Uses buffers to communicate values between pullers/pushers and the streaming task, instead of only using one buffer per task-to-task connection. --- src/Dagger.jl | 2 +- ...{stream-fetchers.jl => stream-transfer.jl} | 11 +- src/stream.jl | 235 ++++++++++++------ 3 files changed, 172 insertions(+), 76 deletions(-) rename src/{stream-fetchers.jl => stream-transfer.jl} (69%) diff --git a/src/Dagger.jl b/src/Dagger.jl index 444af8121..505f41421 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -72,7 +72,7 @@ include("datadeps.jl") # Streaming include("stream-buffers.jl") -include("stream-fetchers.jl") +include("stream-transfer.jl") include("stream.jl") # Array computations diff --git a/src/stream-fetchers.jl b/src/stream-transfer.jl similarity index 69% rename from src/stream-fetchers.jl rename to src/stream-transfer.jl index 39c6cdd6e..defa24463 100644 --- a/src/stream-fetchers.jl +++ b/src/stream-transfer.jl @@ -1,13 +1,14 @@ struct RemoteFetcher end -function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} +function stream_pull_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} thunk_id = STREAM_THUNK_ID[] @dagdebug thunk_id :stream "fetching values" values = T[] + free_space = length(buffer.buffer) - length(buffer) while isempty(values) # FIXME: Pass buffer free space # TODO: It would be ideal if we could wait on store.lock, but get unlocked during migration - values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id) do store, id, T, Store_remote, thunk_id + values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id, free_space) do store, id, T, Store_remote, thunk_id, free_space if !isopen(store) throw(InvalidStateException("Stream is closed", :closed)) end @@ -16,8 +17,9 @@ function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_r in_store = store STREAM_THUNK_ID[] = thunk_id values = T[] - while !isempty(store, id) + while !isempty(store, id) && length(values) < free_space value = take!(store, id)::T + @dagdebug thunk_id :stream "fetched $value" push!(values, value) end return values @@ -36,3 +38,6 @@ function stream_fetch_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_r put!(buffer, value) end end +function stream_push_values!(::Type{RemoteFetcher}, T, store_ref::Store_remote, buffer::Blocal, id::UInt) where {Store_remote, Blocal} + sleep(0.1) +end diff --git a/src/stream.jl b/src/stream.jl index 986a70301..b0e5c10ba 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -1,12 +1,20 @@ mutable struct StreamStore{T,B} + uid::UInt waiters::Vector{Int} - buffers::Dict{Int,B} - buffer_amount::Int + input_streams::Dict{UInt,Any} # FIXME: Concrete type + output_streams::Dict{UInt,Any} # FIXME: Concrete type + input_buffers::Dict{UInt,B} + output_buffers::Dict{UInt,B} + input_buffer_amount::Int + output_buffer_amount::Int open::Bool migrating::Bool lock::Threads.Condition - StreamStore{T,B}(buffer_amount::Integer) where {T,B} = - new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount, + StreamStore{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} = + new{T,B}(uid, zeros(Int, 0), + Dict{UInt,Any}(), Dict{UInt,Any}(), + Dict{UInt,B}(), Dict{UInt,B}(), + input_buffer_amount, output_buffer_amount, true, false, Threads.Condition()) end @@ -27,8 +35,12 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} @dagdebug thunk_id :stream "closed!" throw(InvalidStateException("Stream is closed", :closed)) end - @dagdebug thunk_id :stream "adding $value" - for buffer in values(store.buffers) + @dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)" + for output_uid in keys(store.output_streams) + if !haskey(store.output_buffers, output_uid) + initialize_output_stream!(store, output_uid) + end + buffer = store.output_buffers[output_uid] while isfull(buffer) if !isopen(store) @dagdebug thunk_id :stream "closed!" @@ -47,7 +59,11 @@ end function Base.take!(store::StreamStore, id::UInt) thunk_id = STREAM_THUNK_ID[] @lock store.lock begin - buffer = store.buffers[id] + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + error("Must first check isempty(store, id) before taking from a stream") + end + buffer = store.output_buffers[id] while isempty(buffer) && isopen(store, id) @dagdebug thunk_id :stream "no elements, not taking" wait(store.lock) @@ -70,8 +86,14 @@ function Base.take!(store::StreamStore, id::UInt) end end -Base.isempty(store::StreamStore, id::UInt) = isempty(store.buffers[id]) -isfull(store::StreamStore, id::UInt) = isfull(store.buffers[id]) +function Base.isempty(store::StreamStore, id::UInt) + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return true + end + return isempty(store.output_buffers[id]) +end +isfull(store::StreamStore, id::UInt) = isfull(store.output_buffers[id]) "Returns whether the store is actively open. Only check this when deciding if new values can be pushed." Base.isopen(store::StreamStore) = store.open @@ -83,7 +105,7 @@ taken. """ function Base.isopen(store::StreamStore, id::UInt) @lock store.lock begin - if !isempty(store.buffers[id]) + if !isempty(store.output_buffers[id]) return true end return store.open @@ -92,63 +114,114 @@ end function Base.close(store::StreamStore) @lock store.lock begin - store.open && return + store.open || return store.open = false notify(store.lock) end end -function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Int}) where {T,B} +# FIXME: Just pass Stream directly, rather than its uid +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{UInt}) where {T,B} + our_uid = store.uid @lock store.lock begin - for w in waiters - buffer = initialize_stream_buffer(B, T, store.buffer_amount) - store.buffers[w] = buffer + for output_uid in waiters + store.output_streams[output_uid] = task_to_stream(output_uid) end append!(store.waiters, waiters) notify(store.lock) end end -function remove_waiters!(store::StreamStore, waiters::Vector{Int}) +function remove_waiters!(store::StreamStore, waiters::Vector{UInt}) @lock store.lock begin for w in waiters - delete!(store.buffers, w) + delete!(store.output_buffers, w) idx = findfirst(wo->wo==w, store.waiters) deleteat!(store.waiters, idx) + delete!(store.input_streams, w) end notify(store.lock) end end mutable struct Stream{T,B} + uid::UInt store::Union{StreamStore{T,B},Nothing} store_ref::Chunk - input_buffer::Union{B,Nothing} - buffer_amount::Int - function Stream{T,B}(buffer_amount::Integer=0) where {T,B} + function Stream{T,B}(uid::UInt, input_buffer_amount::Integer, output_buffer_amount::Integer) where {T,B} # Creates a new output stream - store = StreamStore{T,B}(buffer_amount) + store = StreamStore{T,B}(uid, input_buffer_amount, output_buffer_amount) store_ref = tochunk(store) - return new{T,B}(store, store_ref, nothing, buffer_amount) + return new{T,B}(uid, store, store_ref) end - function Stream{B}(stream::Stream{T}, buffer_amount::Integer=0) where {T,B} + function Stream(stream::Stream{T,B}) where {T,B} # References an existing output stream - return new{T,B}(nothing, stream.store_ref, nothing, buffer_amount) + return new{T,B}(stream.uid, nothing, stream.store_ref) end end -function initialize_input_stream!(stream::Stream{T,B}) where {T,B} - stream.input_buffer = initialize_stream_buffer(B, T, stream.buffer_amount) +struct StreamCancelledException <: Exception end +struct StreamingValue{B} + buffer::B end +Base.take!(sv::StreamingValue) = take!(sv.buffer) + +function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::Stream{IT,IB}) where {IT,OT,IB,OB} + input_uid = input_stream.uid + our_uid = our_store.uid + buffer = @lock our_store.lock begin + if haskey(our_store.input_buffers, input_uid) + return StreamingValue(our_store.input_buffers[input_uid]) + end -Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) - -function Base.take!(stream::Stream{T,B}, id::UInt) where {T,B} - # FIXME: Make remote fetcher configurable - stream_fetch_values!(RemoteFetcher, T, stream.store_ref, stream.input_buffer, id) - return take!(stream.input_buffer) + buffer = initialize_stream_buffer(OB, IT, our_store.input_buffer_amount) + # FIXME: Also pass a RemoteChannel to track remote closure + our_store.input_buffers[input_uid] = buffer + buffer + end + thunk_id = STREAM_THUNK_ID[] + tls = get_tls() + Sch.errormonitor_tracked("streaming input: $input_uid -> $our_uid", Threads.@spawn begin + set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id + try + while isopen(our_store) + # FIXME: Make remote fetcher configurable + stream_pull_values!(RemoteFetcher, IT, input_stream.store_ref, buffer, our_uid) + end + catch err + err isa InterruptException || rethrow(err) + finally + @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" + end + end) + return StreamingValue(buffer) +end +initialize_input_stream!(our_store::StreamStore, arg) = arg +function initialize_output_stream!(store::StreamStore{T,B}, output_uid::UInt) where {T,B} + @assert islocked(store.lock) + @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" + buffer = initialize_stream_buffer(B, T, store.output_buffer_amount) + store.output_buffers[output_uid] = buffer + our_uid = store.uid + thunk_id = STREAM_THUNK_ID[] + Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin + # FIXME: Track remote closure + try + while isopen(store) + # FIXME: Make remote fetcher configurable + stream_push_values!(RemoteFetcher, T, store, buffer, output_uid) + end + catch err + err isa InterruptException || rethrow(err) + finally + @dagdebug thunk_id :stream "output stream closed" + end + end) end +Base.put!(stream::Stream, @nospecialize(value)) = put!(stream.store, value) + function Base.isopen(stream::Stream, id::UInt)::Bool return MemPool.access_ref(stream.store_ref.handle, id) do store, id return isopen(store::StreamStore, id) @@ -163,7 +236,7 @@ function Base.close(stream::Stream) return end -function add_waiters!(stream::Stream, waiters::Vector{Int}) +function add_waiters!(stream::Stream, waiters::Vector{UInt}) MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters add_waiters!(store::StreamStore, waiters) return @@ -171,9 +244,9 @@ function add_waiters!(stream::Stream, waiters::Vector{Int}) return end -add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream, Int[waiter]) +add_waiters!(stream::Stream, waiter::Integer) = add_waiters!(stream, UInt[waiter]) -function remove_waiters!(stream::Stream, waiters::Vector{Int}) +function remove_waiters!(stream::Stream, waiters::Vector{UInt}) MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters remove_waiters!(store::StreamStore, waiters) return @@ -211,7 +284,7 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) # FIXME: Do this with MemPool.access_ref, in case stream was already migrated if stream.store_ref.handle.owner != w thunk_id = STREAM_THUNK_ID[] - @dagdebug thunk_id :stream "Beginning migration..." + @dagdebug thunk_id :stream "Beginning migration... ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; pre_migration=store->begin @@ -219,15 +292,29 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) # N.B. Serialization automatically unlocks the migrated copy lock((store::StreamStore).lock) - # Return the serializeable unsent outputs. We can't send the + # Return the serializeable unsent inputs/outputs. We can't send the # buffers themselves because they may be mmap'ed or something. - Dict(id => collect!(buffer) for (id, buffer) in store.buffers) + unsent_inputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.input_buffers) + unsent_outputs = Dict(uid => collect!(buffer) for (uid, buffer) in store.output_buffers) + empty!(store.input_buffers) + empty!(store.output_buffers) + return (unsent_inputs, unsent_outputs) end, - dest_post_migration=(store, unsent_outputs)->begin - # Initialize the StreamStore on the destination with the unsent outputs. - for (id, outputs) in unsent_outputs + dest_post_migration=(store, unsent)->begin + # Initialize the StreamStore on the destination with the unsent inputs/outputs. + STREAM_THUNK_ID[] = thunk_id + unsent_inputs, unsent_outputs = unsent + for (input_uid, inputs) in unsent_inputs + input_stream = store.input_streams[input_uid] + initialize_input_stream!(store, input_stream) + for item in inputs + put!(store.input_buffers[input_uid], item) + end + end + for (output_uid, outputs) in unsent_outputs + initialize_output_stream!(store, output_uid) for item in outputs - put!(store.buffers[id], item) + put!(store.output_buffers[output_uid], item) end end @@ -244,7 +331,7 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true) end - @dagdebug thunk_id :stream "Migration complete" + @dagdebug thunk_id :stream "Migration complete ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" end end @@ -269,18 +356,28 @@ end function initialize_streaming!(self_streams, spec, task) if !isa(spec.f, StreamingFunction) - # Adapt called function for streaming and generate output Streams + # Calculate the return type of the called function T_old = Base.uniontypes(task.metadata.return_type) T_old = map(t->(t !== Union{} && t <: FinishStream) ? first(t.parameters) : t, T_old) - # We treat non-dominating error paths as unreachable + # N.B. We treat non-dominating error paths as unreachable T_old = filter(t->t !== Union{}, T_old) T = task.metadata.return_type = !isempty(T_old) ? Union{T_old...} : Any + + # Get input buffer configuration + input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) + if input_buffer_amount <= 0 + throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) + end + + # Get output buffer configuration output_buffer_amount = get(spec.options, :stream_output_buffer_amount, 1) if output_buffer_amount <= 0 throw(ArgumentError("Output buffering is required; please specify a `stream_output_buffer_amount` greater than 0")) end - output_buffer = get(spec.options, :stream_output_buffer, ProcessRingBuffer) - stream = Stream{T,output_buffer}(output_buffer_amount) + + # Create the Stream + buffer_type = get(spec.options, :stream_buffer_type, ProcessRingBuffer) + stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount) self_streams[task.uid] = stream max_evals = get(spec.options, :stream_max_evals, -1) @@ -349,6 +446,8 @@ end function _run_streamingfunction(tls, sf, args...; kwargs...) @nospecialize sf args kwargs + store = sf.stream.store = MemPool.access_ref(identity, sf.stream.store_ref.handle; local_only=true) + if tls !== nothing set_tls!(tls) end @@ -372,11 +471,8 @@ function _run_streamingfunction(tls, sf, args...; kwargs...) # allocate boxes within `stream!`, when possible kwarg_names = map(name->Val{name}(), map(first, (kwargs...,))) kwarg_values = map(last, (kwargs...,)) - for arg in args - if arg isa Stream - initialize_input_stream!(arg) - end - end + args = map(arg->initialize_input_stream!(store, arg), args) + kwarg_values = map(kwarg->initialize_input_stream!(store, kwarg), kwarg_values) return stream!(sf, uid, (args...,), kwarg_names, kwarg_values) finally if !sf.stream.store.migrating @@ -423,8 +519,8 @@ function stream!(sf::StreamingFunction, uid, end # Get values from Stream args/kwargs - stream_args = _stream_take_values!(args, uid) - stream_kwarg_values = _stream_take_values!(kwarg_values, uid) + stream_args = _stream_take_values!(args) + stream_kwarg_values = _stream_take_values!(kwarg_values) stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) # Run a single cycle of f @@ -445,13 +541,13 @@ function stream!(sf::StreamingFunction, uid, end end -function _stream_take_values!(args, uid) +function _stream_take_values!(args) return ntuple(length(args)) do idx arg = args[idx] - if arg isa Stream - take!(arg, uid) + if arg isa StreamingValue + return take!(arg) else - arg + return arg end end end @@ -479,10 +575,11 @@ function task_to_stream(uid::UInt) end function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) - stream_waiter_changes = Dict{UInt,Vector{Int}}() + stream_waiter_changes = Dict{UInt,Vector{UInt}}() for (spec, task) in tasks @assert haskey(self_streams, task.uid) + our_stream = self_streams[task.uid] # Adapt args to accept Stream output of other streaming tasks for (idx, (pos, arg)) in enumerate(spec.args) @@ -495,21 +592,15 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) end if other_stream !== nothing - # Get input stream configs and configure input stream - input_buffer_amount = get(spec.options, :stream_input_buffer_amount, 1) - if input_buffer_amount <= 0 - throw(ArgumentError("Input buffering is required; please specify a `stream_input_buffer_amount` greater than 0")) - end - input_buffer = get(spec.options, :stream_input_buffer, ProcessRingBuffer) + # Generate Stream handle for input # FIXME: input_fetcher = get(spec.options, :stream_input_fetcher, RemoteFetcher) - input_stream = Stream{input_buffer}(other_stream, input_buffer_amount) - - # Replace the DTask with the input Stream - spec.args[idx] = pos => other_stream + other_stream_handle = Stream(other_stream) + spec.args[idx] = pos => other_stream_handle + our_stream.store.input_streams[arg.uid] = other_stream_handle # Add this task as a waiter for the associated output Stream changes = get!(stream_waiter_changes, arg.uid) do - Int[] + UInt[] end push!(changes, task.uid) end @@ -517,8 +608,8 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) end # Filter out all streaming options - to_filter = (:stream_input_buffer, :stream_input_buffer_amount, - :stream_output_buffer, :stream_output_buffer_amount, + to_filter = (:stream_buffer_type, + :stream_input_buffer_amount, :stream_output_buffer_amount, :stream_max_evals) spec.options = NamedTuple(filter(opt -> !(opt[1] in to_filter), Base.pairs(spec.options))) From 61ab9c1433ea1344cd736eb03bc973a76a96baa0 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:18:16 -0500 Subject: [PATCH 34/56] fixup! fixup! fixup! cancellation: Add cancel token support --- src/cancellation.jl | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/src/cancellation.jl b/src/cancellation.jl index 5387f101a..dcb0f5add 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -2,14 +2,35 @@ struct CancelToken cancelled::Base.RefValue{Bool} + event::Base.Event end -CancelToken() = CancelToken(Ref(false)) +CancelToken() = CancelToken(Ref(false), Base.Event()) function cancel!(token::CancelToken) token.cancelled[] = true + notify(token.event) + return end +is_cancelled(token::CancelToken) = token.cancelled[] +Base.wait(token::CancelToken) = wait(token.event) +# TODO: Enable this for safety +#Serialization.serialize(io::AbstractSerializer, ::CancelToken) = +# throw(ConcurrencyViolationError("Cannot serialize a CancelToken")) const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing) +function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer) + remote_token = remotecall_fetch(wid) do + return poolset(CancelToken()) + end + errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin + wait(orig_token) + @dagdebug nothing :cancel "Cancelling remote token on worker $wid" + MemPool.access_ref(remote_token) do remote_token + cancel!(remote_token) + end + end) +end + # Global-level cancellation """ From a51cbf958868bc9f523d9376bba0f7aa4f12d772 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:18:55 -0500 Subject: [PATCH 35/56] Sch: Trigger cancel token on task exit --- src/sch/Sch.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sch/Sch.jl b/src/sch/Sch.jl index 48eba1b31..794afce92 100644 --- a/src/sch/Sch.jl +++ b/src/sch/Sch.jl @@ -1371,6 +1371,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re else rethrow(err) end + finally + # Ensure that any spawned tasks get cleaned up + Dagger.cancel!(cancel_token) end end lock(istate.queue) do _ From 31944af65f58274496b1f1b3d0285768f5c35e9b Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:22:52 -0500 Subject: [PATCH 36/56] Add task_id for DTask --- src/sch/eager.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sch/eager.jl b/src/sch/eager.jl index 7ccdfbb29..aea0abbf6 100644 --- a/src/sch/eager.jl +++ b/src/sch/eager.jl @@ -141,3 +141,6 @@ function _find_thunk(e::Dagger.DTask) unwrap_weak_checked(EAGER_STATE[].thunk_dict[tid]) end end +Dagger.task_id(t::Dagger.DTask) = lock(EAGER_ID_MAP) do id_map + id_map[t.uid] +end From d5c27abeee3354970246194958a1d764159b082f Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:23:32 -0500 Subject: [PATCH 37/56] ProcessRingBuffer: Allow closure --- src/stream-buffers.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl index e8c1c9488..cd00000b4 100644 --- a/src/stream-buffers.jl +++ b/src/stream-buffers.jl @@ -53,18 +53,26 @@ mutable struct ProcessRingBuffer{T} write_idx::Int @atomic count::Int buffer::Vector{T} + open::Bool function ProcessRingBuffer{T}(len::Int=1024) where T buffer = Vector{T}(undef, len) - return new{T}(1, 1, 0, buffer) + return new{T}(1, 1, 0, buffer, true) end end Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) Base.length(rb::ProcessRingBuffer) = @atomic rb.count +Base.isopen(rb::ProcessRingBuffer) = rb.open +function Base.close(rb::ProcessRingBuffer) + rb.open = false +end function Base.put!(rb::ProcessRingBuffer{T}, x) where T len = length(rb.buffer) while (@atomic rb.count) == len yield() + if !isopen(rb) + throw(InvalidStateException("Stream is closed", :closed)) + end task_may_cancel!() end to_write_idx = mod1(rb.write_idx, len) @@ -75,6 +83,9 @@ end function Base.take!(rb::ProcessRingBuffer) while (@atomic rb.count) == 0 yield() + if !isopen(rb) + throw(InvalidStateException("Stream is closed", :closed)) + end task_may_cancel!() end to_read_idx = rb.read_idx From fbae73f0b15566b449a5cc20cbf4ada51ab1813c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:24:48 -0500 Subject: [PATCH 38/56] RemoteFetcher: Only collect values up to free buffer space --- src/stream-transfer.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/stream-transfer.jl b/src/stream-transfer.jl index defa24463..3251abb9a 100644 --- a/src/stream-transfer.jl +++ b/src/stream-transfer.jl @@ -1,22 +1,25 @@ struct RemoteFetcher end +# TODO: Switch to RemoteChannel approach function stream_pull_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} thunk_id = STREAM_THUNK_ID[] @dagdebug thunk_id :stream "fetching values" - values = T[] free_space = length(buffer.buffer) - length(buffer) + if free_space == 0 + yield() + task_may_cancel!() + return + end + + values = T[] while isempty(values) - # FIXME: Pass buffer free space - # TODO: It would be ideal if we could wait on store.lock, but get unlocked during migration values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id, free_space) do store, id, T, Store_remote, thunk_id, free_space - if !isopen(store) - throw(InvalidStateException("Stream is closed", :closed)) - end @dagdebug thunk_id :stream "trying to fetch values at $(myid())" store::Store_remote in_store = store STREAM_THUNK_ID[] = thunk_id values = T[] + @dagdebug thunk_id :stream "trying to fetch: $(store.output_buffers[id].count) values, free_space: $free_space" while !isempty(store, id) && length(values) < free_space value = take!(store, id)::T @dagdebug thunk_id :stream "fetched $value" @@ -39,5 +42,5 @@ function stream_pull_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_re end end function stream_push_values!(::Type{RemoteFetcher}, T, store_ref::Store_remote, buffer::Blocal, id::UInt) where {Store_remote, Blocal} - sleep(0.1) + sleep(1) end From bf53117f1f07e57a7a00e943e9862e61556b3873 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:26:51 -0500 Subject: [PATCH 39/56] streaming: Close buffers on closing StreamStore --- src/stream.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index b0e5c10ba..b367f8c95 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -113,9 +113,15 @@ function Base.isopen(store::StreamStore, id::UInt) end function Base.close(store::StreamStore) + store.open || return + store.open = false @lock store.lock begin - store.open || return - store.open = false + for buffer in values(store.input_buffers) + close(buffer) + end + for buffer in values(store.output_buffers) + close(buffer) + end notify(store.lock) end end From b9e3c70f7bc9724aaa3426cdd1bf0e273ae622b5 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 12:59:55 -0500 Subject: [PATCH 40/56] task-tls: Tweaks and fixes, task_id helper --- src/task-tls.jl | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/task-tls.jl b/src/task-tls.jl index 8a8b6c66d..af57181fe 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -1,6 +1,6 @@ # In-Thunk Helpers -struct DTaskTLS +mutable struct DTaskTLS processor::Processor sch_uid::UInt sch_handle::Any # FIXME: SchedulerHandle @@ -10,6 +10,8 @@ end const DTASK_TLS = TaskLocalValue{Union{DTaskTLS,Nothing}}(()->nothing) +Base.copy(tls::DTaskTLS) = DTaskTLS(tls.processor, tls.sch_uid, tls.sch_handle, tls.task_spec, tls.cancel_token) + """ get_tls() -> DTaskTLS @@ -32,7 +34,14 @@ end Returns `true` if currently executing in a [`DTask`](@ref), else `false`. """ in_task() = DTASK_TLS[] !== nothing -@deprecate in_thunk() in_task() +@deprecate(in_thunk(), in_task()) + +""" + task_id() -> Int + +Returns the ID of the current [`DTask`](@ref). +""" +task_id() = get_tls().sch_handle.thunk_id.id """ task_processor() -> Processor @@ -40,7 +49,7 @@ in_task() = DTASK_TLS[] !== nothing Get the current processor executing the current [`DTask`](@ref). """ task_processor() = get_tls().processor -@deprecate thunk_processor() task_processor() +@deprecate(thunk_processor(), task_processor()) """ task_cancelled() -> Bool From 8908478ef2974209234ebb5559a020d53b4a4d4c Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:00:31 -0500 Subject: [PATCH 41/56] task-tls: Add task_cancel! --- src/task-tls.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/task-tls.jl b/src/task-tls.jl index af57181fe..f6889bbb1 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -56,7 +56,7 @@ task_processor() = get_tls().processor Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. """ -task_cancelled() = get_tls().cancel_token.cancelled[] +task_cancelled() = is_cancelled(get_tls().cancel_token) """ task_may_cancel!() @@ -68,3 +68,10 @@ function task_may_cancel!() throw(InterruptException()) end end + +""" + task_cancel!() + +Cancels the current [`DTask`](@ref). +""" +task_cancel!() = cancel!(get_tls().cancel_token) From 1f21693c104678c01d34dce2843a0d0b77f6f421 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:01:20 -0500 Subject: [PATCH 42/56] streaming: max_evals cannot be specified as 0 --- src/stream.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/stream.jl b/src/stream.jl index b367f8c95..7ec23f41b 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -386,7 +386,12 @@ function initialize_streaming!(self_streams, spec, task) stream = Stream{T,buffer_type}(task.uid, input_buffer_amount, output_buffer_amount) self_streams[task.uid] = stream + # Get max evaluation count max_evals = get(spec.options, :stream_max_evals, -1) + if max_evals == 0 + throw(ArgumentError("stream_max_evals cannot be 0")) + end + spec.f = StreamingFunction(spec.f, stream, max_evals) spec.options = merge(spec.options, (;occupancy=Dict(Any=>0))) From c4bc7b2c7a1dc15609712d4dc7592e72b4ce85b1 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:05:21 -0500 Subject: [PATCH 43/56] streaming: Small tweaks to migration and cancellation --- src/stream.jl | 62 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index 7ec23f41b..493183b99 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -46,7 +46,7 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} @dagdebug thunk_id :stream "closed!" throw(InvalidStateException("Stream is closed", :closed)) end - @dagdebug thunk_id :stream "buffer full, waiting" + @dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting" wait(store.lock) task_may_cancel!() end @@ -211,8 +211,9 @@ function initialize_output_stream!(store::StreamStore{T,B}, output_uid::UInt) wh store.output_buffers[output_uid] = buffer our_uid = store.uid thunk_id = STREAM_THUNK_ID[] + tls = get_tls() Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin - # FIXME: Track remote closure + set_tls!(tls) try while isopen(store) # FIXME: Make remote fetcher configurable @@ -267,21 +268,8 @@ struct StreamingFunction{F, S} stream::S max_evals::Int - status_event::Threads.Event - migration_complete::Threads.Event - StreamingFunction(f::F, stream::S, max_evals) where {F, S} = - new{F, S}(f, stream, max_evals, Threads.Event(), Threads.Event()) -end - -function migrate_streamingfunction!(sf::StreamingFunction, w::Integer=myid()) - current_worker = sf.stream.store_ref.handle.owner - if myid() != current_worker - return remotecall_fetch(migrate_streamingfunction!, current_worker, sf, w) - end - - sf.stream.store.migrating = true - @lock sf.status_event wait(sf.status_event) # Wait for the streaming function to finish + new{F, S}(f, stream, max_evals) end function migrate_stream!(stream::Stream, w::Integer=myid()) @@ -292,6 +280,11 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) thunk_id = STREAM_THUNK_ID[] @dagdebug thunk_id :stream "Beginning migration... ($(length(stream.store.input_streams)) -> $(length(stream.store.output_streams)))" + # TODO: Wire up listener to ferry cancel_token notifications to remote worker + tls = get_tls() + @assert w == myid() "Only pull-based migration is currently supported" + #remote_cancel_token = clone_cancel_token_remote(get_tls().cancel_token, worker_id) + new_store_ref = MemPool.migrate!(stream.store_ref.handle, w; pre_migration=store->begin # Lock store to prevent any further modifications @@ -309,6 +302,9 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) dest_post_migration=(store, unsent)->begin # Initialize the StreamStore on the destination with the unsent inputs/outputs. STREAM_THUNK_ID[] = thunk_id + @assert !in_task() + set_tls!(tls) + #get_tls().cancel_token = MemPool.access_ref(identity, remote_cancel_token; local_only=true) unsent_inputs, unsent_outputs = unsent for (input_uid, inputs) in unsent_inputs input_stream = store.input_streams[input_uid] @@ -324,12 +320,16 @@ function migrate_stream!(stream::Stream, w::Integer=myid()) end end - # Ensure that the 'migrating' flag is not set + # Reset the state of this new store + store.open = true store.migrating = false end, post_migration=store->begin + # Indicate that this store has migrated + store.migrating = true + store.open = false + # Unlock the store - # FIXME: Indicate to all waiters that this store is dead unlock((store::StreamStore).lock) end) if w == myid() @@ -435,18 +435,17 @@ function (sf::StreamingFunction)(args...; kwargs...) # Migrate our output stream store to this worker if sf.stream isa Stream - migrate_stream!(sf.stream) + remote_cancel_token = migrate_stream!(sf.stream) end @label start @dagdebug thunk_id :stream "Starting StreamingFunction" worker_id = sf.stream.store_ref.handle.owner result = if worker_id == myid() - _run_streamingfunction(nothing, sf, args...; kwargs...) + _run_streamingfunction(nothing, nothing, sf, args...; kwargs...) else tls = get_tls() - # FIXME: Wire up listener to ferry cancel_token notifications to remote worker - remotecall_fetch(_run_streamingfunction, worker_id, tls, sf, args...; kwargs...) + remotecall_fetch(_run_streamingfunction, worker_id, tls, remote_cancel_token, sf, args...; kwargs...) end if result === StreamMigrating() @goto start @@ -454,12 +453,15 @@ function (sf::StreamingFunction)(args...; kwargs...) return result end -function _run_streamingfunction(tls, sf, args...; kwargs...) +function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...) @nospecialize sf args kwargs store = sf.stream.store = MemPool.access_ref(identity, sf.stream.store_ref.handle; local_only=true) + @assert isopen(store) if tls !== nothing + # Setup TLS on this new task + tls.cancel_token = MemPool.access_ref(identity, cancel_token; local_only=true) set_tls!(tls) end @@ -509,8 +511,6 @@ function _run_streamingfunction(tls, sf, args...; kwargs...) @dagdebug thunk_id :stream "closed stream" close(sf.stream) end - - notify(sf.status_event) end end @@ -520,12 +520,16 @@ function stream!(sf::StreamingFunction, uid, f = move(thunk_processor(), sf.f) counter = 0 - while sf.max_evals < 0 || counter < sf.max_evals + while true + # Yield to other (streaming) tasks + yield() + # Exit streaming on cancellation task_may_cancel!() # Exit streaming on migration if sf.stream.store.migrating + error("FIXME: max_evals should be retained") return StreamMigrating() end @@ -549,6 +553,12 @@ function stream!(sf::StreamingFunction, uid, # Put the result into the output stream put!(sf.stream, stream_result) + + # Exit streaming on eval limit + if sf.max_evals >= 0 && counter >= sf.max_evals + @dagdebug STREAM_THUNK_ID[] :stream "max evals reached" + return + end end end From 51e1606f02c105a95195740b4390c743d5aa18bc Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:05:49 -0500 Subject: [PATCH 44/56] dagdebug: Always yield to avoid heisenbugs --- src/utils/dagdebug.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/utils/dagdebug.jl b/src/utils/dagdebug.jl index 1e2b625bd..6a71e5c52 100644 --- a/src/utils/dagdebug.jl +++ b/src/utils/dagdebug.jl @@ -32,6 +32,10 @@ macro dagdebug(thunk, category, msg, args...) $debug_ex_noid end end + + # Always yield to reduce differing behavior for debug vs. non-debug + # TODO: Remove this eventually + yield() end end) end From 4ea09c46ce83e42ce7d93d7616d4c705e65eea2f Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:06:37 -0500 Subject: [PATCH 45/56] tests: Revamp streaming tests --- test/streaming.jl | 517 +++++++++++++++++++++++++++++++++------------- 1 file changed, 375 insertions(+), 142 deletions(-) diff --git a/test/streaming.jl b/test/streaming.jl index 87e0550ad..a994b4109 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -1,192 +1,425 @@ -import MemPool: access_ref - -@everywhere begin - """ - A functor to produce a certain number of outputs. - - Note: always use this like `Dagger.spawn(Producer())` rather than - `Dagger.@spawn Producer()`. The macro form will just create fresh objects - every time and stream forever. - """ - mutable struct Producer - N::Union{Int, Float64} - count::Int - mailbox::Union{RemoteChannel, Nothing} - - Producer(N=5, mailbox=nothing) = new(N, 0, mailbox) +@everywhere function rand_finite(T=Float64) + x = rand(T) + if rand() < 0.1 + return Dagger.finish_stream(x) end + return x +end +@everywhere function rand_finite_returns(T=Float64) + x = rand(T) + if rand() < 0.1 + return Dagger.finish_stream(x; result=x) + end + return x +end - function (self::Producer)() - self.count += 1 - - # Sleeping will make the loop yield (handy for single-threaded - # processes), and stops Dagger from being too spammy in debug mode. - if self.N == Inf - sleep(0.1) - end +const ACCUMULATOR = Dict{Int,Vector{Real}}() +@everywhere function accumulator(x=0) + tid = Dagger.task_id() + remotecall_wait(1, tid, x) do tid, x + acc = get!(Vector{Real}, ACCUMULATOR, tid) + push!(acc, x) + end + return +end +@everywhere accumulator(xs...) = accumulator(sum(xs)) - # Check if there are any instructions for us - if !isnothing(self.mailbox) && isready(self.mailbox) - msg = take!(self.mailbox) - if msg === :exit - put!(self.mailbox, self.count) - return Dagger.finish_stream(self.count) - else - error("Unrecognized Producer message: $msg") - end +function catch_interrupt(f) + try + f() + catch err + if err isa Dagger.DTaskFailedException && err.ex isa InterruptException + return + elseif err isa Dagger.Sch.SchedulingException + return end - - self.count >= self.N ? Dagger.finish_stream(self.count) : self.count + rethrow(err) end end - -function test_in_task(f, message, parent_testsets) - task_local_storage(:__BASETESTNEXT__, parent_testsets) - - @testset "$message" begin +function merge_testset!(inner::Test.DefaultTestSet) + outer = Test.get_testset() + append!(outer.results, inner.results) + outer.n_passed += inner.n_passed +end +function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) + t = @eval Threads.@spawn begin + tset = nothing try - f() - catch err - if err isa Dagger.DTaskFailedException && err.ex isa InterruptException - return - elseif err isa Dagger.Sch.SchedulingException - return + @testset $message begin + try + @testset $message begin + Dagger.with_options(;stream_max_evals=$max_evals) do + catch_interrupt($f) + end + end + finally + tset = Test.get_testset() + end end - rethrow() + catch end + return tset end -end - -function test_finishes(f, message::String; ignore_timeout=false) - # We sneakily pass a magic variable from the current TLS into the new - # task. It's used by the Test stdlib to hold a list of the current - # testsets, so we need it to be able to record the tests from the new - # task in the original testset that we're currently running under. - parent_testsets = get(task_local_storage(), :__BASETESTNEXT__, []) - t = Threads.@spawn test_in_task(f, message, parent_testsets) - - if timedwait(()->istaskdone(t), 20) == :timed_out + timed_out = timedwait(()->istaskdone(t), 5) == :timed_out + if timed_out if !ignore_timeout @warn "Testing task timed out: $message" end - Dagger.cancel!(;halt_sch=true, force=true) + Dagger.cancel!(;halt_sch=true) fetch(Dagger.@spawn 1+1) - return false end - return true + tset = fetch(t)::Test.DefaultTestSet + merge_testset!(tset) + return !timed_out end -@testset "Basics" begin - master_scope = Dagger.scope(worker=myid()) +all_scopes = [Dagger.ExactScope(proc) for proc in Dagger.all_processors()] +for idx in 1:5 + if idx == 1 + scopes = [Dagger.scope(worker = 1, thread = 1)] + scope_str = "Worker 1" + elseif idx == 2 && nprocs() > 1 + scopes = [Dagger.scope(worker = 2, thread = 1)] + scope_str = "Worker 2" + else + scopes = all_scopes + scope_str = "All Workers" + end - @test test_finishes("Migration") do - if nprocs() == 1 - @warn "Skipping migration test because it requires at least 1 extra worker" - return + @testset "Single Task Control Flow ($scope_str)" begin + @test !test_finishes("Single task running forever"; max_evals=1_000_000, ignore_timeout=true) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + y = rand() + sleep(1) + return y + end + end + fetch(x) end - # Start streaming locally - mailbox = RemoteChannel() - producer = Producer(Inf, mailbox) - x = Dagger.spawn_streaming() do - Dagger.spawn(producer, Dagger.Options(; scope=master_scope)) + @test test_finishes("Single task without result") do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + @test fetch(x) === nothing end - # Wait for the stream to get started - while producer.count < 2 - sleep(0.1) + @test test_finishes("Single task with result"; max_evals=1_000_000) do + local x + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) () -> begin + x = rand() + if x < 0.1 + return Dagger.finish_stream(x; result=123) + end + return x + end + end + @test fetch(x) == 123 end + end - # Migrate to another worker - access_ref(x.thunk_ref) do thunk - access_ref(thunk.f.handle) do streaming_function - Dagger.migrate_stream!(streaming_function.stream, workers()[1]) + @testset "Non-Streaming Inputs ($scope_str)" begin + @test test_finishes("() -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(0), values[A_tid]) end + @test test_finishes("42 -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42), values[A_tid]) + end + @test test_finishes("(42, 43) -> A") do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator(42, 43) + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(==(42 + 43), values[A_tid]) + end + end - # Wait a bit for the stream to get started again on the other node - sleep(0.5) - - # Stop it - put!(mailbox, :exit) - fetch(x) - - final_count = take!(mailbox) - @info "Counts:" producer.count final_count + @testset "Non-Streaming Outputs ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + A = Dagger.@spawn accumulator(x) + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + end + @test test_finishes("x -> (A, B)") do + local x, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + end + A = Dagger.@spawn accumulator(x) + B = Dagger.@spawn accumulator(x) + @test fetch(x) === nothing + @test fetch(A) === nothing + @test fetch(B) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 1 + @test all(v -> 0 <= v <= 10, values[B_tid]) + end end - return + @testset "Multiple Tasks ($scope_str)" begin + @test test_finishes("x -> A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) + end - @test test_finishes("Single task") do - local x - Dagger.spawn_streaming() do - x = Dagger.spawn(Producer()) + @test test_finishes("(x, A)") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(1.0) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v == 1, values[A_tid]) end - @test fetch(x) === nothing - end - @test !test_finishes("Single task running forever"; ignore_timeout=true) do - local x - Dagger.spawn_streaming() do - x = Dagger.spawn() do - y = rand() - sleep(1) - return y + @test test_finishes("x -> y -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(y) end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 1 <= v <= 2, values[A_tid]) end - fetch(x) - end - @test test_finishes("Max evaluations") do - producer = Producer(20) - x = Dagger.with_options(; stream_max_evals=10) do + @test test_finishes("x -> (y, A)") do + local x, y, A Dagger.spawn_streaming() do - # Spawn on the same node so we can access the local `producer` variable - Dagger.spawn(producer, Dagger.Options(; scope=master_scope)) + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x+1 + A = Dagger.@spawn scope=rand(scopes) accumulator(x) end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 1, values[A_tid]) end - wait(x) - @test producer.count == 10 - end + @test test_finishes("(x, y) -> A") do + local x, y, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x, y) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + end - @test test_finishes("Two tasks (sequential)") do - local x, y - @warn "\n\n\nStart streaming\n\n\n" - Dagger.spawn_streaming() do - x = Dagger.spawn(Producer()) - y = Dagger.@spawn x+1 + @test test_finishes("(x, y) -> z -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) end - @test fetch(x) === nothing - @test_throws Dagger.DTaskFailedException fetch(y) - end - # TODO: Two tasks (parallel) + @test test_finishes("x -> (y, z) -> A") do + local x, y, z, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) x + 1 + z = Dagger.@spawn scope=rand(scopes) x + 2 + A = Dagger.@spawn scope=rand(scopes) accumulator(y, z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 3 <= v <= 5, values[A_tid]) + end - # TODO: Three tasks (2 -> 1) and (1 -> 2) - # TODO: Four tasks (diamond) + @test test_finishes("(x, y) -> z -> (A, B)") do + local x, y, z, A, B + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + y = Dagger.@spawn scope=rand(scopes) rand() + z = Dagger.@spawn scope=rand(scopes) x + y + A = Dagger.@spawn scope=rand(scopes) accumulator(z) + B = Dagger.@spawn scope=rand(scopes) accumulator(z) + end + @test fetch(x) === nothing + @test fetch(y) === nothing + @test fetch(z) === nothing + @test fetch(A) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[A_tid]) + B_tid = Dagger.task_id(B) + @test length(values[B_tid]) == 10 + @test all(v -> 0 <= v <= 2, values[B_tid]) + end - # TODO: With pass-through/Without result - # TODO: With pass-through/With result - # TODO: Without pass-through/Without result + for T in (Float64, Int32, BigFloat) + @test test_finishes("Stream eltype $T") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand(T) + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + @test fetch(x) === nothing + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 10 + @test all(v -> v isa T, values[A_tid]) + end + end + end + + @testset "Max Evals ($scope_str)" begin + @test test_finishes("max_evals=0"; max_evals=0) do + @test_throws ArgumentError Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + end + @test test_finishes("max_evals=1"; max_evals=1) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) accumulator() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 1 + end + @test test_finishes("max_evals=100"; max_evals=100) do + local A + Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) rand() + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test length(values[A_tid]) == 100 + end + end - @test test_finishes("Without pass-through/With result") do - local x - Dagger.spawn_streaming() do - x = Dagger.spawn() do - x = rand() - if x < 0.1 - return Dagger.finish_stream(x; result=123) + @testset "DropBuffer ($scope_str)" begin + @test test_finishes("x (drop)-> A") do + local x, A + Dagger.spawn_streaming() do + Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + x = Dagger.@spawn scope=rand(scopes) rand() end - return x + A = Dagger.@spawn scope=rand(scopes) accumulator(x) end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test !haskey(values, A_tid) + end + @test test_finishes("x ->(drop) A") do + local x, A + Dagger.spawn_streaming() do + x = Dagger.@spawn scope=rand(scopes) rand() + Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test !haskey(values, A_tid) + end + @test test_finishes("x -(drop)> A") do + local x, A + Dagger.spawn_streaming() do + Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + x = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator(x) + end + end + @test fetch(A) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + A_tid = Dagger.task_id(A) + @test !haskey(values, A_tid) end - @test fetch(x) == 123 end -end -# TODO: Custom stream buffers/buffer amounts -# TODO: Cross-worker streaming -# TODO: Different stream element types (immutable and mutable) -# TODO: Zero-allocation examples -# FIXME: Streaming across threads + # FIXME: Varying buffer amounts + + #= TODO: Zero-allocation test + # First execution of a streaming task will almost guaranteed allocate (compiling, setup, etc.) + # BUT, second and later executions could possibly not allocate any further ("steady-state") + # We want to be able to validate that the steady-state execution for certain tasks is non-allocating + =# +end From 8bf5fbf8440544576fe7bddd7aa80464169c76c5 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Tue, 24 Sep 2024 13:06:59 -0500 Subject: [PATCH 46/56] tests: Add offline mode --- test/runtests.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index a4863c791..67d25276e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,7 +35,10 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ pushfirst!(LOAD_PATH, joinpath(@__DIR__, "..")) using Pkg Pkg.activate(@__DIR__) - Pkg.instantiate() + try + Pkg.instantiate() + catch + end using ArgParse s = ArgParseSettings(description = "Dagger Testsuite") @@ -55,6 +58,9 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ "-v", "--verbose" action = :store_true help = "Run the tests with debug logs from Dagger" + "-O", "--offline" + action = :store_true + help = "Set Pkg into offline mode" end end @@ -88,6 +94,11 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__ if parsed_args["verbose"] ENV["JULIA_DEBUG"] = "Dagger" end + + if parsed_args["offline"] + Pkg.UPDATED_REGISTRY_THIS_SESSION[] = true + Pkg.offline(true) + end else to_test = all_test_names @info "Running all tests" From 07ba8b1ef4b239fd3b43e615d54345a5fcd6fe03 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 09:00:23 -0500 Subject: [PATCH 47/56] dagdebug: Add JULIA_DAGGER_DEBUG config variable --- src/Dagger.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/Dagger.jl b/src/Dagger.jl index 505f41421..afbc37b9f 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -153,6 +153,20 @@ function __init__() ThreadProc(myid(), tid) end end + + # Set up @dagdebug categories, if specified + try + if haskey(ENV, "JULIA_DAGGER_DEBUG") + empty!(DAGDEBUG_CATEGORIES) + for category in split(ENV["JULIA_DAGGER_DEBUG"], ",") + if category != "" + push!(DAGDEBUG_CATEGORIES, Symbol(category)) + end + end + end + catch err + @warn "Error parsing JULIA_DAGGER_DEBUG" exception=err + end end end # module From 3aba122692482261823d2cf70dbab4736b1468cb Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:07:18 -0500 Subject: [PATCH 48/56] cancellation: Add graceful vs. forced --- src/cancellation.jl | 27 ++++++++++++++++++++------- src/task-tls.jl | 20 ++++++++++++-------- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/src/cancellation.jl b/src/cancellation.jl index dcb0f5add..0aa150331 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -1,16 +1,29 @@ # DTask-level cancellation -struct CancelToken - cancelled::Base.RefValue{Bool} +mutable struct CancelToken + @atomic cancelled::Bool + @atomic graceful::Bool event::Base.Event end -CancelToken() = CancelToken(Ref(false), Base.Event()) -function cancel!(token::CancelToken) - token.cancelled[] = true +CancelToken() = CancelToken(false, false, Base.Event()) +function cancel!(token::CancelToken; graceful::Bool=true) + if !graceful + @atomic token.graceful = false + end + @atomic token.cancelled = true notify(token.event) return end -is_cancelled(token::CancelToken) = token.cancelled[] +function is_cancelled(token::CancelToken; must_force::Bool=false) + if token.cancelled[] + if must_force && token.graceful[] + # If we're only responding to forced cancellation, ignore graceful cancellations + return false + end + return true + end + return false +end Base.wait(token::CancelToken) = wait(token.event) # TODO: Enable this for safety #Serialization.serialize(io::AbstractSerializer, ::CancelToken) = @@ -128,7 +141,7 @@ function _cancel!(state, tid, force, halt_sch) push!(istate.cancelled, tid) to_proc = istate.proc put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing))) - cancel!(istate.cancel_tokens[tid]) + cancel!(istate.cancel_tokens[tid]; graceful=false) end end end diff --git a/src/task-tls.jl b/src/task-tls.jl index f6889bbb1..5c7d0375b 100644 --- a/src/task-tls.jl +++ b/src/task-tls.jl @@ -52,26 +52,30 @@ task_processor() = get_tls().processor @deprecate(thunk_processor(), task_processor()) """ - task_cancelled() -> Bool + task_cancelled(; must_force::Bool=false) -> Bool Returns `true` if the current [`DTask`](@ref) has been cancelled, else `false`. +If `must_force=true`, then only return `true` if the cancellation was forced. """ -task_cancelled() = is_cancelled(get_tls().cancel_token) +task_cancelled(; must_force::Bool=false) = + is_cancelled(get_tls().cancel_token; must_force) """ - task_may_cancel!() + task_may_cancel!(; must_force::Bool=false) Throws an `InterruptException` if the current [`DTask`](@ref) has been cancelled. +If `must_force=true`, then only throw if the cancellation was forced. """ -function task_may_cancel!() - if task_cancelled() +function task_may_cancel!(;must_force::Bool=false) + if task_cancelled(;must_force) throw(InterruptException()) end end """ - task_cancel!() + task_cancel!(; graceful::Bool=true) -Cancels the current [`DTask`](@ref). +Cancels the current [`DTask`](@ref). If `graceful=true`, then the task will be +cancelled gracefully, otherwise it will be forced. """ -task_cancel!() = cancel!(get_tls().cancel_token) +task_cancel!(; graceful::Bool=true) = cancel!(get_tls().cancel_token; graceful) From 6ac140c8ac8cb487300345cbb4e2993e7c5c84b6 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:07:56 -0500 Subject: [PATCH 49/56] cancellation: Wrap InterruptException in DTaskFailedException --- src/cancellation.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/cancellation.jl b/src/cancellation.jl index 0aa150331..86f562bc4 100644 --- a/src/cancellation.jl +++ b/src/cancellation.jl @@ -96,7 +96,7 @@ function _cancel!(state, tid, force, halt_sch) for task in state.ready tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling ready task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end @@ -106,7 +106,7 @@ function _cancel!(state, tid, force, halt_sch) for task in keys(state.waiting) tid !== nothing && task.id != tid && continue @dagdebug tid :cancel "Cancelling waiting task" - state.cache[task] = InterruptException() + state.cache[task] = DTaskFailedException(task, task, InterruptException()) state.errored[task] = true Sch.set_failed!(state, task) end From f60cb779ec80a1c3c222f69ca5c57b308842742d Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:08:31 -0500 Subject: [PATCH 50/56] options: Add internal helper to strip all options --- src/options.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/options.jl b/src/options.jl index 1c1e3ff29..00196dd59 100644 --- a/src/options.jl +++ b/src/options.jl @@ -20,6 +20,12 @@ function with_options(f, options::NamedTuple) end with_options(f; options...) = with_options(f, NamedTuple(options)) +function _without_options(f) + with(options_context => NamedTuple()) do + f() + end +end + """ get_options(key::Symbol, default) -> Any get_options(key::Symbol) -> Any From b3b70e1846d876a0fa62bdde1ca6331746de8bb5 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Wed, 2 Oct 2024 19:14:11 -0500 Subject: [PATCH 51/56] streaming: Get tests passing Switch from RemoteFetcher to RemoteChannelFetcher Pass object rather than type to `stream_{push,pull}_values!` ProcessRingBuffer: Don't exit on graceful interrupt when non-empty --- src/Dagger.jl | 2 +- src/stream-buffers.jl | 218 +++++++++-------------------------------- src/stream-transfer.jl | 112 ++++++++++++++++++--- src/stream.jl | 118 +++++++++++++++++----- test/streaming.jl | 49 ++++----- 5 files changed, 259 insertions(+), 240 deletions(-) diff --git a/src/Dagger.jl b/src/Dagger.jl index afbc37b9f..70725340a 100644 --- a/src/Dagger.jl +++ b/src/Dagger.jl @@ -71,9 +71,9 @@ include("sch/Sch.jl"); using .Sch include("datadeps.jl") # Streaming +include("stream.jl") include("stream-buffers.jl") include("stream-transfer.jl") -include("stream.jl") # Array computations include("array/darray.jl") diff --git a/src/stream-buffers.jl b/src/stream-buffers.jl index cd00000b4..579a59753 100644 --- a/src/stream-buffers.jl +++ b/src/stream-buffers.jl @@ -1,50 +1,35 @@ """ -A buffer that drops all elements put into it. Only to be used as the output -buffer for a task - will throw if attached as an input. +A buffer that drops all elements put into it. """ -struct DropBuffer{T} end +mutable struct DropBuffer{T} + open::Bool + DropBuffer{T}() where T = new{T}(true) +end DropBuffer{T}(_) where T = DropBuffer{T}() Base.isempty(::DropBuffer) = true isfull(::DropBuffer) = false -Base.put!(::DropBuffer, _) = nothing -Base.take!(::DropBuffer) = error("Cannot `take!` from a DropBuffer") - -"A process-local buffer backed by a `Channel{T}`." -struct ChannelBuffer{T} - channel::Channel{T} - len::Int - count::Threads.Atomic{Int} - ChannelBuffer{T}(len::Int=1024) where T = - new{T}(Channel{T}(len), len, Threads.Atomic{Int}(0)) -end -Base.isempty(cb::ChannelBuffer) = isempty(cb.channel) -isfull(cb::ChannelBuffer) = cb.count[] == cb.len -function Base.put!(cb::ChannelBuffer{T}, x) where T - put!(cb.channel, convert(T, x)) - Threads.atomic_add!(cb.count, 1) -end -function Base.take!(cb::ChannelBuffer) - take!(cb.channel) - Threads.atomic_sub!(cb.count, 1) -end - -"A cross-worker buffer backed by a `RemoteChannel{T}`." -struct RemoteChannelBuffer{T} - channel::RemoteChannel{Channel{T}} - len::Int - count::Threads.Atomic{Int} - RemoteChannelBuffer{T}(len::Int=1024) where T = - new{T}(RemoteChannel(()->Channel{T}(len)), len, Threads.Atomic{Int}(0)) -end -Base.isempty(cb::RemoteChannelBuffer) = isempty(cb.channel) -isfull(cb::RemoteChannelBuffer) = cb.count[] == cb.len -function Base.put!(cb::RemoteChannelBuffer{T}, x) where T - put!(cb.channel, convert(T, x)) - Threads.atomic_add!(cb.count, 1) -end -function Base.take!(cb::RemoteChannelBuffer) - take!(cb.channel) - Threads.atomic_sub!(cb.count, 1) +capacity(::DropBuffer) = typemax(Int) +Base.length(::DropBuffer) = 0 +Base.isopen(buf::DropBuffer) = buf.open +function Base.close(buf::DropBuffer) + buf.open = false +end +function Base.put!(buf::DropBuffer, _) + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + return +end +function Base.take!(buf::DropBuffer) + while true + if !isopen(buf) + throw(InvalidStateException("DropBuffer is closed", :closed)) + end + task_may_cancel!(; must_force=true) + yield() + end end "A process-local ring buffer." @@ -53,7 +38,7 @@ mutable struct ProcessRingBuffer{T} write_idx::Int @atomic count::Int buffer::Vector{T} - open::Bool + @atomic open::Bool function ProcessRingBuffer{T}(len::Int=1024) where T buffer = Vector{T}(undef, len) return new{T}(1, 1, 0, buffer, true) @@ -61,32 +46,37 @@ mutable struct ProcessRingBuffer{T} end Base.isempty(rb::ProcessRingBuffer) = (@atomic rb.count) == 0 isfull(rb::ProcessRingBuffer) = (@atomic rb.count) == length(rb.buffer) +capacity(rb::ProcessRingBuffer) = length(rb.buffer) Base.length(rb::ProcessRingBuffer) = @atomic rb.count -Base.isopen(rb::ProcessRingBuffer) = rb.open +Base.isopen(rb::ProcessRingBuffer) = @atomic rb.open function Base.close(rb::ProcessRingBuffer) - rb.open = false + @atomic rb.open = false end function Base.put!(rb::ProcessRingBuffer{T}, x) where T - len = length(rb.buffer) - while (@atomic rb.count) == len + while isfull(rb) yield() if !isopen(rb) - throw(InvalidStateException("Stream is closed", :closed)) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) end - task_may_cancel!() + task_may_cancel!(; must_force=true) end - to_write_idx = mod1(rb.write_idx, len) + to_write_idx = mod1(rb.write_idx, length(rb.buffer)) rb.buffer[to_write_idx] = convert(T, x) rb.write_idx += 1 @atomic rb.count += 1 end function Base.take!(rb::ProcessRingBuffer) - while (@atomic rb.count) == 0 + while isempty(rb) yield() - if !isopen(rb) - throw(InvalidStateException("Stream is closed", :closed)) + if !isopen(rb) && isempty(rb) + throw(InvalidStateException("ProcessRingBuffer is closed", :closed)) end - task_may_cancel!() + if task_cancelled() && isempty(rb) + # We respect a graceful cancellation only if the buffer is empty. + # Otherwise, we may have values to continue communicating. + task_may_cancel!() + end + task_may_cancel!(; must_force=true) end to_read_idx = rb.read_idx rb.read_idx += 1 @@ -106,123 +96,3 @@ function collect!(rb::ProcessRingBuffer{T}) where T return output end - -#= TODO -"A server-local ring buffer backed by shared-memory." -mutable struct ServerRingBuffer{T} - read_idx::Int - write_idx::Int - @atomic count::Int - buffer::Vector{T} - function ServerRingBuffer{T}(len::Int=1024) where T - buffer = Vector{T}(undef, len) - return new{T}(1, 1, 0, buffer) - end -end -Base.isempty(rb::ServerRingBuffer) = (@atomic rb.count) == 0 -function Base.put!(rb::ServerRingBuffer{T}, x) where T - len = length(rb.buffer) - while (@atomic rb.count) == len - yield() - end - to_write_idx = mod1(rb.write_idx, len) - rb.buffer[to_write_idx] = convert(T, x) - rb.write_idx += 1 - @atomic rb.count += 1 -end -function Base.take!(rb::ServerRingBuffer) - while (@atomic rb.count) == 0 - yield() - end - to_read_idx = rb.read_idx - rb.read_idx += 1 - @atomic rb.count -= 1 - to_read_idx = mod1(to_read_idx, length(rb.buffer)) - return rb.buffer[to_read_idx] -end -=# - -#= -"A TCP-based ring buffer." -mutable struct TCPRingBuffer{T} - read_idx::Int - write_idx::Int - @atomic count::Int - buffer::Vector{T} - function TCPRingBuffer{T}(len::Int=1024) where T - buffer = Vector{T}(undef, len) - return new{T}(1, 1, 0, buffer) - end -end -Base.isempty(rb::TCPRingBuffer) = (@atomic rb.count) == 0 -function Base.put!(rb::TCPRingBuffer{T}, x) where T - len = length(rb.buffer) - while (@atomic rb.count) == len - yield() - end - to_write_idx = mod1(rb.write_idx, len) - rb.buffer[to_write_idx] = convert(T, x) - rb.write_idx += 1 - @atomic rb.count += 1 -end -function Base.take!(rb::TCPRingBuffer) - while (@atomic rb.count) == 0 - yield() - end - to_read_idx = rb.read_idx - rb.read_idx += 1 - @atomic rb.count -= 1 - to_read_idx = mod1(to_read_idx, length(rb.buffer)) - return rb.buffer[to_read_idx] -end -=# - -#= -""" -A flexible puller which switches to the most efficient buffer type based -on the sender and receiver locations. -""" -mutable struct UniBuffer{T} - buffer::Union{ProcessRingBuffer{T}, Nothing} -end -function initialize_stream_buffer!(::Type{UniBuffer{T}}, T, send_proc, recv_proc, buffer_amount) where T - if buffer_amount == 0 - error("Return NullBuffer") - end - send_osproc = get_parent(send_proc) - recv_osproc = get_parent(recv_proc) - if send_osproc.pid == recv_osproc.pid - inner = RingBuffer{T}(buffer_amount) - elseif system_uuid(send_osproc.pid) == system_uuid(recv_osproc.pid) - inner = ProcessBuffer{T}(buffer_amount) - else - inner = RemoteBuffer{T}(buffer_amount) - end - return UniBuffer{T}(buffer_amount) -end - -struct LocalPuller{T,B} - buffer::B{T} - id::UInt - function LocalPuller{T,B}(id::UInt, buffer_amount::Integer) where {T,B} - buffer = initialize_stream_buffer!(B, T, buffer_amount) - return new{T,B}(buffer, id) - end -end -function Base.take!(pull::LocalPuller{T,B}) where {T,B} - if pull.buffer === nothing - pull.buffer = - error("Return NullBuffer") - end - value = take!(pull.buffer) -end -function initialize_input_stream!(stream::Stream{T,B}, id::UInt, send_proc::Processor, recv_proc::Processor, buffer_amount::Integer) where {T,B} - local_buffer = remotecall_fetch(stream.ref.handle.owner, stream.ref.handle, id) do ref, id - local_buffer, remote_buffer = initialize_stream_buffer!(B, T, send_proc, recv_proc, buffer_amount) - ref.buffers[id] = remote_buffer - return local_buffer - end - stream.buffer = local_buffer - return stream -end -=# diff --git a/src/stream-transfer.jl b/src/stream-transfer.jl index 3251abb9a..667808762 100644 --- a/src/stream-transfer.jl +++ b/src/stream-transfer.jl @@ -1,32 +1,116 @@ +struct RemoteChannelFetcher + chan::RemoteChannel + RemoteChannelFetcher() = new(RemoteChannel()) +end +const _THEIR_TID = TaskLocalValue{Int}(()->0) +function stream_push_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_push "taking output value: $our_tid -> $their_tid" + value = try + take!(buffer) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_push "pushing output value: $our_tid -> $their_tid" + try + put!(fetcher.chan, value) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_push "channel closed: $our_tid -> $their_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_push "finished pushing output value: $our_tid -> $their_tid" +end +function stream_pull_values!(fetcher::RemoteChannelFetcher, T, our_store::StreamStore, their_stream::Stream, buffer) + our_tid = STREAM_THUNK_ID[] + our_uid = our_store.uid + their_uid = their_stream.uid + if _THEIR_TID[] == 0 + _THEIR_TID[] = remotecall_fetch(1) do + lock(Sch.EAGER_ID_MAP) do id_map + id_map[their_uid] + end + end + end + their_tid = _THEIR_TID[] + @dagdebug our_tid :stream_pull "pulling input value: $their_tid -> $our_tid" + value = try + take!(fetcher.chan) + catch err + if err isa InvalidStateException && !isopen(fetcher.chan) + @dagdebug our_tid :stream_pull "channel closed: $their_tid -> $our_tid" + throw(InterruptException()) + end + rethrow(err) + end + @dagdebug our_tid :stream_pull "putting input value: $their_tid -> $our_tid" + try + put!(buffer, value) + catch + close(fetcher.chan) + rethrow() + end + @lock our_store.lock notify(our_store.lock) + @dagdebug our_tid :stream_pull "finished putting input value: $their_tid -> $our_tid" +end + +#= TODO: Remove me +# This is a bad implementation because it wants to sleep on the remote side to +# wait for values, but this isn't semantically valid when done with MemPool.access_ref struct RemoteFetcher end -# TODO: Switch to RemoteChannel approach -function stream_pull_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_remote}, buffer::Blocal, id::UInt) where {Store_remote, Blocal} +function stream_push_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + sleep(1) +end +function stream_pull_values!(::Type{RemoteFetcher}, T, our_store::StreamStore, their_stream::Stream, buffer) + id = our_store.uid thunk_id = STREAM_THUNK_ID[] @dagdebug thunk_id :stream "fetching values" - free_space = length(buffer.buffer) - length(buffer) + free_space = capacity(buffer) - length(buffer) if free_space == 0 + @dagdebug thunk_id :stream "waiting for drain of full input buffer" yield() task_may_cancel!() + wait_for_nonfull_input(our_store, their_stream.uid) return end values = T[] while isempty(values) - values = MemPool.access_ref(store_ref.handle, id, T, Store_remote, thunk_id, free_space) do store, id, T, Store_remote, thunk_id, free_space - @dagdebug thunk_id :stream "trying to fetch values at $(myid())" - store::Store_remote - in_store = store + values, closed = MemPool.access_ref(their_stream.store_ref.handle, id, T, thunk_id, free_space) do their_store, id, T, thunk_id, free_space + @dagdebug thunk_id :stream "trying to fetch values at worker $(myid())" STREAM_THUNK_ID[] = thunk_id values = T[] - @dagdebug thunk_id :stream "trying to fetch: $(store.output_buffers[id].count) values, free_space: $free_space" - while !isempty(store, id) && length(values) < free_space - value = take!(store, id)::T + @dagdebug thunk_id :stream "trying to fetch with free_space: $free_space" + wait_for_nonempty_output(their_store, id) + if isempty(their_store, id) && !isopen(their_store, id) + @dagdebug thunk_id :stream "remote stream is closed, returning" + return values, true + end + while !isempty(their_store, id) && length(values) < free_space + value = take!(their_store, id)::T @dagdebug thunk_id :stream "fetched $value" push!(values, value) end - return values - end::Vector{T} + return values, false + end::Tuple{Vector{T},Bool} + if closed + throw(InterruptException()) + end # We explicitly yield in the loop to allow other tasks to run. This # matters on single-threaded instances because MemPool.access_ref() @@ -41,6 +125,4 @@ function stream_pull_values!(::Type{RemoteFetcher}, T, store_ref::Chunk{Store_re put!(buffer, value) end end -function stream_push_values!(::Type{RemoteFetcher}, T, store_ref::Store_remote, buffer::Blocal, id::UInt) where {Store_remote, Blocal} - sleep(1) -end +=# diff --git a/src/stream.jl b/src/stream.jl index 493183b99..f3c2cfbcb 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -7,6 +7,8 @@ mutable struct StreamStore{T,B} output_buffers::Dict{UInt,B} input_buffer_amount::Int output_buffer_amount::Int + input_fetchers::Dict{UInt,Any} + output_fetchers::Dict{UInt,Any} open::Bool migrating::Bool lock::Threads.Condition @@ -15,6 +17,7 @@ mutable struct StreamStore{T,B} Dict{UInt,Any}(), Dict{UInt,Any}(), Dict{UInt,B}(), Dict{UInt,B}(), input_buffer_amount, output_buffer_amount, + Dict{UInt,Any}(), Dict{UInt,Any}(), true, false, Threads.Condition()) end @@ -48,6 +51,9 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} end @dagdebug thunk_id :stream "buffer full ($(length(buffer)) values), waiting" wait(store.lock) + if !isfull(buffer) + @dagdebug thunk_id :stream "buffer has space ($(length(buffer)) values), continuing" + end task_may_cancel!() end put!(buffer, value) @@ -85,6 +91,36 @@ function Base.take!(store::StreamStore, id::UInt) return value end end +function wait_for_nonfull_input(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.input_streams, id) + @assert haskey(store.input_buffers, id) + buffer = store.input_buffers[id] + while isfull(buffer) && isopen(store) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for space in input buffer" + wait(store.lock) + end + end +end +function wait_for_nonempty_output(store::StreamStore, id::UInt) + @lock store.lock begin + @assert haskey(store.output_streams, id) + + # Wait for the output buffer to be initialized + while !haskey(store.output_buffers, id) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be initialized" + wait(store.lock) + end + isopen(store, id) || return + + # Wait for the output buffer to be nonempty + buffer = store.output_buffers[id] + while isempty(buffer) && isopen(store, id) + @dagdebug STREAM_THUNK_ID[] :stream "waiting for output buffer to be nonempty" + wait(store.lock) + end + end +end function Base.isempty(store::StreamStore, id::UInt) if !haskey(store.output_buffers, id) @@ -105,6 +141,10 @@ taken. """ function Base.isopen(store::StreamStore, id::UInt) @lock store.lock begin + if !haskey(store.output_buffers, id) + @assert haskey(store.output_streams, id) + return store.open + end if !isempty(store.output_buffers[id]) return true end @@ -127,13 +167,14 @@ function Base.close(store::StreamStore) end # FIXME: Just pass Stream directly, rather than its uid -function add_waiters!(store::StreamStore{T,B}, waiters::Vector{UInt}) where {T,B} +function add_waiters!(store::StreamStore{T,B}, waiters::Vector{Pair{UInt,Any}}) where {T,B} our_uid = store.uid @lock store.lock begin - for output_uid in waiters + for (output_uid, output_fetcher) in waiters store.output_streams[output_uid] = task_to_stream(output_uid) + push!(store.waiters, output_uid) + store.output_fetchers[output_uid] = output_fetcher end - append!(store.waiters, waiters) notify(store.lock) end end @@ -175,7 +216,8 @@ Base.take!(sv::StreamingValue) = take!(sv.buffer) function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::Stream{IT,IB}) where {IT,OT,IB,OB} input_uid = input_stream.uid our_uid = our_store.uid - buffer = @lock our_store.lock begin + local buffer, input_fetcher + @lock our_store.lock begin if haskey(our_store.input_buffers, input_uid) return StreamingValue(our_store.input_buffers[input_uid]) end @@ -183,7 +225,7 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S buffer = initialize_stream_buffer(OB, IT, our_store.input_buffer_amount) # FIXME: Also pass a RemoteChannel to track remote closure our_store.input_buffers[input_uid] = buffer - buffer + input_fetcher = our_store.input_fetchers[input_uid] end thunk_id = STREAM_THUNK_ID[] tls = get_tls() @@ -192,11 +234,14 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S STREAM_THUNK_ID[] = thunk_id try while isopen(our_store) - # FIXME: Make remote fetcher configurable - stream_pull_values!(RemoteFetcher, IT, input_stream.store_ref, buffer, our_uid) + stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) end catch err - err isa InterruptException || rethrow(err) + if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow(err) + end finally @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" end @@ -204,23 +249,34 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S return StreamingValue(buffer) end initialize_input_stream!(our_store::StreamStore, arg) = arg -function initialize_output_stream!(store::StreamStore{T,B}, output_uid::UInt) where {T,B} - @assert islocked(store.lock) +function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} + @assert islocked(our_store.lock) @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" - buffer = initialize_stream_buffer(B, T, store.output_buffer_amount) - store.output_buffers[output_uid] = buffer - our_uid = store.uid + buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) + our_store.output_buffers[output_uid] = buffer + our_uid = our_store.uid + output_stream = our_store.output_streams[output_uid] + output_fetcher = our_store.output_fetchers[output_uid] thunk_id = STREAM_THUNK_ID[] tls = get_tls() Sch.errormonitor_tracked("streaming output: $our_uid -> $output_uid", Threads.@spawn begin set_tls!(tls) + STREAM_THUNK_ID[] = thunk_id try - while isopen(store) - # FIXME: Make remote fetcher configurable - stream_push_values!(RemoteFetcher, T, store, buffer, output_uid) + while true + if !isopen(our_store) && isempty(buffer) + # Only exit if the buffer is empty; otherwise, we need to + # continue draining it + break + end + stream_push_values!(output_fetcher, T, our_store, output_stream, buffer) end catch err - err isa InterruptException || rethrow(err) + if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + return + else + rethrow(err) + end finally @dagdebug thunk_id :stream "output stream closed" end @@ -243,7 +299,7 @@ function Base.close(stream::Stream) return end -function add_waiters!(stream::Stream, waiters::Vector{UInt}) +function add_waiters!(stream::Stream, waiters::Vector{Pair{UInt,Any}}) MemPool.access_ref(stream.store_ref.handle, waiters) do store, waiters add_waiters!(store::StreamStore, waiters) return @@ -508,8 +564,8 @@ function _run_streamingfunction(tls, cancel_token, sf, args...; kwargs...) end # Ensure downstream tasks also terminate - @dagdebug thunk_id :stream "closed stream" close(sf.stream) + @dagdebug thunk_id :stream "closed stream store" end end end @@ -530,6 +586,7 @@ function stream!(sf::StreamingFunction, uid, # Exit streaming on migration if sf.stream.store.migrating error("FIXME: max_evals should be retained") + @dagdebug STREAM_THUNK_ID[] :stream "returning for migration" return StreamMigrating() end @@ -538,9 +595,15 @@ function stream!(sf::StreamingFunction, uid, stream_kwarg_values = _stream_take_values!(kwarg_values) stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) + if length(stream_args) > 0 || length(stream_kwarg_values) > 0 + # Notify tasks that input buffers may have space + @lock sf.stream.store.lock notify(sf.stream.store.lock) + end + # Run a single cycle of f - stream_result = f(stream_args...; stream_kwargs...) counter += 1 + @dagdebug STREAM_THUNK_ID[] :stream "executing $f (eval $counter)" + stream_result = f(stream_args...; stream_kwargs...) # Exit streaming on graceful request if stream_result isa FinishStream @@ -548,6 +611,7 @@ function stream!(sf::StreamingFunction, uid, value = something(stream_result.value) put!(sf.stream, value) end + @dagdebug STREAM_THUNK_ID[] :stream "voluntarily returning" return stream_result.result end @@ -555,8 +619,8 @@ function stream!(sf::StreamingFunction, uid, put!(sf.stream, stream_result) # Exit streaming on eval limit - if sf.max_evals >= 0 && counter >= sf.max_evals - @dagdebug STREAM_THUNK_ID[] :stream "max evals reached" + if sf.max_evals > 0 && counter >= sf.max_evals + @dagdebug STREAM_THUNK_ID[] :stream "max evals reached ($counter)" return end end @@ -596,7 +660,7 @@ function task_to_stream(uid::UInt) end function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) - stream_waiter_changes = Dict{UInt,Vector{UInt}}() + stream_waiter_changes = Dict{UInt,Vector{Pair{UInt,Any}}}() for (spec, task) in tasks @assert haskey(self_streams, task.uid) @@ -614,16 +678,18 @@ function finalize_streaming!(tasks::Vector{Pair{DTaskSpec,DTask}}, self_streams) if other_stream !== nothing # Generate Stream handle for input - # FIXME: input_fetcher = get(spec.options, :stream_input_fetcher, RemoteFetcher) + # FIXME: Be configurable + input_fetcher = RemoteChannelFetcher() other_stream_handle = Stream(other_stream) spec.args[idx] = pos => other_stream_handle our_stream.store.input_streams[arg.uid] = other_stream_handle + our_stream.store.input_fetchers[arg.uid] = input_fetcher # Add this task as a waiter for the associated output Stream changes = get!(stream_waiter_changes, arg.uid) do - UInt[] + Pair{UInt,Any}[] end - push!(changes, task.uid) + push!(changes, task.uid => input_fetcher) end end end diff --git a/test/streaming.jl b/test/streaming.jl index a994b4109..b7d3ce328 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -23,6 +23,7 @@ const ACCUMULATOR = Dict{Int,Vector{Real}}() return end @everywhere accumulator(xs...) = accumulator(sum(xs)) +@everywhere accumulator(::Nothing) = accumulator(0) function catch_interrupt(f) try @@ -60,12 +61,13 @@ function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) end return tset end - timed_out = timedwait(()->istaskdone(t), 5) == :timed_out + timed_out = timedwait(()->istaskdone(t), 10) == :timed_out if timed_out if !ignore_timeout @warn "Testing task timed out: $message" end Dagger.cancel!(;halt_sch=true) + @everywhere GC.gc() fetch(Dagger.@spawn 1+1) end tset = fetch(t)::Test.DefaultTestSet @@ -96,7 +98,7 @@ for idx in 1:5 return y end end - fetch(x) + @test_throws_unwrap InterruptException fetch(x) end @test test_finishes("Single task without result") do @@ -164,7 +166,9 @@ for idx in 1:5 Dagger.spawn_streaming() do x = Dagger.@spawn scope=rand(scopes) rand() end - A = Dagger.@spawn accumulator(x) + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + end @test fetch(x) === nothing @test fetch(A) === nothing values = copy(ACCUMULATOR); empty!(ACCUMULATOR) @@ -177,8 +181,10 @@ for idx in 1:5 Dagger.spawn_streaming() do x = Dagger.@spawn scope=rand(scopes) rand() end - A = Dagger.@spawn accumulator(x) - B = Dagger.@spawn accumulator(x) + Dagger._without_options() do + A = Dagger.@spawn accumulator(x) + B = Dagger.@spawn accumulator(x) + end @test fetch(x) === nothing @test fetch(A) === nothing @test fetch(B) === nothing @@ -364,7 +370,7 @@ for idx in 1:5 @test test_finishes("max_evals=100"; max_evals=100) do local A Dagger.spawn_streaming() do - A = Dagger.@spawn scope=rand(scopes) rand() + A = Dagger.@spawn scope=rand(scopes) accumulator() end @test fetch(A) === nothing values = copy(ACCUMULATOR); empty!(ACCUMULATOR) @@ -374,44 +380,39 @@ for idx in 1:5 end @testset "DropBuffer ($scope_str)" begin - @test test_finishes("x (drop)-> A") do + # TODO: Test that accumulator never gets called + @test !test_finishes("x (drop)-> A"; ignore_timeout=true) do local x, A Dagger.spawn_streaming() do - Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do x = Dagger.@spawn scope=rand(scopes) rand() end A = Dagger.@spawn scope=rand(scopes) accumulator(x) end - @test fetch(A) === nothing - values = copy(ACCUMULATOR); empty!(ACCUMULATOR) - A_tid = Dagger.task_id(A) - @test !haskey(values, A_tid) + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing end - @test test_finishes("x ->(drop) A") do + @test !test_finishes("x ->(drop) A"; ignore_timeout=true) do local x, A Dagger.spawn_streaming() do x = Dagger.@spawn scope=rand(scopes) rand() - Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do A = Dagger.@spawn scope=rand(scopes) accumulator(x) end end - @test fetch(A) === nothing - values = copy(ACCUMULATOR); empty!(ACCUMULATOR) - A_tid = Dagger.task_id(A) - @test !haskey(values, A_tid) + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing end - @test test_finishes("x -(drop)> A") do + @test !test_finishes("x -(drop)> A"; ignore_timeout=true) do local x, A Dagger.spawn_streaming() do - Dagger.with_options(;stream_buffer_type=>Dagger.DropBuffer) do + Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do x = Dagger.@spawn scope=rand(scopes) rand() A = Dagger.@spawn scope=rand(scopes) accumulator(x) end end - @test fetch(A) === nothing - values = copy(ACCUMULATOR); empty!(ACCUMULATOR) - A_tid = Dagger.task_id(A) - @test !haskey(values, A_tid) + @test fetch(x) === nothing + @test_throws_unwrap InterruptException fetch(A) === nothing end end From efc80be4109e8d4ae320c2b239e408ce817ce5b6 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 16 Nov 2024 21:13:29 +0100 Subject: [PATCH 52/56] Bump MemPool compat --- .buildkite/pipeline.yml | 2 +- Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 0f7ef011b..f7ab7d985 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -5,7 +5,7 @@ sandbox_capable: "true" os: linux arch: x86_64 - command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\"); Pkg.add(; url=\"https://github.com/JuliaData/MemPool.jl\", rev=\"jps/migration-helper\")'" + command: "julia --project -e 'using Pkg; Pkg.develop(;path=\"lib/TimespanLogging\")'" .bench: &bench if: build.message =~ /\[run benchmarks\]/ diff --git a/Project.toml b/Project.toml index fd7508cd7..7289ff5f7 100644 --- a/Project.toml +++ b/Project.toml @@ -52,7 +52,7 @@ GraphViz = "0.2" Graphs = "1" JSON3 = "1" MacroTools = "0.5" -MemPool = "0.4.6" +MemPool = "0.4.10" OnlineStats = "1" Plots = "1" PrecompileTools = "1.2" From e6a504dc8609387391b5926ffa843140979fac2f Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 16 Nov 2024 20:58:54 +0100 Subject: [PATCH 53/56] Fully lock StreamStore in close(::StreamStore) --- src/stream.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/stream.jl b/src/stream.jl index f3c2cfbcb..e22347b6f 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -153,9 +153,10 @@ function Base.isopen(store::StreamStore, id::UInt) end function Base.close(store::StreamStore) - store.open || return - store.open = false @lock store.lock begin + store.open || return + + store.open = false for buffer in values(store.input_buffers) close(buffer) end From 149adb59fa0c0b1bd22dd41f037cb576b8f73174 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 16 Nov 2024 20:59:38 +0100 Subject: [PATCH 54/56] Streaming tests cleanup and fixes - Added some whitespace. - Deleted the unused `rand_finite()` methods. - Allow passing the `timeout` to `test_finishes()` - Fix bug in one of the tests where we weren't waiting for all the tasks to finish, which would occasionally cause test failures because of the race condition. --- test/streaming.jl | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/test/streaming.jl b/test/streaming.jl index b7d3ce328..a93f44bd9 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -1,18 +1,3 @@ -@everywhere function rand_finite(T=Float64) - x = rand(T) - if rand() < 0.1 - return Dagger.finish_stream(x) - end - return x -end -@everywhere function rand_finite_returns(T=Float64) - x = rand(T) - if rand() < 0.1 - return Dagger.finish_stream(x; result=x) - end - return x -end - const ACCUMULATOR = Dict{Int,Vector{Real}}() @everywhere function accumulator(x=0) tid = Dagger.task_id() @@ -37,12 +22,14 @@ function catch_interrupt(f) rethrow(err) end end + function merge_testset!(inner::Test.DefaultTestSet) outer = Test.get_testset() append!(outer.results, inner.results) outer.n_passed += inner.n_passed end -function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) + +function test_finishes(f, message::String; timeout=10, ignore_timeout=false, max_evals=10) t = @eval Threads.@spawn begin tset = nothing try @@ -61,7 +48,8 @@ function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) end return tset end - timed_out = timedwait(()->istaskdone(t), 10) == :timed_out + + timed_out = timedwait(()->istaskdone(t), timeout) == :timed_out if timed_out if !ignore_timeout @warn "Testing task timed out: $message" @@ -70,6 +58,7 @@ function test_finishes(f, message::String; ignore_timeout=false, max_evals=10) @everywhere GC.gc() fetch(Dagger.@spawn 1+1) end + tset = fetch(t)::Test.DefaultTestSet merge_testset!(tset) return !timed_out @@ -176,6 +165,7 @@ for idx in 1:5 @test length(values[A_tid]) == 1 @test all(v -> 0 <= v <= 10, values[A_tid]) end + @test test_finishes("x -> (A, B)") do local x, A, B Dagger.spawn_streaming() do @@ -324,7 +314,8 @@ for idx in 1:5 @test fetch(y) === nothing @test fetch(z) === nothing @test fetch(A) === nothing - @test fetch(A) === nothing + @test fetch(B) === nothing + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) A_tid = Dagger.task_id(A) @test length(values[A_tid]) == 10 From 079e9fa2d22911d14a38d3c30bbed65a0a72c723 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sat, 16 Nov 2024 20:52:31 +0100 Subject: [PATCH 55/56] Make Dagger.finish_stream() propagate downstream Previously a streaming task calling `Dagger.finish_stream()` would only stop the caller, but now it will also stop all downstream tasks. This is done by: - Getting the output handler tasks to close their `RemoteChannel` when exiting. - Making the input handler tasks close their buffers when the `RemoteChannel` is closed. - Exiting `stream!()` when an input buffer is closed. --- src/sch/util.jl | 2 + src/stream.jl | 61 +++++++++++++++++++++---- test/streaming.jl | 113 ++++++++++++++++++++++++++++++++++------------ 3 files changed, 137 insertions(+), 39 deletions(-) diff --git a/src/sch/util.jl b/src/sch/util.jl index eb5a285b4..2e090b26c 100644 --- a/src/sch/util.jl +++ b/src/sch/util.jl @@ -31,6 +31,8 @@ unwrap_nested_exception(err::RemoteException) = unwrap_nested_exception(err.captured) unwrap_nested_exception(err::DTaskFailedException) = unwrap_nested_exception(err.ex) +unwrap_nested_exception(err::TaskFailedException) = + unwrap_nested_exception(err.t.exception) unwrap_nested_exception(err) = err "Gets a `NamedTuple` of options propagated by `thunk`." diff --git a/src/stream.jl b/src/stream.jl index e22347b6f..ddc303c98 100644 --- a/src/stream.jl +++ b/src/stream.jl @@ -40,9 +40,6 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B} end @dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)" for output_uid in keys(store.output_streams) - if !haskey(store.output_buffers, output_uid) - initialize_output_stream!(store, output_uid) - end buffer = store.output_buffers[output_uid] while isfull(buffer) if !isopen(store) @@ -238,12 +235,18 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S stream_pull_values!(input_fetcher, IT, our_store, input_stream, buffer) end catch err - if err isa InterruptException || (err isa InvalidStateException && !isopen(buffer)) + unwrapped_err = Sch.unwrap_nested_exception(err) + if unwrapped_err isa InterruptException || (unwrapped_err isa InvalidStateException && !isopen(input_fetcher.chan)) return else - rethrow(err) + rethrow() end finally + # Close the buffer because there will be no more values put into + # it. We don't close the entire store because there might be some + # remaining elements in the buffer to process and send to downstream + # tasks. + close(buffer) @dagdebug STREAM_THUNK_ID[] :stream "input stream closed" end end) @@ -251,10 +254,13 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S end initialize_input_stream!(our_store::StreamStore, arg) = arg function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B} - @assert islocked(our_store.lock) @dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid" - buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) - our_store.output_buffers[output_uid] = buffer + local buffer + @lock our_store.lock begin + buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount) + our_store.output_buffers[output_uid] = buffer + end + our_uid = our_store.uid output_stream = our_store.output_streams[output_uid] output_fetcher = our_store.output_fetchers[output_uid] @@ -279,6 +285,7 @@ function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt rethrow(err) end finally + close(output_fetcher.chan) @dagdebug thunk_id :stream "output stream closed" end end) @@ -476,6 +483,17 @@ struct FinishStream{T,R} result::R end +""" + finish_stream(value=nothing; result=nothing) + +Tell Dagger to stop executing the streaming function and all of its downstream +[`DTask`](@ref)'s. + +# Arguments +- `value`: The final value to be returned by the streaming function. This will + be passed to all downstream [`DTask`](@ref)'s. +- `result`: The value that will be returned by `fetch()`'ing the [`DTask`](@ref). +""" finish_stream(value::T; result::R=nothing) where {T,R} = FinishStream{T,R}(Some{T}(value), result) finish_stream(; result::R=nothing) where R = FinishStream{Union{},R}(nothing, result) @@ -577,6 +595,16 @@ function stream!(sf::StreamingFunction, uid, f = move(thunk_processor(), sf.f) counter = 0 + # Initialize output streams. We can't do this in add_waiters!() because the + # output handlers depend on the DTaskTLS, so they have to be set up from + # within the DTask. + store = sf.stream.store + for output_uid in keys(store.output_streams) + if !haskey(store.output_buffers, output_uid) + initialize_output_stream!(store, output_uid) + end + end + while true # Yield to other (streaming) tasks yield() @@ -592,8 +620,21 @@ function stream!(sf::StreamingFunction, uid, end # Get values from Stream args/kwargs - stream_args = _stream_take_values!(args) - stream_kwarg_values = _stream_take_values!(kwarg_values) + local stream_args, stream_kwarg_values + try + stream_args = _stream_take_values!(args) + stream_kwarg_values = _stream_take_values!(kwarg_values) + catch ex + if ex isa InvalidStateException + # This means a buffer has been closed because an upstream task + # finished. + @dagdebug STREAM_THUNK_ID[] :stream "Upstream task finished, returning" + return nothing + else + rethrow() + end + end + stream_kwargs = _stream_namedtuple(kwarg_names, stream_kwarg_values) if length(stream_args) > 0 || length(stream_kwarg_values) > 0 diff --git a/test/streaming.jl b/test/streaming.jl index a93f44bd9..06ee0a660 100644 --- a/test/streaming.jl +++ b/test/streaming.jl @@ -370,40 +370,95 @@ for idx in 1:5 end end - @testset "DropBuffer ($scope_str)" begin - # TODO: Test that accumulator never gets called - @test !test_finishes("x (drop)-> A"; ignore_timeout=true) do - local x, A - Dagger.spawn_streaming() do - Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do - x = Dagger.@spawn scope=rand(scopes) rand() - end - A = Dagger.@spawn scope=rand(scopes) accumulator(x) + # @testset "DropBuffer ($scope_str)" begin + # # TODO: Test that accumulator never gets called + # @test !test_finishes("x (drop)-> A"; ignore_timeout=false, max_evals=typemax(Int)) do + # # ENV["JULIA_DEBUG"] = "Dagger" + + # local x, A + # Dagger.spawn_streaming() do + # Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + # x = Dagger.@spawn scope=rand(scopes) rand() + # end + # A = Dagger.@spawn scope=rand(scopes) accumulator(x) + # end + # @test fetch(x) === nothing + # fetch(A) + # @test_throws_unwrap InterruptException fetch(A) + # end + + # @test !test_finishes("x ->(drop) A"; ignore_timeout=true) do + # local x, A + # Dagger.spawn_streaming() do + # x = Dagger.@spawn scope=rand(scopes) rand() + # Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + # A = Dagger.@spawn scope=rand(scopes) accumulator(x) + # end + # end + # @test fetch(x) === nothing + # @test_throws_unwrap InterruptException fetch(A) === nothing + # end + + # @test !test_finishes("x -(drop)> A"; ignore_timeout=true) do + # local x, A + # Dagger.spawn_streaming() do + # Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do + # x = Dagger.@spawn scope=rand(scopes) rand() + # A = Dagger.@spawn scope=rand(scopes) accumulator(x) + # end + # end + # @test fetch(x) === nothing + # @test_throws_unwrap InterruptException fetch(A) === nothing + # end + # end + + @testset "Graceful finishing" begin + @test test_finishes("finish_stream() without return value") do + B = Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream() + + Dagger.@spawn scope=rand(scopes) accumulator(A) end - @test fetch(x) === nothing - @test_throws_unwrap InterruptException fetch(A) === nothing + + fetch(B) + # Since we don't return any value in the call to finish_stream(), B + # should never execute. + @test isempty(ACCUMULATOR) end - @test !test_finishes("x ->(drop) A"; ignore_timeout=true) do - local x, A - Dagger.spawn_streaming() do - x = Dagger.@spawn scope=rand(scopes) rand() - Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do - A = Dagger.@spawn scope=rand(scopes) accumulator(x) - end + + @test test_finishes("finish_stream() with one downstream task") do + B = Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream(42) + + Dagger.@spawn scope=rand(scopes) accumulator(A) end - @test fetch(x) === nothing - @test_throws_unwrap InterruptException fetch(A) === nothing + + fetch(B) + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + @test values[Dagger.task_id(B)] == [42] end - @test !test_finishes("x -(drop)> A"; ignore_timeout=true) do - local x, A - Dagger.spawn_streaming() do - Dagger.with_options(;stream_buffer_type=Dagger.DropBuffer) do - x = Dagger.@spawn scope=rand(scopes) rand() - A = Dagger.@spawn scope=rand(scopes) accumulator(x) - end + + @test test_finishes("finish_stream() with multiple downstream tasks"; max_evals=2) do + D, E = Dagger.spawn_streaming() do + A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream(1) + B = Dagger.@spawn scope=rand(scopes) A + 1 + C = Dagger.@spawn scope=rand(scopes) A + 1 + D = Dagger.@spawn scope=rand(scopes) accumulator(B, C) + + E = Dagger.@spawn scope=rand(scopes) accumulator() + + D, E end - @test fetch(x) === nothing - @test_throws_unwrap InterruptException fetch(A) === nothing + + fetch(D) + fetch(E) + values = copy(ACCUMULATOR); empty!(ACCUMULATOR) + + # D should only execute once since it depends on A/B/C + @test values[Dagger.task_id(D)] == [4] + + # E should run max_evals times since it has no dependencies + @test length(values[Dagger.task_id(E)]) == 2 end end From 999bdd7e7e639afad3689c5f1e34383eb6671a90 Mon Sep 17 00:00:00 2001 From: JamesWrigley Date: Sun, 17 Nov 2024 00:44:07 +0100 Subject: [PATCH 56/56] Fix @test_throws_unwrap tests `unwrap_nested_exception()` now supports `DTaskFailedException` so we can match against the real exceptions thrown. --- test/mutation.jl | 4 +++- test/processors.jl | 6 +++--- test/scheduler.jl | 6 +++--- test/scopes.jl | 12 +++++++----- test/thunk.jl | 27 +++++++++++++-------------- 5 files changed, 29 insertions(+), 26 deletions(-) diff --git a/test/mutation.jl b/test/mutation.jl index b6ac7143b..a245f445d 100644 --- a/test/mutation.jl +++ b/test/mutation.jl @@ -1,3 +1,5 @@ +import Dagger.Sch: SchedulingException + @everywhere begin struct DynamicHistogram bins::Vector{Float64} @@ -48,7 +50,7 @@ end x = Dagger.@mutable worker=w Ref{Int}() @test fetch(Dagger.@spawn mutable_update!(x)) == w wo_scope = Dagger.ProcessScope(wo) - @test_throws_unwrap Dagger.DTaskFailedException fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) + @test_throws_unwrap SchedulingException fetch(Dagger.@spawn scope=wo_scope mutable_update!(x)) end end # @testset "@mutable" diff --git a/test/processors.jl b/test/processors.jl index e97a1d239..4cedcd340 100644 --- a/test/processors.jl +++ b/test/processors.jl @@ -1,6 +1,6 @@ using Distributed import Dagger: Context, Processor, OSProc, ThreadProc, get_parent, get_processors -import Dagger.Sch: ThunkOptions +import Dagger.Sch: ThunkOptions, SchedulingException @everywhere begin @@ -37,9 +37,9 @@ end end @testset "Processor exhaustion" begin opts = ThunkOptions(proclist=[OptOutProc]) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap SchedulingException reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=(proc)->false) - @test_throws_unwrap Dagger.DTaskFailedException ex isa Dagger.Sch.SchedulingException ex.reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) + @test_throws_unwrap SchedulingException reason="No processors available, try widening scope" collect(delayed(sum; options=opts)([1,2,3])) opts = ThunkOptions(proclist=nothing) @test collect(delayed(sum; options=opts)([1,2,3])) == 6 end diff --git a/test/scheduler.jl b/test/scheduler.jl index b9fe01872..a949ffc1d 100644 --- a/test/scheduler.jl +++ b/test/scheduler.jl @@ -182,7 +182,7 @@ end @testset "allow errors" begin opts = ThunkOptions(;allow_errors=true) a = delayed(error; options=opts)("Test") - @test_throws_unwrap Dagger.DTaskFailedException collect(a) + @test_throws_unwrap ErrorException collect(a) end end @@ -396,7 +396,7 @@ end ([Dagger.tochunk(MyStruct(1)), Dagger.tochunk(1)], sizeof(MyStruct)+sizeof(Int)), ] for arg in args - if arg isa Chunk + if arg isa Dagger.Chunk aff = Dagger.affinity(arg) @test aff[1] == OSProc(1) @test aff[2] == MemPool.approx_size(MemPool.poolget(arg.handle)) @@ -540,7 +540,7 @@ end t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) sleep(100) start_time = time_ns() Dagger.cancel!(t) - @test_throws_unwrap Dagger.DTaskFailedException fetch(t) + @test_throws_unwrap InterruptException fetch(t) t = Dagger.@spawn scope=Dagger.scope(worker=1, thread=1) yield() fetch(t) finish_time = time_ns() diff --git a/test/scopes.jl b/test/scopes.jl index 5f82a71a0..a92cc42f2 100644 --- a/test/scopes.jl +++ b/test/scopes.jl @@ -1,3 +1,5 @@ +import Dagger.Sch: SchedulingException + @testset "Chunk Scopes" begin wid1, wid2 = addprocs(2, exeflags=["-t 2"]) @everywhere [wid1,wid2] using Dagger @@ -56,7 +58,7 @@ # Different nodes for (ch1, ch2) in [(ns1_ch, ns2_ch), (ns2_ch, ns1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Process Scope" begin @@ -75,7 +77,7 @@ # Different process for (ch1, ch2) in [(ps1_ch, ps2_ch), (ps2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process and node @@ -83,7 +85,7 @@ # Different process and node for (ch1, ch2) in [(ps1_ch, ns2_ch), (ns2_ch, ps1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Exact Scope" begin @@ -104,14 +106,14 @@ # Different process, different processor for (ch1, ch2) in [(es1_ch, es2_ch), (es2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end # Same process, different processor es1_2 = ExactScope(Dagger.ThreadProc(wid1, 2)) es1_2_ch = Dagger.tochunk(nothing, OSProc(), es1_2) for (ch1, ch2) in [(es1_ch, es1_2_ch), (es1_2_ch, es1_ch)] - @test_throws_unwrap Dagger.DTaskFailedException ex.reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) + @test_throws_unwrap SchedulingException reason<"Scopes are not compatible:" fetch(Dagger.@spawn ch1 + ch2) end end @testset "Union Scope" begin diff --git a/test/thunk.jl b/test/thunk.jl index e6fb7e86b..5e193a505 100644 --- a/test/thunk.jl +++ b/test/thunk.jl @@ -69,7 +69,7 @@ end A = rand(4, 4) @test fetch(@spawn sum(A; dims=1)) ≈ sum(A; dims=1) - @test_throws_unwrap Dagger.DTaskFailedException fetch(@spawn sum(A; fakearg=2)) + @test_throws_unwrap MethodError fetch(@spawn sum(A; fakearg=2)) @test fetch(@spawn reduce(+, A; dims=1, init=2.0)) ≈ reduce(+, A; dims=1, init=2.0) @@ -194,7 +194,7 @@ end a = @spawn error("Test") wait(a) @test isready(a) - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap ErrorException fetch(a) b = @spawn 1+2 @test fetch(b) == 3 end @@ -207,7 +207,6 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) ex_str = sprint(io->Base.showerror(io,ex)) @test occursin(r"^DTaskFailedException:", ex_str) @test occursin("Test", ex_str) @@ -218,7 +217,6 @@ end catch err err end - ex = Dagger.Sch.unwrap_nested_exception(ex) ex_str = sprint(io->Base.showerror(io,ex)) @test occursin("Test", ex_str) @test occursin("Root Task", ex_str) @@ -226,28 +224,28 @@ end @testset "single dependent" begin a = @spawn error("Test") b = @spawn a+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap ErrorException fetch(a) end @testset "multi dependent" begin a = @spawn error("Test") b = @spawn a+2 c = @spawn a*2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap ErrorException fetch(b) + @test_throws_unwrap ErrorException fetch(c) end @testset "dependent chain" begin a = @spawn error("Test") - @test_throws_unwrap Dagger.DTaskFailedException fetch(a) + @test_throws_unwrap ErrorException fetch(a) b = @spawn a+1 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap ErrorException fetch(b) c = @spawn b+2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap ErrorException fetch(c) end @testset "single input" begin a = @spawn 1+1 b = @spawn (a->error("Test"))(a) @test fetch(a) == 2 - @test_throws_unwrap Dagger.DTaskFailedException fetch(b) + @test_throws_unwrap ErrorException fetch(b) end @testset "multi input" begin a = @spawn 1+1 @@ -255,7 +253,7 @@ end c = @spawn ((a,b)->error("Test"))(a,b) @test fetch(a) == 2 @test fetch(b) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(c) + @test_throws_unwrap ErrorException fetch(c) end @testset "diamond" begin a = @spawn 1+1 @@ -265,9 +263,10 @@ end @test fetch(a) == 2 @test fetch(b) == 3 @test fetch(c) == 4 - @test_throws_unwrap Dagger.DTaskFailedException fetch(d) + @test_throws_unwrap ErrorException fetch(d) end end + @testset "remote spawn" begin a = fetch(Distributed.@spawnat 2 Dagger.@spawn 1+2) @test Dagger.Sch.EAGER_INIT[] @@ -283,7 +282,7 @@ end t1 = Dagger.@spawn 1+"fail" Dagger.@spawn t1+1 end - @test_throws_unwrap Dagger.DTaskFailedException fetch(t2) + @test_throws_unwrap MethodError fetch(t2) end @testset "undefined function" begin # Issues #254, #255