@@ -393,13 +393,15 @@ extern "C" int onemklHgemm_batch(syclQueue_t device_queue, onemklTranspose trans
393
393
int64_t *ldb, uint16_t *beta, short **c,
394
394
int64_t *ldc, int64_t group_count, int64_t *group_size) {
395
395
gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
396
+ device_queue->val .wait_and_throw ();
396
397
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
397
398
&gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
398
399
m, n, k, reinterpret_cast <sycl::half *>(alpha),
399
400
reinterpret_cast <const sycl::half **>(&a[0 ]), lda,
400
401
reinterpret_cast <const sycl::half **>(&b[0 ]), ldb,
401
402
reinterpret_cast <sycl::half *>(beta), reinterpret_cast <sycl::half **>(&c[0 ]),
402
403
ldc, group_count, group_size, {});
404
+ device_queue->val .wait_and_throw ();
403
405
return 0 ;
404
406
}
405
407
@@ -410,13 +412,15 @@ extern "C" int onemklSgemm_batch(syclQueue_t device_queue, onemklTranspose trans
410
412
int64_t *ldb, float *beta, float **c,
411
413
int64_t *ldc, int64_t group_count, int64_t *group_size) {
412
414
gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
415
+ device_queue->val .wait_and_throw ();
413
416
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
414
417
&gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
415
418
m, n, k, alpha,
416
419
(const float **)&a[0 ], lda,
417
420
(const float **)&b[0 ], ldb,
418
421
beta, &c[0 ], ldc,
419
422
group_count, group_size, {});
423
+ device_queue->val .wait_and_throw ();
420
424
return 0 ;
421
425
}
422
426
@@ -427,13 +431,15 @@ extern "C" int onemklDgemm_batch(syclQueue_t device_queue, onemklTranspose trans
427
431
int64_t *ldb, double *beta, double **c,
428
432
int64_t *ldc, int64_t group_count, int64_t *group_size) {
429
433
gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
434
+ device_queue->val .wait_and_throw ();
430
435
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
431
436
&gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
432
437
m, n, k, alpha,
433
438
(const double **)&a[0 ], lda,
434
439
(const double **)&b[0 ], ldb,
435
440
beta, &c[0 ], ldc,
436
441
group_count, group_size, {});
442
+ device_queue->val .wait_and_throw ();
437
443
return 0 ;
438
444
}
439
445
@@ -445,6 +451,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
445
451
int64_t *ldb, float _Complex *beta, float _Complex **c,
446
452
int64_t *ldc, int64_t group_count, int64_t *group_size) {
447
453
gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
454
+ device_queue->val .wait_and_throw ();
448
455
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
449
456
&gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
450
457
m, n, k, reinterpret_cast <std::complex <float > *>(alpha),
@@ -455,6 +462,7 @@ extern "C" int onemklCgemm_batch(syclQueue_t device_queue, onemklTranspose trans
455
462
reinterpret_cast <std::complex <float > *>(beta),
456
463
reinterpret_cast <std::complex <float > **>(&c[0 ]), ldc,
457
464
group_count, group_size, {});
465
+ device_queue->val .wait_and_throw ();
458
466
return 0 ;
459
467
}
460
468
@@ -467,6 +475,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
467
475
double _Complex **c,
468
476
int64_t *ldc, int64_t group_count, int64_t *group_size) {
469
477
gemmBatchInfo gemmInfo (device_queue, group_count, transa, transb);
478
+ device_queue->val .wait_and_throw ();
470
479
auto status = oneapi::mkl::blas::column_major::gemm_batch (device_queue->val ,
471
480
&gemmInfo.m_transa [0 ], &gemmInfo.m_transb [0 ],
472
481
m, n, k, reinterpret_cast <std::complex <double > *>(alpha),
@@ -477,6 +486,7 @@ extern "C" int onemklZgemm_batch(syclQueue_t device_queue, onemklTranspose trans
477
486
reinterpret_cast <std::complex <double > *>(beta),
478
487
reinterpret_cast <std::complex <double > **>(&c[0 ]), ldc,
479
488
group_count, group_size, {});
489
+ device_queue->val .wait_and_throw ();
480
490
return 0 ;
481
491
}
482
492
@@ -487,12 +497,14 @@ extern "C" int onemklStrsm_batch(syclQueue_t device_queue, onemklSide left_right
487
497
int64_t group_count, int64_t *group_size) {
488
498
trsmBatchInfo trsmInfo (device_queue, left_right, upper_lower, transa,
489
499
unit_diag, group_count);
500
+ device_queue->val .wait_and_throw ();
490
501
491
502
auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
492
503
&trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
493
504
&trsmInfo.m_transa [0 ], &trsmInfo.m_unitdiag [0 ],
494
505
m, n, alpha, (const float **)&a[0 ], lda,
495
506
&b[0 ], ldb, group_count, group_size, {});
507
+ device_queue->val .wait_and_throw ();
496
508
return 0 ;
497
509
}
498
510
@@ -504,12 +516,14 @@ extern "C" int onemklDtrsm_batch(syclQueue_t device_queue, onemklSide left_right
504
516
int64_t *group_size) {
505
517
trsmBatchInfo trsmInfo (device_queue, left_right, upper_lower, transa,
506
518
unit_diag, group_count);
519
+ device_queue->val .wait_and_throw ();
507
520
508
521
auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
509
522
&trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
510
523
&trsmInfo.m_transa [0 ], &trsmInfo.m_unitdiag [0 ],
511
524
m, n, alpha, (const double **)&a[0 ], lda, &b[0 ],
512
525
ldb, group_count, group_size, {});
526
+ device_queue->val .wait_and_throw ();
513
527
return 0 ;
514
528
}
515
529
@@ -521,6 +535,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
521
535
int64_t group_count, int64_t *group_size) {
522
536
trsmBatchInfo trsmInfo (device_queue, left_right, upper_lower, transa,
523
537
unit_diag, group_count);
538
+ device_queue->val .wait_and_throw ();
524
539
525
540
auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
526
541
&trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
@@ -529,6 +544,7 @@ extern "C" int onemklCtrsm_batch(syclQueue_t device_queue, onemklSide left_right
529
544
reinterpret_cast <const std::complex <float > **>(&a[0 ]),
530
545
lda, reinterpret_cast <std::complex <float > **>(&b[0 ]),
531
546
ldb, group_count, group_size, {});
547
+ device_queue->val .wait_and_throw ();
532
548
return 0 ;
533
549
}
534
550
@@ -540,6 +556,7 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
540
556
int64_t group_count, int64_t *group_size) {
541
557
trsmBatchInfo trsmInfo (device_queue, left_right,
542
558
upper_lower, transa, unit_diag, group_count);
559
+ device_queue->val .wait_and_throw ();
543
560
544
561
auto status = oneapi::mkl::blas::column_major::trsm_batch (device_queue->val ,
545
562
&trsmInfo.m_leftright [0 ], &trsmInfo.m_upperlower [0 ],
@@ -548,5 +565,6 @@ extern "C" int onemklZtrsm_batch(syclQueue_t device_queue, onemklSide left_right
548
565
reinterpret_cast <const std::complex <double > **>(&a[0 ]),
549
566
lda, reinterpret_cast <std::complex <double > **>(&b[0 ]),
550
567
ldb, group_count, group_size, {});
568
+ device_queue->val .wait_and_throw ();
551
569
return 0 ;
552
570
}
0 commit comments