Skip to content

Commit

Permalink
Merge pull request #472 from JuliaParallel/jps/darray-no-cache
Browse files Browse the repository at this point in the history
DArray: Remove the stage cache
  • Loading branch information
jpsamaroo authored Feb 22, 2024
2 parents 8172122 + 0aad01f commit 7364106
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 74 deletions.
38 changes: 4 additions & 34 deletions src/array/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ Base.IndexStyle(::Type{<:ArrayOp}) = IndexCartesian()

collect(x::ArrayOp) = collect(fetch(x))

_to_darray(x::ArrayOp) = cached_stage(Context(global_context()), x)::DArray
_to_darray(x::ArrayOp) = stage(Context(global_context()), x)::DArray
Base.fetch(x::ArrayOp) = fetch(_to_darray(x))

collect(x::Computation) = collect(fetch(x))

Base.fetch(x::Computation) = fetch(cached_stage(Context(global_context()), x))
Base.fetch(x::Computation) = fetch(stage(Context(global_context()), x))

function Base.show(io::IO, ::MIME"text/plain", x::ArrayOp)
write(io, string(typeof(x)))
Expand Down Expand Up @@ -288,36 +288,6 @@ function Base.fetch(c::DArray{T}) where T
end
end

global _stage_cache = WeakKeyDict{Context, Dict}()

"""
cached_stage(ctx::Context, x)
A memoized version of stage. It is important that the
tasks generated for the same `DArray` have the same
identity, for example:
```julia
A = rand(Blocks(100,100), Float64, 1000, 1000)
compute(A+A')
```
must not result in computation of `A` twice.
"""
function cached_stage(ctx::Context, x)
cache = if !haskey(_stage_cache, ctx)
_stage_cache[ctx] = Dict()
else
_stage_cache[ctx]
end

if haskey(cache, x)
cache[x]
else
cache[x] = stage(ctx, x)
end
end

Base.@deprecate_binding Cat DArray
Base.@deprecate_binding ComputedArray DArray

Expand Down Expand Up @@ -352,15 +322,15 @@ end
function stage(ctx::Context, d::Distribute)
if isa(d.data, ArrayOp)
# distributing a distributed array
x = cached_stage(ctx, d.data)
x = stage(ctx, d.data)
if d.domainchunks == domainchunks(x)
return x # already properly distributed
end
Nd = ndims(x)
T = eltype(d.data)
concat = x.concat
cs = map(d.domainchunks) do idx
chunks = cached_stage(ctx, x[idx]).chunks
chunks = stage(ctx, x[idx]).chunks
shape = size(chunks)
# TODO: fix hashing
#hash = uhash(idx, Base.hash(Distribute, Base.hash(d.data)))
Expand Down
4 changes: 2 additions & 2 deletions src/array/getindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ GetIndex(input::ArrayOp, idx::Tuple) =
GetIndex{eltype(input), ndims(input)}(input, idx)

function stage(ctx::Context, gidx::GetIndex)
inp = cached_stage(ctx, gidx.input)
inp = stage(ctx, gidx.input)

dmn = domain(inp)
idxs = [if isa(gidx.idx[i], Colon)
Expand All @@ -32,7 +32,7 @@ struct GetIndexScalar <: Computation
end

function stage(ctx::Context, gidx::GetIndexScalar)
inp = cached_stage(ctx, gidx.input)
inp = stage(ctx, gidx.input)
s = view(inp, ArrayDomain(gidx.idx))
Dagger.@spawn identity(collect(s)[1])
end
Expand Down
4 changes: 2 additions & 2 deletions src/array/map-reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ size(x::Map) = size(x.inputs[1])
Map(f, inputs::Tuple) = Map{Any, ndims(inputs[1])}(f, inputs)

function stage(ctx::Context, node::Map)
inputs = Any[cached_stage(ctx, n) for n in node.inputs]
inputs = Any[stage(ctx, n) for n in node.inputs]
primary = inputs[1] # all others will align to this guy
domains = domainchunks(primary)
thunks = similar(domains, Any)
Expand Down Expand Up @@ -130,7 +130,7 @@ function Base.reduce(f::Function, x::ArrayOp; dims = nothing)
end

function stage(ctx::Context, r::Reducedim)
inp = cached_stage(ctx, r.input)
inp = stage(ctx, r.input)
thunks = let op = r.op, dims=r.dims
# do reducedim on each block
tmp = map(p->Dagger.spawn(b->reduce(op,b,dims=dims), p), chunks(inp))
Expand Down
24 changes: 12 additions & 12 deletions src/array/matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function _ctranspose(x::AbstractArray)
end

function stage(ctx::Context, node::Transpose)
inp = cached_stage(ctx, node.input)
inp = stage(ctx, node.input)
thunks = _ctranspose(chunks(inp))
return DArray(eltype(inp), domain(inp)', domainchunks(inp)', thunks, inp.partitioning', inp.concat)
end
Expand Down Expand Up @@ -143,7 +143,7 @@ function promote_distribution(ctx::Context, m::MatMul, a,b)
pb = domainchunks(b)

d = DomainBlocks((1,1), (pa.cumlength[2], pb.cumlength[2])) # FIXME: this is not generic
a, cached_stage(ctx, Distribute(d, b))
a, stage(ctx, Distribute(d, b))
end

function stage_operands(ctx::Context, m::MatMul, a, b)
Expand All @@ -152,14 +152,14 @@ function stage_operands(ctx::Context, m::MatMul, a, b)
end
# take the row distribution of a and get b onto that.

stg_a = cached_stage(ctx, a)
stg_b = cached_stage(ctx, b)
stg_a = stage(ctx, a)
stg_b = stage(ctx, b)
promote_distribution(ctx, m, stg_a, stg_b)
end

"An operand which should be distributed as per convenience"
function stage_operands(ctx::Context, ::MatMul, a::ArrayOp, b::PromotePartition{T,1}) where T
stg_a = cached_stage(ctx, a)
stg_a = stage(ctx, a)
dmn_a = domain(stg_a)
dchunks_a = domainchunks(stg_a)
dmn_b = domain(b.data)
Expand All @@ -168,19 +168,19 @@ function stage_operands(ctx::Context, ::MatMul, a::ArrayOp, b::PromotePartition{
end
dmn_out = DomainBlocks((1,),(dchunks_a.cumlength[2],))

stg_a, cached_stage(ctx, Distribute(dmn_out, b.data))
stg_a, stage(ctx, Distribute(dmn_out, b.data))
end

function stage_operands(ctx::Context, ::MatMul, a::PromotePartition, b::ArrayOp)

if size(a, 2) != size(b, 1)
throw(DimensionMismatch("Cannot promote array of domain $(dmn_b) to multiply with an array of size $(dmn_a)"))
end
stg_b = cached_stage(ctx, b)
stg_b = stage(ctx, b)

ps = domainchunks(stg_b)
dmn_out = DomainBlocks((1,1),([size(a.data, 1)], ps.cumlength[1],))
cached_stage(ctx, Distribute(dmn_out, a.data)), stg_b
stage(ctx, Distribute(dmn_out, a.data)), stg_b
end

function stage(ctx::Context, mul::MatMul{T,N}) where {T,N}
Expand Down Expand Up @@ -215,11 +215,11 @@ scale(l::ArrayOp, r::ArrayOp) = _to_darray(Scale(l, r))
function stage_operand(ctx::Context, ::Scale, a, b::PromotePartition)
ps = domainchunks(a)
b_parts = DomainBlocks((1,), (ps.cumlength[1],))
cached_stage(ctx, Distribute(b_parts, b.data))
stage(ctx, Distribute(b_parts, b.data))
end

function stage_operand(ctx::Context, ::Scale, a, b)
cached_stage(ctx, b)
stage(ctx, b)
end

function _scale(l, r)
Expand All @@ -231,7 +231,7 @@ function _scale(l, r)
end

function stage(ctx::Context, scal::Scale)
r = cached_stage(ctx, scal.r)
r = stage(ctx, scal.r)
l = stage_operand(ctx, scal, r, scal.l)

@assert size(domain(r), 1) == size(domain(l), 1)
Expand Down Expand Up @@ -265,7 +265,7 @@ function Base.cat(d::ArrayDomain, ds::ArrayDomain...; dims::Int)
end

function stage(ctx::Context, c::Concat)
inp = Any[cached_stage(ctx, x) for x in c.inputs]
inp = Any[stage(ctx, x) for x in c.inputs]

dmns = map(domain, inp)
dims = [[i == c.axis ? 0 : i for i in size(d)] for d in dmns]
Expand Down
16 changes: 8 additions & 8 deletions src/array/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ BCast(b::Broadcasted) = BCast{typeof(b), combine_eltypes(b.f, b.args), length(ax
size(x::BCast) = map(length, axes(x.bcasted))

function stage_operands(ctx::Context, ::BCast, xs::ArrayOp...)
map(x->cached_stage(ctx, x), xs)
map(x->stage(ctx, x), xs)
end

function stage_operands(ctx::Context, ::BCast, x::ArrayOp, y::PromotePartition)
stg_x = cached_stage(ctx, x)
stg_x = stage(ctx, x)
y1 = Distribute(domain(stg_x), y.data)
stg_x, cached_stage(ctx, y1)
stg_x, stage(ctx, y1)
end

function stage_operands(ctx::Context, ::BCast, x::PromotePartition, y::ArrayOp)
stg_y = cached_stage(ctx, y)
stg_y = stage(ctx, y)
x1 = Distribute(domain(stg_y), x.data)
cached_stage(ctx, x1), stg_y
stage(ctx, x1), stg_y
end

struct DaggerBroadcastStyle <: BroadcastStyle end
Expand All @@ -57,7 +57,7 @@ function stage(ctx::Context, node::BCast{B,T,N}) where {B,T,N}
bc = Broadcast.flatten(node.bcasted)
args = bc.args
args1 = map(args) do x
x isa ArrayOp ? cached_stage(ctx, x) : x
x isa ArrayOp ? stage(ctx, x) : x
end
ds = map(x->x isa DArray ? domainchunks(x) : nothing, args1)
sz = size(node)
Expand All @@ -84,7 +84,7 @@ function stage(ctx::Context, node::BCast{B,T,N}) where {B,T,N}
end
end |> Tuple
dmn = DomainBlocks(ntuple(_->1, length(s)), splits)
cached_stage(ctx, Distribute(dmn, part, arg)).chunks
stage(ctx, Distribute(dmn, part, arg)).chunks
else
arg
end
Expand All @@ -105,7 +105,7 @@ end
mapchunk(f::Function, xs::ArrayOp...) = MapChunk(f, xs)
Base.@deprecate mappart(args...) mapchunk(args...)
function stage(ctx::Context, node::MapChunk)
inputs = map(x->cached_stage(ctx, x), node.input)
inputs = map(x->stage(ctx, x), node.input)
thunks = map(map(chunks, inputs)...) do ps...
Dagger.spawn(node.f, map(p->nothing=>p, ps)...)
end
Expand Down
2 changes: 1 addition & 1 deletion src/array/setindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function setindex(x::ArrayOp, val, idxs...)
end

function stage(ctx::Context, sidx::SetIndex)
inp = cached_stage(ctx, sidx.input)
inp = stage(ctx, sidx.input)

dmn = domain(inp)
idxs = [if isa(sidx.idx[i], Colon)
Expand Down
2 changes: 1 addition & 1 deletion src/compute.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export stage, cached_stage, compute, debug_compute, cleanup
export compute, debug_compute

###### Scheduler #######

Expand Down
2 changes: 1 addition & 1 deletion src/file-io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ function save(p::Computation, name::AbstractString)
end

function stage(ctx::Context, s::Save)
x = cached_stage(ctx, s.input)
x = stage(ctx, s.input)
dir_path = s.name * "_data"
if !isdir(dir_path)
mkdir(dir_path)
Expand Down
6 changes: 3 additions & 3 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1063,7 +1063,7 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
@async begin
timespan_start(ctx, :fire, gproc.pid, 0)
try
remotecall_wait(do_tasks, gproc.pid, proc, state.chan, [ts])
remotecall_wait(do_tasks, gproc.pid, proc, state.chan, [ts]);
catch err
bt = catch_backtrace()
thunk_id = ts[1]
Expand Down Expand Up @@ -1552,7 +1552,7 @@ function do_task(to_proc, task_desc)
=#
x = @invokelatest move(to_proc, x)
#end
@dagdebug thunk_id :move "Moved argument $id to $to_proc: $x"
@dagdebug thunk_id :move "Moved argument $id to $to_proc: $(typeof(x))"
timespan_finish(ctx, :move, (;thunk_id, id), (;f, id, data=x); tasks=[Base.current_task()])
return x
end
Expand Down Expand Up @@ -1595,7 +1595,7 @@ function do_task(to_proc, task_desc)
# FIXME
#gcnum_start = Base.gc_num()

@dagdebug thunk_id :execute "Executing"
@dagdebug thunk_id :execute "Executing $(typeof(f))"

result_meta = try
# Set TLS variables
Expand Down
2 changes: 1 addition & 1 deletion src/sch/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ function print_sch_status(io::IO, state, thunk; offset=0, limit=5, max_inputs=3)
println(io, "$(thunk.id): $(thunk.f)")
for (idx, input) in enumerate(thunk.syncdeps)
if input isa WeakThunk
input = unwrap_weak(input)
input = Dagger.unwrap_weak(input)
if input === nothing
println(io, repeat(' ', offset+2), "(???)")
continue
Expand Down
14 changes: 6 additions & 8 deletions src/threadproc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,25 @@ iscompatible_func(proc::ThreadProc, opts, f) = true
iscompatible_arg(proc::ThreadProc, opts, x) = true
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...))
tls = get_tls()
# FIXME: Use return type of the call to specialize container
result = Ref{Any}()
task = Task() do
set_tls!(tls)
TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id)
@invokelatest f(args...; kwargs...)
result[] = @invokelatest f(args...; kwargs...)
return
end
set_task_tid!(task, proc.tid)
schedule(task)
try
fetch(task)
return result[]
catch err
@static if VERSION < v"1.7-rc1"
stk = Base.catch_stack(task)
else
stk = Base.current_exceptions(task)
end
err, frames = stk[1]
err, frames = Base.current_exceptions(task)[1]
rethrow(CapturedException(err, frames))
end
end
get_parent(proc::ThreadProc) = OSProc(proc.owner)
default_enabled(proc::ThreadProc) = true

# TODO: ThreadGroupProc?

2 changes: 1 addition & 1 deletion src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ Base.hash(x::Thunk, h::UInt) = hash(x.id, hash(h, 0x7ad3bac49089a05f % UInt))
Base.isequal(x::Thunk, y::Thunk) = x.id==y.id

function show_thunk(io::IO, t)
lvl = get(io, :lazy_level, 2)
lvl = get(io, :lazy_level, 0)
f = if t.f isa Chunk
Tf = t.f.chunktype
if isdefined(Tf, :instance)
Expand Down

0 comments on commit 7364106

Please sign in to comment.