Skip to content

Commit

Permalink
fixup! Make Dagger.finish_stream() propagate downstream
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesWrigley committed Nov 16, 2024
1 parent 3f5f8d4 commit e1ccbfe
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 6 deletions.
22 changes: 16 additions & 6 deletions src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ function Base.put!(store::StreamStore{T,B}, value) where {T,B}
end
@dagdebug thunk_id :stream "adding $value ($(length(store.output_streams)) outputs)"
for output_uid in keys(store.output_streams)
if !haskey(store.output_buffers, output_uid)
initialize_output_stream!(store, output_uid)
end
buffer = store.output_buffers[output_uid]
while isfull(buffer)
if !isopen(store)
Expand Down Expand Up @@ -257,10 +254,13 @@ function initialize_input_stream!(our_store::StreamStore{OT,OB}, input_stream::S
end
initialize_input_stream!(our_store::StreamStore, arg) = arg
function initialize_output_stream!(our_store::StreamStore{T,B}, output_uid::UInt) where {T,B}
@assert islocked(our_store.lock)
@dagdebug STREAM_THUNK_ID[] :stream "initializing output stream $output_uid"
buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount)
our_store.output_buffers[output_uid] = buffer
local buffer
@lock our_store.lock begin
buffer = initialize_stream_buffer(B, T, our_store.output_buffer_amount)
our_store.output_buffers[output_uid] = buffer
end

our_uid = our_store.uid
output_stream = our_store.output_streams[output_uid]
output_fetcher = our_store.output_fetchers[output_uid]
Expand Down Expand Up @@ -595,6 +595,16 @@ function stream!(sf::StreamingFunction, uid,
f = move(thunk_processor(), sf.f)
counter = 0

# Initialize output streams. We can't do this in add_waiters!() because the
# output handlers depend on the DTaskTLS, so they have to be set up from
# within the DTask.
store = sf.stream.store
for output_uid in keys(store.output_streams)
if !haskey(store.output_buffers, output_uid)
initialize_output_stream!(store, output_uid)
end
end

while true
# Yield to other (streaming) tasks
yield()
Expand Down
13 changes: 13 additions & 0 deletions test/streaming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,19 @@ for idx in 1:5
# end

@testset "Graceful finishing" begin
@test test_finishes("finish_stream() without return value") do
B = Dagger.spawn_streaming() do
A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream()

Dagger.@spawn scope=rand(scopes) accumulator(A)
end

fetch(B)
# Since we don't return any value in the call to finish_stream(), B
# should never execute.
@test isempty(ACCUMULATOR)
end

@test test_finishes("finish_stream() with one downstream task") do
B = Dagger.spawn_streaming() do
A = Dagger.@spawn scope=rand(scopes) Dagger.finish_stream(42)
Expand Down

0 comments on commit e1ccbfe

Please sign in to comment.