Skip to content

Commit 1c4d8f3

Browse files
committed
CUDA: add stream-based concurrency
1 parent 5d8bb90 commit 1c4d8f3

File tree

2 files changed

+326
-12
lines changed

2 files changed

+326
-12
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,32 @@ struct ggml_cuda_graph {
958958
#endif
959959
};
960960

961+
struct ggml_cuda_concurrent_event {
962+
std::vector<cudaEvent_t> per_stream_events;
963+
cudaEvent_t fork_event;
964+
cudaEvent_t join_event;
965+
966+
int n_streams = 0;
967+
std::unordered_map<const ggml_tensor *, int> stream_mapping;
968+
969+
const ggml_tensor * join_node;
970+
971+
ggml_cuda_concurrent_event() = default;
972+
973+
explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) {
974+
per_stream_events.resize(n_streams);
975+
976+
for (size_t i = 0; i < per_stream_events.size(); ++i) {
977+
cudaEventCreateWithFlags(&per_stream_events[i], cudaEventDisableTiming);
978+
}
979+
980+
cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming);
981+
cudaEventCreateWithFlags(&join_event, cudaEventDisableTiming);
982+
}
983+
};
984+
985+
using ggml_cuda_stream_context = std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event>;
986+
961987
struct ggml_backend_cuda_context {
962988
int device;
963989
std::string name;
@@ -968,11 +994,15 @@ struct ggml_backend_cuda_context {
968994

969995
std::unique_ptr<ggml_cuda_graph> cuda_graph;
970996

997+
int curr_stream_no = 0;
998+
971999
explicit ggml_backend_cuda_context(int device) :
9721000
device(device),
9731001
name(GGML_CUDA_NAME + std::to_string(device)) {
9741002
}
9751003

1004+
ggml_cuda_stream_context concurrent_stream_context;
1005+
9761006
~ggml_backend_cuda_context();
9771007

9781008
cudaStream_t stream(int device, int stream) {
@@ -983,9 +1013,9 @@ struct ggml_backend_cuda_context {
9831013
return streams[device][stream];
9841014
}
9851015

986-
cudaStream_t stream() {
987-
return stream(device, 0);
988-
}
1016+
cudaStream_t stream() { return stream(device, curr_stream_no); }
1017+
1018+
ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; }
9891019

9901020
cublasHandle_t cublas_handle(int device) {
9911021
if (cublas_handles[device] == nullptr) {
@@ -1001,15 +1031,15 @@ struct ggml_backend_cuda_context {
10011031
}
10021032

10031033
// pool
1004-
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
1034+
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS];
10051035

1006-
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
1036+
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device, int stream_no);
10071037

10081038
ggml_cuda_pool & pool(int device) {
1009-
if (pools[device] == nullptr) {
1010-
pools[device] = new_pool_for_device(device);
1039+
if (pools[device][curr_stream_no] == nullptr) {
1040+
pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no);
10111041
}
1012-
return *pools[device];
1042+
return *pools[device][curr_stream_no];
10131043
}
10141044

10151045
ggml_cuda_pool & pool() {

0 commit comments

Comments
 (0)