Skip to content

Commit ff2aaf6

Browse files
Add dynamic shared memory allocation
1 parent cf797d7 commit ff2aaf6

File tree

13 files changed

+93
-27
lines changed

13 files changed

+93
-27
lines changed

src/thorin/be/c/c.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,9 @@ void CCodeGen::emit_module() {
512512
stream_.fmt("__device__ inline int blockDim_{}() {{ return blockDim.{}; }}\n", x, x);
513513
stream_.fmt("__device__ inline int gridDim_{}() {{ return gridDim.{}; }}\n", x, x);
514514
}
515+
516+
stream_.fmt("\n"
517+
"extern __shared__ unsigned char __dynamic_smem[];\n");
515518
}
516519

517520
stream_.endl() << func_impls_.str();
@@ -742,7 +745,10 @@ void CCodeGen::emit_epilogue(Continuation* cont) {
742745
bb.tail.fmt("goto {};", label_name(callee));
743746
} else if (auto callee = body->callee()->isa_nom<Continuation>(); callee && callee->is_intrinsic()) {
744747
if (callee->intrinsic() == Intrinsic::Reserve) {
748+
assert(body->num_args() == 3 && "incorrect number of arguments");
749+
745750
emit_unsafe(body->arg(0));
751+
746752
if (!body->arg(1)->isa<PrimLit>())
747753
world().edef(body->arg(1), "reserve_shared: couldn't extract memory size");
748754

@@ -758,6 +764,16 @@ void CCodeGen::emit_epilogue(Continuation* cont) {
758764
}
759765
bb.tail.fmt("p_{} = {}_reserved;\n", ret_cont->param(1)->unique_name(), cont->unique_name());
760766
bb.tail.fmt("goto {};", label_name(ret_cont));
767+
} else if (callee->intrinsic() == Intrinsic::LocalMemory) {
768+
if (lang_ == Lang::HLS)
769+
world().edef(body, "local_memory not supported for HLS");
770+
assert(body->num_args() == 2 && "incorrect number of arguments");
771+
772+
emit_unsafe(body->arg(0));
773+
774+
auto ret_cont = body->arg(1)->as_nom<Continuation>();
775+
bb.tail.fmt("p_{} = __dynamic_smem;\n", ret_cont->param(1)->unique_name());
776+
bb.tail.fmt("goto {};", label_name(ret_cont));
761777
} else if (callee->intrinsic() == Intrinsic::Pipeline) {
762778
assert((lang_ == Lang::OpenCL || lang_ == Lang::HLS) && "pipelining not supported on this backend");
763779

@@ -1440,6 +1456,12 @@ std::string CCodeGen::emit_fun_head(Continuation* cont, bool is_proto) {
14401456
}
14411457
needs_comma = true;
14421458
}
1459+
1460+
if (cont->is_exported() && lang_ == Lang::OpenCL) {
1461+
if (needs_comma) s.fmt(", ");
1462+
s.fmt("__local unsigned char* __dynamic_smem");
1463+
}
1464+
14431465
s << ")";
14441466
return s.str();
14451467
}

src/thorin/be/codegen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ struct LaunchArgs {
3333
Device,
3434
Space,
3535
Config,
36+
LocalMem,
3637
Body,
3738
Return,
3839
Num

src/thorin/be/llvm/amdgpu.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,8 @@ llvm::Value* AMDGPUCodeGen::emit_reserve(llvm::IRBuilder<>& irbuilder, const Con
7676
return emit_reserve_shared(irbuilder, continuation, true);
7777
}
7878

79+
llvm::Value* AMDGPUCodeGen::emit_local_memory(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
80+
return emit_local_memory_base_ptr(irbuilder, continuation);
81+
}
82+
7983
}

src/thorin/be/llvm/amdgpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class AMDGPUCodeGen : public CodeGen {
2121
llvm::Value* emit_global(const Global*) override;
2222
llvm::Value* emit_mathop(llvm::IRBuilder<>&, const MathOp*) override;
2323
llvm::Value* emit_reserve(llvm::IRBuilder<>&, const Continuation*) override;
24+
llvm::Value* emit_local_memory(llvm::IRBuilder<>&, const Continuation*) override;
2425
std::string get_alloc_name() const override { return "malloc"; }
2526

2627
const Cont2Config& kernel_config_;

src/thorin/be/llvm/llvm.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1300,6 +1300,7 @@ std::vector<llvm::Value*> CodeGen::emit_intrinsic(llvm::IRBuilder<>& irbuilder,
13001300
case Intrinsic::CmpXchgWeak: return emit_cmpxchg(irbuilder, continuation, true);
13011301
case Intrinsic::Fence: emit_fence(irbuilder, continuation); break;
13021302
case Intrinsic::Reserve: return { emit_reserve(irbuilder, continuation) };
1303+
case Intrinsic::LocalMemory: return { emit_local_memory(irbuilder, continuation) };
13031304
case Intrinsic::CUDA: runtime_->emit_host_code(*this, irbuilder, Runtime::CUDA_PLATFORM, ".cu", continuation); break;
13041305
case Intrinsic::NVVM: runtime_->emit_host_code(*this, irbuilder, Runtime::CUDA_PLATFORM, ".nvvm", continuation); break;
13051306
case Intrinsic::OpenCL: runtime_->emit_host_code(*this, irbuilder, Runtime::OPENCL_PLATFORM, ".cl", continuation); break;
@@ -1420,7 +1421,7 @@ llvm::Value* CodeGen::emit_reserve(llvm::IRBuilder<>&, const Continuation* conti
14201421
llvm::Value* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const Continuation* continuation, bool init_undef) {
14211422
assert(continuation->has_body());
14221423
auto body = continuation->body();
1423-
assert(body->num_args() == 3 && "required arguments are missing");
1424+
assert(body->num_args() == 3 && "incorrect number of arguments");
14241425
if (!body->arg(1)->isa<PrimLit>())
14251426
world().edef(body->arg(1), "reserve_shared: couldn't extract memory size");
14261427
auto num_elems = body->arg(1)->as<PrimLit>()->ps32_value();
@@ -1437,6 +1438,28 @@ llvm::Value* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const Co
14371438
return call;
14381439
}
14391440

1441+
llvm::Value* CodeGen::emit_local_memory(llvm::IRBuilder<>&, const Continuation* continuation) {
1442+
world().edef(continuation, "local_memory: only allowed in device code");
1443+
THORIN_UNREACHABLE;
1444+
}
1445+
1446+
llvm::Value* CodeGen::emit_local_memory_base_ptr(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
1447+
static constexpr auto name = "__dynamic_smem";
1448+
1449+
assert(continuation->has_body());
1450+
auto body = continuation->body();
1451+
assert(body->num_args() == 2 && "incorrect number of arguments");
1452+
auto cont = body->arg(1)->as_nom<Continuation>();
1453+
1454+
if (auto found = module().getGlobalVariable(name))
1455+
return found;
1456+
1457+
auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(context()), 0);
1458+
auto global = new llvm::GlobalVariable(module(), type, false, llvm::GlobalValue::ExternalLinkage, nullptr, name, nullptr, llvm::GlobalVariable::NotThreadLocal, 3);
1459+
global->setAlignment(llvm::Align(16));
1460+
return global;
1461+
}
1462+
14401463
/*
14411464
* backend-specific stuff
14421465
*/

src/thorin/be/llvm/llvm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ class CodeGen : public thorin::CodeGen, public thorin::Emitter<llvm::Value*, llv
8585

8686
virtual llvm::Value* emit_reserve(llvm::IRBuilder<>&, const Continuation*);
8787
llvm::Value* emit_reserve_shared(llvm::IRBuilder<>&, const Continuation*, bool=false);
88+
virtual llvm::Value* emit_local_memory(llvm::IRBuilder<>&, const Continuation*);
89+
llvm::Value* emit_local_memory_base_ptr(llvm::IRBuilder<>& irbuilder, const Continuation* continuation);
8890

8991
virtual std::string get_alloc_name() const = 0;
9092
llvm::BasicBlock* cont2bb(Continuation* cont) { return cont2bb_[cont].first; }

src/thorin/be/llvm/nvvm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ llvm::Value* NVVMCodeGen::emit_reserve(llvm::IRBuilder<>& irbuilder, const Conti
246246
return emit_reserve_shared(irbuilder, continuation);
247247
}
248248

249+
llvm::Value* NVVMCodeGen::emit_local_memory(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
250+
return emit_local_memory_base_ptr(irbuilder, continuation);
251+
}
252+
249253
llvm::Value* NVVMCodeGen::emit_mathop(llvm::IRBuilder<>& irbuilder, const MathOp* mathop) {
250254
auto make_key = [] (MathOpTag tag, unsigned bitwidth) { return (static_cast<unsigned>(tag) << 16) | bitwidth; };
251255
static const std::unordered_map<unsigned, std::string> libdevice_functions = {

src/thorin/be/llvm/nvvm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class NVVMCodeGen : public CodeGen {
2929
llvm::Value* emit_mathop(llvm::IRBuilder<>&, const MathOp*) override;
3030

3131
llvm::Value* emit_reserve(llvm::IRBuilder<>&, const Continuation*) override;
32+
llvm::Value* emit_local_memory(llvm::IRBuilder<>&, const Continuation*) override;
3233

3334
llvm::Value* emit_global(const Global*) override;
3435

src/thorin/be/llvm/runtime.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,13 @@ void Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& builder, Plat
6464
assert(continuation->has_body());
6565
auto body = continuation->body();
6666
// to-target is the desired kernel call
67-
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), body, return, free_vars)
67+
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), lmem, body, return, free_vars)
6868
auto target = body->callee()->as_nom<Continuation>();
6969
assert_unused(target->is_intrinsic());
7070
assert(body->num_args() >= LaunchArgs::Num && "required arguments are missing");
7171

72+
auto& world = continuation->world();
73+
7274
// arguments
7375
auto target_device_id = code_gen.emit(body->arg(LaunchArgs::Device));
7476
auto target_platform = builder.getInt32(platform);
@@ -78,7 +80,6 @@ void Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& builder, Plat
7880
auto it_config = body->arg(LaunchArgs::Config);
7981
auto kernel = body->arg(LaunchArgs::Body)->as<Global>()->init()->as<Continuation>();
8082

81-
auto& world = continuation->world();
8283
//auto kernel_name = builder.CreateGlobalStringPtr(kernel->name() == "hls_top" ? kernel->name() : kernel->name());
8384
auto kernel_name = builder.CreateGlobalStringPtr(kernel->name());
8485
auto file_name = builder.CreateGlobalStringPtr(world.name() + ext);
@@ -179,9 +180,12 @@ void Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& builder, Plat
179180
allocs = builder.CreateInBoundsGEP(llvm::cast<llvm::AllocaInst>(allocs)->getAllocatedType(), allocs, gep_first_elem);
180181
types = builder.CreateInBoundsGEP(llvm::cast<llvm::AllocaInst>(types)->getAllocatedType(), types, gep_first_elem);
181182

183+
auto lmem = code_gen.emit(body->arg(LaunchArgs::LocalMem));
184+
182185
launch_kernel(code_gen, builder, target_device,
183186
file_name, kernel_name,
184187
grid_size, block_size,
188+
lmem,
185189
args, sizes, aligns, allocs, types,
186190
builder.getInt32(num_kernel_args));
187191
}
@@ -190,10 +194,11 @@ llvm::Value* Runtime::launch_kernel(
190194
CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* device,
191195
llvm::Value* file, llvm::Value* kernel,
192196
llvm::Value* grid, llvm::Value* block,
197+
llvm::Value* lmem,
193198
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
194199
llvm::Value* num_args)
195200
{
196-
llvm::Value* launch_args[] = { device, file, kernel, grid, block, args, sizes, aligns, allocs, types, num_args };
201+
llvm::Value* launch_args[] = { device, file, kernel, grid, block, lmem, args, sizes, aligns, allocs, types, num_args };
197202
return builder.CreateCall(get(code_gen, "anydsl_launch_kernel"), launch_args);
198203
}
199204

src/thorin/be/llvm/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Runtime {
3232
CodeGen&, llvm::IRBuilder<>&, llvm::Value* device,
3333
llvm::Value* file, llvm::Value* kernel,
3434
llvm::Value* grid, llvm::Value* block,
35+
llvm::Value* lmem,
3536
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
3637
llvm::Value* num_args);
3738

0 commit comments

Comments
 (0)