@@ -958,6 +958,32 @@ struct ggml_cuda_graph {
958958#endif
959959};
960960
961+ struct ggml_cuda_concurrent_event {
962+ std::vector<cudaEvent_t> per_stream_events;
963+ cudaEvent_t fork_event;
964+ cudaEvent_t join_event;
965+
966+ int n_streams = 0 ;
967+ std::unordered_map<const ggml_tensor *, int > stream_mapping;
968+
969+ const ggml_tensor * join_node;
970+
971+ ggml_cuda_concurrent_event () = default ;
972+
973+ explicit ggml_cuda_concurrent_event (int n_streams) : n_streams(n_streams) {
974+ per_stream_events.resize (n_streams);
975+
976+ for (size_t i = 0 ; i < per_stream_events.size (); ++i) {
977+ cudaEventCreateWithFlags (&per_stream_events[i], cudaEventDisableTiming);
978+ }
979+
980+ cudaEventCreateWithFlags (&fork_event, cudaEventDisableTiming);
981+ cudaEventCreateWithFlags (&join_event, cudaEventDisableTiming);
982+ }
983+ };
984+
985+ using ggml_cuda_stream_context = std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event>;
986+
961987struct ggml_backend_cuda_context {
962988 int device;
963989 std::string name;
@@ -968,11 +994,15 @@ struct ggml_backend_cuda_context {
968994
969995 std::unique_ptr<ggml_cuda_graph> cuda_graph;
970996
997+ int curr_stream_no = 0 ;
998+
971999 explicit ggml_backend_cuda_context (int device) :
9721000 device(device),
9731001 name(GGML_CUDA_NAME + std::to_string(device)) {
9741002 }
9751003
1004+ ggml_cuda_stream_context concurrent_stream_context;
1005+
9761006 ~ggml_backend_cuda_context ();
9771007
9781008 cudaStream_t stream (int device, int stream) {
@@ -983,9 +1013,9 @@ struct ggml_backend_cuda_context {
9831013 return streams[device][stream];
9841014 }
9851015
986- cudaStream_t stream () {
987- return stream (device, 0 );
988- }
1016+ cudaStream_t stream () { return stream (device, curr_stream_no); }
1017+
1018+ ggml_cuda_stream_context & stream_context () { return concurrent_stream_context; }
9891019
9901020 cublasHandle_t cublas_handle (int device) {
9911021 if (cublas_handles[device] == nullptr ) {
@@ -1001,15 +1031,15 @@ struct ggml_backend_cuda_context {
10011031 }
10021032
10031033 // pool
1004- std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
1034+ std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] ;
10051035
1006- static std::unique_ptr<ggml_cuda_pool> new_pool_for_device (int device);
1036+ static std::unique_ptr<ggml_cuda_pool> new_pool_for_device (int device, int stream_no );
10071037
10081038 ggml_cuda_pool & pool (int device) {
1009- if (pools[device] == nullptr ) {
1010- pools[device] = new_pool_for_device (device);
1039+ if (pools[device][curr_stream_no] == nullptr ) {
1040+ pools[device][curr_stream_no] = new_pool_for_device (device, curr_stream_no );
10111041 }
1012- return *pools[device];
1042+ return *pools[device][curr_stream_no] ;
10131043 }
10141044
10151045 ggml_cuda_pool & pool () {
0 commit comments