Skip to content

Commit 00bb7bd

Browse files
Add dynamic shared memory allocation
1 parent d360324 commit 00bb7bd

File tree

13 files changed

+77
-6
lines changed

13 files changed

+77
-6
lines changed

src/thorin/be/c/c.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,9 @@ void CCodeGen::emit_module() {
510510
stream_.fmt("__device__ inline int blockDim_{}() {{ return blockDim.{}; }}\n", x, x);
511511
stream_.fmt("__device__ inline int gridDim_{}() {{ return gridDim.{}; }}\n", x, x);
512512
}
513+
514+
stream_.fmt("\n"
515+
"extern __shared__ unsigned char __dynamic_smem[];\n");
513516
}
514517

515518
stream_.endl() << func_impls_.str();
@@ -737,7 +740,10 @@ void CCodeGen::emit_epilogue(Continuation* cont) {
737740
bb.tail.fmt("goto {};", label_name(callee));
738741
} else if (auto callee = body->callee()->isa_nom<Continuation>(); callee && callee->is_intrinsic()) {
739742
if (callee->intrinsic() == Intrinsic::Reserve) {
743+
assert(body->num_args() == 3 && "incorrect number of arguments");
744+
740745
emit_unsafe(body->arg(0));
746+
741747
if (!body->arg(1)->isa<PrimLit>())
742748
world().edef(body->arg(1), "reserve_shared: couldn't extract memory size");
743749

@@ -753,6 +759,16 @@ void CCodeGen::emit_epilogue(Continuation* cont) {
753759
}
754760
bb.tail.fmt("p_{} = {}_reserved;\n", ret_cont->param(1)->unique_name(), cont->unique_name());
755761
bb.tail.fmt("goto {};", label_name(ret_cont));
762+
} else if (callee->intrinsic() == Intrinsic::LocalMemory) {
763+
if (lang_ == Lang::HLS)
764+
world().edef(body, "local_memory not supported for HLS");
765+
assert(body->num_args() == 2 && "incorrect number of arguments");
766+
767+
emit_unsafe(body->arg(0));
768+
769+
auto ret_cont = body->arg(1)->as_nom<Continuation>();
770+
bb.tail.fmt("p_{} = __dynamic_smem;\n", ret_cont->param(1)->unique_name());
771+
bb.tail.fmt("goto {};", label_name(ret_cont));
756772
} else if (callee->intrinsic() == Intrinsic::Pipeline) {
757773
assert((lang_ == Lang::OpenCL || lang_ == Lang::HLS) && "pipelining not supported on this backend");
758774

@@ -1435,6 +1451,12 @@ std::string CCodeGen::emit_fun_head(Continuation* cont, bool is_proto) {
14351451
}
14361452
needs_comma = true;
14371453
}
1454+
1455+
if (cont->is_exported() && lang_ == Lang::OpenCL) {
1456+
if (needs_comma) s.fmt(", ");
1457+
s.fmt("__local unsigned char* __dynamic_smem");
1458+
}
1459+
14381460
s << ")";
14391461
return s.str();
14401462
}

src/thorin/be/codegen.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ struct LaunchArgs {
3232
Device,
3333
Space,
3434
Config,
35+
LocalMem,
3536
Body,
3637
Return,
3738
Num

src/thorin/be/llvm/amdgpu.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,4 +90,8 @@ Continuation* AMDGPUCodeGen::emit_reserve(llvm::IRBuilder<>& irbuilder, const Co
9090
return emit_reserve_shared(irbuilder, continuation, true);
9191
}
9292

93+
Continuation* AMDGPUCodeGen::emit_local_memory(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
94+
return emit_local_memory_base_ptr(irbuilder, continuation);
95+
}
96+
9397
}

src/thorin/be/llvm/amdgpu.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class AMDGPUCodeGen : public CodeGen {
2323
llvm::Value* emit_global(const Global*) override;
2424
llvm::Value* emit_mathop(llvm::IRBuilder<>&, const MathOp*) override;
2525
Continuation* emit_reserve(llvm::IRBuilder<>&, const Continuation*) override;
26+
Continuation* emit_local_memory(llvm::IRBuilder<>&, const Continuation*) override;
2627
std::string get_alloc_name() const override { return "malloc"; }
2728

2829
const Cont2Config& kernel_config_;

src/thorin/be/llvm/llvm.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,7 @@ Continuation* CodeGen::emit_intrinsic(llvm::IRBuilder<>& irbuilder, Continuation
11821182
case Intrinsic::CmpXchgWeak: return emit_cmpxchg(irbuilder, continuation, true);
11831183
case Intrinsic::Fence: return emit_fence(irbuilder, continuation);
11841184
case Intrinsic::Reserve: return emit_reserve(irbuilder, continuation);
1185+
case Intrinsic::LocalMemory: return emit_local_memory(irbuilder, continuation);
11851186
case Intrinsic::CUDA: return runtime_->emit_host_code(*this, irbuilder, Runtime::CUDA_PLATFORM, ".cu", continuation);
11861187
case Intrinsic::NVVM: return runtime_->emit_host_code(*this, irbuilder, Runtime::CUDA_PLATFORM, ".nvvm", continuation);
11871188
case Intrinsic::OpenCL: return runtime_->emit_host_code(*this, irbuilder, Runtime::OPENCL_PLATFORM, ".cl", continuation);
@@ -1309,7 +1310,7 @@ Continuation* CodeGen::emit_reserve(llvm::IRBuilder<>&, const Continuation* cont
13091310
Continuation* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const Continuation* continuation, bool init_undef) {
13101311
assert(continuation->has_body());
13111312
auto body = continuation->body();
1312-
assert(body->num_args() == 3 && "required arguments are missing");
1313+
assert(body->num_args() == 3 && "incorrect number of arguments");
13131314
if (!body->arg(1)->isa<PrimLit>())
13141315
world().edef(body->arg(1), "reserve_shared: couldn't extract memory size");
13151316
auto num_elems = body->arg(1)->as<PrimLit>()->ps32_value();
@@ -1327,6 +1328,33 @@ Continuation* CodeGen::emit_reserve_shared(llvm::IRBuilder<>& irbuilder, const C
13271328
return cont;
13281329
}
13291330

1331+
Continuation* CodeGen::emit_local_memory(llvm::IRBuilder<>&, const Continuation* continuation) {
1332+
world().edef(continuation, "local_memory: only allowed in device code");
1333+
THORIN_UNREACHABLE;
1334+
}
1335+
1336+
Continuation* CodeGen::emit_local_memory_base_ptr(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
1337+
static constexpr auto name = "__dynamic_smem";
1338+
1339+
assert(continuation->has_body());
1340+
auto body = continuation->body();
1341+
assert(body->num_args() == 2 && "incorrect number of arguments");
1342+
auto cont = body->arg(1)->as_nom<Continuation>();
1343+
1344+
auto global = [&] {
1345+
if (auto found = module().getGlobalVariable(name))
1346+
return found;
1347+
1348+
auto type = llvm::ArrayType::get(llvm::Type::getInt8Ty(context()), 0);
1349+
auto global = new llvm::GlobalVariable(module(), type, false, llvm::GlobalValue::ExternalLinkage, nullptr, name, nullptr, llvm::GlobalVariable::NotThreadLocal, 3);
1350+
global->setAlignment(llvm::Align(16));
1351+
return global;
1352+
}();
1353+
1354+
emit_phi_arg(irbuilder, cont->param(1), global);
1355+
return cont;
1356+
}
1357+
13301358
/*
13311359
* backend-specific stuff
13321360
*/

src/thorin/be/llvm/llvm.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class CodeGen : public thorin::CodeGen, public thorin::Emitter<llvm::Value*, llv
8383

8484
virtual Continuation* emit_reserve(llvm::IRBuilder<>&, const Continuation*);
8585
Continuation* emit_reserve_shared(llvm::IRBuilder<>&, const Continuation*, bool=false);
86+
virtual Continuation* emit_local_memory(llvm::IRBuilder<>&, const Continuation*);
87+
Continuation* emit_local_memory_base_ptr(llvm::IRBuilder<>& irbuilder, const Continuation* continuation);
8688

8789
virtual std::string get_alloc_name() const = 0;
8890
llvm::BasicBlock* cont2bb(Continuation* cont) { return cont2bb_[cont].first; }

src/thorin/be/llvm/nvvm.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ Continuation* NVVMCodeGen::emit_reserve(llvm::IRBuilder<>& irbuilder, const Cont
246246
return emit_reserve_shared(irbuilder, continuation);
247247
}
248248

249+
Continuation* NVVMCodeGen::emit_local_memory(llvm::IRBuilder<>& irbuilder, const Continuation* continuation) {
250+
return emit_local_memory_base_ptr(irbuilder, continuation);
251+
}
252+
249253
llvm::Value* NVVMCodeGen::emit_mathop(llvm::IRBuilder<>& irbuilder, const MathOp* mathop) {
250254
auto make_key = [] (MathOpTag tag, unsigned bitwidth) { return (static_cast<unsigned>(tag) << 16) | bitwidth; };
251255
static const std::unordered_map<unsigned, std::string> libdevice_functions = {

src/thorin/be/llvm/nvvm.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class NVVMCodeGen : public CodeGen {
2929
llvm::Value* emit_mathop(llvm::IRBuilder<>&, const MathOp*) override;
3030

3131
Continuation* emit_reserve(llvm::IRBuilder<>&, const Continuation*) override;
32+
Continuation* emit_local_memory(llvm::IRBuilder<>&, const Continuation*) override;
3233

3334
llvm::Value* emit_global(const Global*) override;
3435

src/thorin/be/llvm/runtime.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,21 +64,24 @@ Continuation* Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& buil
6464
assert(continuation->has_body());
6565
auto body = continuation->body();
6666
// to-target is the desired kernel call
67-
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), body, return, free_vars)
67+
// target(mem, device, (dim.x, dim.y, dim.z), (block.x, block.y, block.z), lmem, body, return, free_vars)
6868
auto target = body->callee()->as_nom<Continuation>();
6969
assert_unused(target->is_intrinsic());
7070
assert(body->num_args() >= LaunchArgs::Num && "required arguments are missing");
7171

72+
auto& world = continuation->world();
73+
7274
// arguments
7375
auto target_device_id = code_gen.emit(body->arg(LaunchArgs::Device));
7476
auto target_platform = builder.getInt32(platform);
7577
auto target_device = builder.CreateOr(target_platform, builder.CreateShl(target_device_id, builder.getInt32(4)));
7678

7779
auto it_space = body->arg(LaunchArgs::Space);
7880
auto it_config = body->arg(LaunchArgs::Config);
79-
auto kernel = body->arg(LaunchArgs::Body)->as<Global>()->init()->as<Continuation>();
8081

81-
auto& world = continuation->world();
82+
auto lmem = code_gen.emit(body->arg(LaunchArgs::LocalMem));
83+
84+
auto kernel = body->arg(LaunchArgs::Body)->as<Global>()->init()->as<Continuation>();
8285
auto kernel_name = builder.CreateGlobalStringPtr(kernel->name() == "hls_top" ? kernel->name() : kernel->unique_name());
8386
auto file_name = builder.CreateGlobalStringPtr(world.name() + ext);
8487
const size_t num_kernel_args = body->num_args() - LaunchArgs::Num;
@@ -181,6 +184,7 @@ Continuation* Runtime::emit_host_code(CodeGen& code_gen, llvm::IRBuilder<>& buil
181184
launch_kernel(code_gen, builder, target_device,
182185
file_name, kernel_name,
183186
grid_size, block_size,
187+
lmem,
184188
args, sizes, aligns, allocs, types,
185189
builder.getInt32(num_kernel_args));
186190

@@ -191,10 +195,11 @@ llvm::Value* Runtime::launch_kernel(
191195
CodeGen& code_gen, llvm::IRBuilder<>& builder, llvm::Value* device,
192196
llvm::Value* file, llvm::Value* kernel,
193197
llvm::Value* grid, llvm::Value* block,
198+
llvm::Value* lmem,
194199
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
195200
llvm::Value* num_args)
196201
{
197-
llvm::Value* launch_args[] = { device, file, kernel, grid, block, args, sizes, aligns, allocs, types, num_args };
202+
llvm::Value* launch_args[] = { device, file, kernel, grid, block, lmem, args, sizes, aligns, allocs, types, num_args };
198203
return builder.CreateCall(get(code_gen, "anydsl_launch_kernel"), launch_args);
199204
}
200205

src/thorin/be/llvm/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class Runtime {
3030
CodeGen&, llvm::IRBuilder<>&, llvm::Value* device,
3131
llvm::Value* file, llvm::Value* kernel,
3232
llvm::Value* grid, llvm::Value* block,
33+
llvm::Value* lmem,
3334
llvm::Value* args, llvm::Value* sizes, llvm::Value* aligns, llvm::Value* allocs, llvm::Value* types,
3435
llvm::Value* num_args);
3536

0 commit comments

Comments
 (0)