Skip to content

Commit cf05fb5

Browse files
authored
Add wait to MKL calls (#518)
1 parent 115a10f commit cf05fb5

File tree

4 files changed

+774
-2
lines changed

4 files changed

+774
-2
lines changed

deps/generate_interfaces.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,17 @@ function generate_cpp(library::String, filename::Vector{String}, output::String;
447447
write(oneapi_cpp, "extern \"C\" $header {\n")
448448
if template
449449
type = version_types[version]
450-
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n")
451-
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
450+
!occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n device_queue->val.wait_and_throw();\n")
451+
occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n device_queue->val.wait_and_throw();\n")
452+
# !occursin("scratchpad_size", name) && write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name<$type>($parameters, {});\n")
453+
# occursin("scratchpad_size", name) && write(oneapi_cpp, " int64_t scratchpad_size = oneapi::mkl::$library::$variant$name<$type>($parameters);\n")
452454
else
453455
if !(name void_output)
454456
write(oneapi_cpp, " auto status = oneapi::mkl::$library::$variant$name($parameters, {});\n")
457+
occursin("device_queue", parameters) && write(oneapi_cpp, " device_queue->val.wait_and_throw();\n")
455458
else
456459
write(oneapi_cpp, " oneapi::mkl::$library::$variant$name($parameters);\n")
460+
occursin("device_queue", parameters) && write(oneapi_cpp, " device_queue->val.wait_and_throw();\n")
457461
end
458462
end
459463
if occursin("scratchpad_size", name)

deps/onemkl_epilogue.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
extern "C" int onemklXsparse_matmat(syclQueue_t device_queue, matrix_handle_t A, matrix_handle_t B, matrix_handle_t C, onemklMatmatRequest req, matmat_descr_t descr, int64_t *sizeTempBuffer, void *tempBuffer) {
22
auto status = oneapi::mkl::sparse::matmat(device_queue->val, (oneapi::mkl::sparse::matrix_handle_t) A, (oneapi::mkl::sparse::matrix_handle_t) B, (oneapi::mkl::sparse::matrix_handle_t) C, convert(req), (oneapi::mkl::sparse::matmat_descr_t) descr, sizeTempBuffer, tempBuffer, {});
3+
device_queue->val.wait_and_throw();
34
return 0;
45
}
56

deps/onemkl_prologue.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,13 +393,15 @@ extern "C" int onemklHgemm_batch(syclQueue_t device_queue, onemklTranspose trans
393393
int64_t *ldb, uint16_t *beta, short **c,
394394
int64_t *ldc, int64_t group_count, int64_t *group_size) {
395395
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
396+
device_queue->val.wait_and_throw();
396397
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
397398
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
398399
m, n, k, reinterpret_cast<sycl::half *>(alpha),
399400
reinterpret_cast<const sycl::half **>(&a[0]), lda,
400401
reinterpret_cast<const sycl::half **>(&b[0]), ldb,
401402
reinterpret_cast<sycl::half *>(beta), reinterpret_cast<sycl::half **>(&c[0]),
402403
ldc, group_count, group_size, {});
404+
device_queue->val.wait_and_throw();
403405
return 0;
404406
}
405407

@@ -410,13 +412,15 @@ extern "C" int onemklSgemm_batch(syclQueue_t device_queue, onemklTranspose trans
410412
int64_t *ldb, float *beta, float **c,
411413
int64_t *ldc, int64_t group_count, int64_t *group_size) {
412414
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
415+
device_queue->val.wait_and_throw();
413416
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
414417
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
415418
m, n, k, alpha,
416419
(const float **)&a[0], lda,
417420
(const float **)&b[0], ldb,
418421
beta, &c[0], ldc,
419422
group_count, group_size, {});
423+
device_queue->val.wait_and_throw();
420424
return 0;
421425
}
422426

@@ -427,13 +431,15 @@ extern "C" int onemklDgemm_batch(syclQueue_t device_queue, onemklTranspose trans
427431
int64_t *ldb, double *beta, double **c,
428432
int64_t *ldc, int64_t group_count, int64_t *group_size) {
429433
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
434+
device_queue->val.wait_and_throw();
430435
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
431436
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
432437
m, n, k, alpha,
433438
(const double **)&a[0], lda,
434439
(const double **)&b[0], ldb,
435440
beta, &c[0], ldc,
436441
group_count, group_size, {});
442+
device_queue->val.wait_and_throw();
437443
return 0;
438444
}
439445

@@ -445,6 +451,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
445451
int64_t *ldb, float _Complex *beta, float _Complex **c,
446452
int64_t *ldc, int64_t group_count, int64_t *group_size) {
447453
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
454+
device_queue->val.wait_and_throw();
448455
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
449456
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
450457
m, n, k, reinterpret_cast<std::complex<float> *>(alpha),
@@ -455,6 +462,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
455462
reinterpret_cast<std::complex<float> *>(beta),
456463
reinterpret_cast<std::complex<float> **>(&c[0]), ldc,
457464
group_count, group_size, {});
465+
device_queue->val.wait_and_throw();
458466
return 0;
459467
}
460468

@@ -467,6 +475,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
467475
double _Complex **c,
468476
int64_t *ldc, int64_t group_count, int64_t *group_size) {
469477
gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb);
478+
device_queue->val.wait_and_throw();
470479
auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val,
471480
&gemmInfo.m_transa[0], &gemmInfo.m_transb[0],
472481
m, n, k, reinterpret_cast<std::complex<double> *>(alpha),
@@ -477,6 +486,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
477486
reinterpret_cast<std::complex<double> *>(beta),
478487
reinterpret_cast<std::complex<double> **>(&c[0]), ldc,
479488
group_count, group_size, {});
489+
device_queue->val.wait_and_throw();
480490
return 0;
481491
}
482492

@@ -487,12 +497,14 @@ extern "C" int onemklStrsm_batch(syclQueue_t device_queue, onemklSide left_right
487497
int64_t group_count, int64_t *group_size) {
488498
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
489499
unit_diag, group_count);
500+
device_queue->val.wait_and_throw();
490501

491502
auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
492503
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
493504
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
494505
m, n, alpha, (const float **)&a[0], lda,
495506
&b[0], ldb, group_count, group_size, {});
507+
device_queue->val.wait_and_throw();
496508
return 0;
497509
}
498510

@@ -504,12 +516,14 @@ extern "C" int onemklDtrsm_batch(syclQueue_t device_queue, onemklSide left_right
504516
int64_t *group_size) {
505517
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
506518
unit_diag, group_count);
519+
device_queue->val.wait_and_throw();
507520

508521
auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
509522
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
510523
&trsmInfo.m_transa[0], &trsmInfo.m_unitdiag[0],
511524
m, n, alpha, (const double **)&a[0], lda, &b[0],
512525
ldb, group_count, group_size, {});
526+
device_queue->val.wait_and_throw();
513527
return 0;
514528
}
515529

@@ -521,6 +535,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
521535
int64_t group_count, int64_t *group_size) {
522536
trsmBatchInfo trsmInfo(device_queue, left_right, upper_lower, transa,
523537
unit_diag, group_count);
538+
device_queue->val.wait_and_throw();
524539

525540
auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
526541
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
@@ -529,6 +544,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
529544
reinterpret_cast<const std::complex<float> **>(&a[0]),
530545
lda, reinterpret_cast<std::complex<float> **>(&b[0]),
531546
ldb, group_count, group_size, {});
547+
device_queue->val.wait_and_throw();
532548
return 0;
533549
}
534550

@@ -540,6 +556,7 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
540556
int64_t group_count, int64_t *group_size) {
541557
trsmBatchInfo trsmInfo(device_queue, left_right,
542558
upper_lower, transa, unit_diag, group_count);
559+
device_queue->val.wait_and_throw();
543560

544561
auto status = oneapi::mkl::blas::column_major::trsm_batch(device_queue->val,
545562
&trsmInfo.m_leftright[0], &trsmInfo.m_upperlower[0],
@@ -548,5 +565,6 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
548565
reinterpret_cast<const std::complex<double> **>(&a[0]),
549566
lda, reinterpret_cast<std::complex<double> **>(&b[0]),
550567
ldb, group_count, group_size, {});
568+
device_queue->val.wait_and_throw();
551569
return 0;
552570
}

0 commit comments

Comments
 (0)