Skip to content

Commit

Permalink
Add dynamic shared memory allocation
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-kenzel committed Aug 1, 2024
1 parent cf797d7 commit bb992d8
Show file tree
Hide file tree
Showing 13 changed files with 71 additions and 5 deletions.
22 changes: 22 additions & 0 deletions src/thorin/be/c/c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,9 @@ void CCodeGen::emit_module() {
stream_.fmt("__device__ inline int blockDim_{}() {{ return blockDim.{}; }}\n", x, x);
stream_.fmt("__device__ inline int gridDim_{}() {{ return gridDim.{}; }}\n", x, x);
}

stream_.fmt("\n"
"extern __shared__ unsigned char __dynamic_smem[];\n");
}

stream_.endl() << func_impls_.str();
Expand Down Expand Up @@ -742,7 +745,10 @@ void CCodeGen::emit_epilogue(Continuation* cont) {
bb.tail.fmt("goto {};", label_name(callee));
} else if (auto callee = body->callee()->isa_nom<Continuation>(); callee && callee->is_intrinsic()) {
if (callee->intrinsic() == Intrinsic::Reserve) {
assert(body->num_args() == 3 && "incorrect number of arguments");

emit_unsafe(body->arg(0));

if (!body->arg(1)->isa<PrimLit>())
world().edef(body->arg(1), "reserve_shared: couldn't extract memory size");

Expand All @@ -758,6 +764,16 @@ void CCodeGen::emit_epilogue(Continuation* cont) {
}
bb.tail.fmt("p_{} = {}_reserved;\n", ret_cont->param(1)->unique_name(), cont->unique_name());
bb.tail.fmt("goto {};", label_name(ret_cont));
} else if (callee->intrinsic() == Intrinsic::LocalMemory) {
if (lang_ == Lang::HLS)
world().edef(body, "local_memory not supported for HLS");
assert(body->num_args() == 2 && "incorrect number of arguments");

emit_unsafe(body->arg(0));

auto ret_cont = body->arg(1)->as_nom<Continuation>();
bb.tail.fmt("p_{} = __dynamic_smem;\n", ret_cont->param(1)->unique_name());
bb.tail.fmt("goto {};", label_name(ret_cont));
} else if (callee->intrinsic() == Intrinsic::Pipeline) {
assert((lang_ == Lang::OpenCL || lang_ == Lang::HLS) && "pipelining not supported on this backend");

Expand Down Expand Up @@ -1440,6 +1456,12 @@ std::string CCodeGen::emit_fun_head(Continuation* cont, bool is_proto) {
}
needs_comma = true;
}

if (cont->is_exported() && lang_ == Lang::OpenCL) {
if (needs_comma) s.fmt(", ");
s.fmt("__local unsigned char* __dynamic_smem");
}

s << ")";
return s.str();
}
Expand Down
1 change: 1 addition & 0 deletions src/thorin/be/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct LaunchArgs {
Device,
Space,
Config,
LocalMem,
Body,
Return,
Num
Expand Down
4 changes: 4 additions & 0 deletions src/thorin/be/llvm/amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,8 @@ llvm::Value* AMDGPUCodeGen::emit_reserve(llvm::IRBuilder<>& irbuilder, const Con
return emit_reserve_shared(irbuilder, continuation, true);
}

llvm::Value* AMDGPUCodeGen::emit_local_memory(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
return emit_local_memory_base_ptr(irbuilder, continuation);
}

}
1 change: 1 addition & 0 deletions src/thorin/be/llvm/amdgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class AMDGPUCodeGen : public CodeGen {
llvm::Value* emit_global(const Global*) override;
llvm::Value* emit_mathop(llvm::IRBuilder<>&, const MathOp*) override;
llvm::Value* emit_reserve(llvm::IRBuilder<>&, const Continuation*) override;
llvm::Value* emit_local_memory(llvm::IRBuilder<>&, const Continuation*) override;
std::string get_alloc_name() const override { return "malloc"; }

const Cont2Config& kernel_config_;
Expand Down
25 changes: 24 additions & 1 deletion src/thorin/be/llvm/llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1300,6 +1300,7 @@ std::vector<llvm::Value*> CodeGen::emit_intrinsic(llvm::IRBuilder<>& irbuilder,
case Intrinsic::CmpXchgWeak: return emit_cmpxchg(irbuilder, continuation, true);
case Intrinsic::Fence: emit_fence(irbuilder, continuation); break;
case Intrinsic::Reserve: return { emit_reserve(irbuilder, continuation) };
case Intrinsic::LocalMemory: return { emit_local_memory(irbuilder, continuation) };
case Intrinsic::CUDA: runtime_->emit_host_code(*this, irbuilder, Runtime::CUDA_PLATFORM, ".cu", continuation); break;
case Intrinsic::NVVM: runtime_->emit_host_code(*this, irbuilder, Runtime::CUDA_PLATFORM, ".nvvm", continuation); break;
case Intrinsic::OpenCL: runtime_->emit_host_code(*this, irbuilder, Runtime::OPENCL_PLATFORM, ".cl", continuation); break;
Expand Down Expand Up @@ -1420,7 +1421,7 @@ llvm::Value* CodeGen::emit_reserve(llvm::IRBuilder<>&, const Continuation* conti
llvm::Value* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const Continuation* continuation, bool init_undef) {
assert(continuation->has_body());
auto body = continuation->body();
assert(body->num_args() == 3 && "required arguments are missing");
assert(body->num_args() == 3 && "incorrect number of arguments");
if (!body->arg(1)->isa<PrimLit>())
world().edef(body->arg(1), "reserve_shared: couldn't extract memory size");
auto num_elems = body->arg(1)->as<PrimLit>()->ps32_value();
Expand All @@ -1437,6 +1438,28 @@ llvm::Value* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const Co
return call;
}

llvm::Value* CodeGen::emit_local_memory(llvm::IRBuilder<>&, const Continuation* continuation) {
world().edef(continuation, "local_memory: only allowed in device code");
THORIN_UNREACHABLE;
}

llvm::Value* CodeGen::emit_local_memory_base_ptr(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
static constexpr auto name = "__dynamic_smem";

assert(continuation->has_body());
auto body = continuation->body();
assert(body->num_args() == 2 && "incorrect number of arguments");
auto cont = body->arg(1)->as_nom<Continuation>();

if (auto found = module().getGlobalVariable(name))
return found;

auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(context()), 0);
auto global = new llvm::GlobalVariable(module(), type, false, llvm::GlobalValue::ExternalLinkage, nullptr, name, nullptr, llvm::GlobalVariable::NotThreadLocal, 3);
global->setAlignment(llvm::Align(16));
return global;
}

/*
* backend-specific stuff
*/
Expand Down
2 changes: 2 additions & 0 deletions src/thorin/be/llvm/llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class CodeGen : public thorin::CodeGen, public thorin::Emitter<llvm::Value*, llv

virtual llvm::Value* emit_reserve(llvm::IRBuilder<>&, const Continuation*);
llvm::Value* emit_reserve_shared(llvm::IRBuilder<>&, const Continuation*, bool=false);
virtual llvm::Value* emit_local_memory(llvm::IRBuilder<>&, const Continuation*);
llvm::Value* emit_local_memory_base_ptr(llvm::IRBuilder<>& irbuilder, const Continuation* continuation);

virtual std::string get_alloc_name() const = 0;
llvm::BasicBlock* cont2bb(Continuation* cont) { return cont2bb_[cont].first; }
Expand Down
4 changes: 4 additions & 0 deletions src/thorin/be/llvm/nvvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ llvm::Value* NVVMCodeGen::emit_reserve(llvm::IRBuilder<>& irbuilder, const Conti
return emit_reserve_shared(irbuilder, continuation);
}

llvm::Value* NVVMCodeGen::emit_local_memory(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
return emit_local_memory_base_ptr(irbuilder, continuation);
}

llvm::Value* NVVMCodeGen::emit_mathop(llvm::IRBuilder<>& irbuilder, const MathOp* mathop) {
auto make_key = [] (MathOpTag tag, unsigned bitwidth) { return (static_cast<unsigned>(tag) << 16) | bitwidth; };
static const std::unordered_map<unsigned, std::string> libdevice_functions = {
Expand Down
1 change: 1 addition & 0 deletions src/thorin/be/llvm/nvvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class NVVMCodeGen : public CodeGen {
llvm::Value* emit_mathop(llvm::IRBuilder<>&, const MathOp*) override;

llvm::Value* emit_reserve(llvm::IRBuilder<>&, const Continuation*) override;
llvm::Value* emit_local_memory(llvm::IRBuilder<>&, const Continuation*) override;

llvm::Value* emit_global(const Global*) override;

Expand Down
11 changes: 8 additions & 3 deletions src/thorin/be/llvm/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ void Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& builder, Plat
assert(continuation->has_body());
auto body = continuation->body();
// to-target is the desired kernel call
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), body, return, free_vars)
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), lmem, body, return, free_vars)
auto target = body->callee()->as_nom<Continuation>();
assert_unused(target->is_intrinsic());
assert(body->num_args() >= LaunchArgs::Num && "required arguments are missing");

auto& world = continuation->world();

// arguments
auto target_device_id = code_gen.emit(body->arg(LaunchArgs::Device));
auto target_platform = builder.getInt32(platform);
Expand All @@ -78,7 +80,6 @@ void Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& builder, Plat
auto it_config = body->arg(LaunchArgs::Config);
auto kernel = body->arg(LaunchArgs::Body)->as<Global>()->init()->as<Continuation>();

auto& world = continuation->world();
//auto kernel_name = builder.CreateGlobalStringPtr(kernel->name() == "hls_top" ? kernel->name() : kernel->name());
auto kernel_name = builder.CreateGlobalStringPtr(kernel->name());
auto file_name = builder.CreateGlobalStringPtr(world.name() + ext);
Expand Down Expand Up @@ -179,9 +180,12 @@ void Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& builder, Plat
allocs = builder.CreateInBoundsGEP(llvm::cast<llvm::AllocaInst>(allocs)->getAllocatedType(), allocs, gep_first_elem);
types = builder.CreateInBoundsGEP(llvm::cast<llvm::AllocaInst>(types)->getAllocatedType(), types, gep_first_elem);

auto lmem = code_gen.emit(body->arg(LaunchArgs::LocalMem));

launch_kernel(code_gen, builder, target_device,
file_name, kernel_name,
grid_size, block_size,
lmem,
args, sizes, aligns, allocs, types,
builder.getInt32(num_kernel_args));
}
Expand All @@ -190,10 +194,11 @@ llvm::Value* Runtime::launch_kernel(
CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* device,
llvm::Value* file, llvm::Value* kernel,
llvm::Value* grid, llvm::Value* block,
llvm::Value* lmem,
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
llvm::Value* num_args)
{
llvm::Value* launch_args[] = { device, file, kernel, grid, block, args, sizes, aligns, allocs, types, num_args };
llvm::Value* launch_args[] = { device, file, kernel, grid, block, lmem, args, sizes, aligns, allocs, types, num_args };
return builder.CreateCall(get(code_gen, "anydsl_launch_kernel"), launch_args);
}

Expand Down
1 change: 1 addition & 0 deletions src/thorin/be/llvm/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Runtime {
CodeGen&, llvm::IRBuilder<>&, llvm::Value* device,
llvm::Value* file, llvm::Value* kernel,
llvm::Value* grid, llvm::Value* block,
llvm::Value* lmem,
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
llvm::Value* num_args);

Expand Down
2 changes: 1 addition & 1 deletion src/thorin/be/llvm/runtime.inc
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ namespace thorin {
declare noalias ptr @anydsl_alloc(i32, i64);
declare noalias ptr @anydsl_alloc_unified(i32, i64);
declare void @anydsl_release(i32, ptr);
declare void @anydsl_launch_kernel(i32, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, ptr, i32);
declare void @anydsl_launch_kernel(i32, ptr, ptr, ptr, ptr, i32, ptr, ptr, ptr, ptr, ptr, i32);
declare void @anydsl_parallel_for(i32, i32, i32, ptr, ptr);
declare void @anydsl_fibers_spawn(i32, i32, i32, ptr, ptr);
declare i32 @anydsl_spawn_thread(ptr, ptr);
Expand Down
1 change: 1 addition & 0 deletions src/thorin/continuation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ void Continuation::set_intrinsic() {
else if (name() == "pe_info") attributes().intrinsic = Intrinsic::PeInfo;
else if (name() == "pipeline") attributes().intrinsic = Intrinsic::Pipeline;
else if (name() == "reserve_shared") attributes().intrinsic = Intrinsic::Reserve;
else if (name() == "local_memory") attributes().intrinsic = Intrinsic::LocalMemory;
else if (name() == "atomic") attributes().intrinsic = Intrinsic::Atomic;
else if (name() == "atomic_load") attributes().intrinsic = Intrinsic::AtomicLoad;
else if (name() == "atomic_store") attributes().intrinsic = Intrinsic::AtomicStore;
Expand Down
1 change: 1 addition & 0 deletions src/thorin/continuation.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ enum class Intrinsic : uint8_t {
Vectorize, ///< External vectorizer.
AcceleratorEnd,
Reserve = AcceleratorEnd, ///< Intrinsic memory reserve function
LocalMemory, ///< Intrinsic get local memory base pointer
Atomic, ///< Intrinsic atomic function
AtomicLoad, ///< Intrinsic atomic load function
AtomicStore, ///< Intrinsic atomic store function
Expand Down

0 comments on commit bb992d8

Please sign in to comment.