@@ -235,15 +235,15 @@ torch::Tensor coo_spmm(torch::Tensor const &rows, torch::Tensor const &cols,
235235 CUDA_CHECK (cudaMemcpy (sorted_val_ptr, values_ptr, nnz * sizeof (scalar_t ),
236236 cudaMemcpyDeviceToDevice));
237237
238- thrust::sort_by_key (thrust::device, //
239- sorted_row_ptr, // key begin
240- sorted_row_ptr + nnz, // key end
241- thrust::make_zip_iterator ( // value begin
242- thrust::make_tuple ( //
243- sorted_col_ptr, //
244- sorted_val_ptr //
245- ) //
246- ));
238+ THRUST_CHECK ( thrust::sort_by_key (thrust::device, //
239+ sorted_row_ptr, // key begin
240+ sorted_row_ptr + nnz, // key end
241+ thrust::make_zip_iterator ( // value begin
242+ thrust::make_tuple ( //
243+ sorted_col_ptr, //
244+ sorted_val_ptr //
245+ ) //
246+ ) ));
247247 LOG_DEBUG (" sorted row" , cudaDeviceSynchronize ());
248248 } else {
249249 sorted_row_ptr = row_indices_ptr;
@@ -481,10 +481,10 @@ coo_spmm_average(torch::Tensor const &rows, torch::Tensor const &cols,
481481 CUDA_CHECK (cudaMemcpy (sorted_col_ptr, col_indices_ptr,
482482 nnz * sizeof (th_int_type), cudaMemcpyDeviceToDevice));
483483
484- thrust::sort_by_key (thrust::device, //
485- sorted_row_ptr, // key begin
486- sorted_row_ptr + nnz, // key end
487- sorted_col_ptr);
484+ THRUST_CHECK ( thrust::sort_by_key (thrust::device, //
485+ sorted_row_ptr, // key begin
486+ sorted_row_ptr + nnz, // key end
487+ sorted_col_ptr) );
488488
489489 // ///////////////////////////////////////////////////////////////////////
490490 // Create vals
@@ -496,21 +496,20 @@ coo_spmm_average(torch::Tensor const &rows, torch::Tensor const &cols,
496496 (scalar_t *)c10::cuda::CUDACachingAllocator::raw_alloc (
497497 nnz * sizeof (scalar_t ));
498498 torch::Tensor ones = at::ones ({nnz}, mat2.options ());
499-
500- // reduce by key
501- auto end = thrust::reduce_by_key (
502- thrust::device, // policy
503- sorted_row_ptr, // key begin
504- sorted_row_ptr + nnz, // key end
505- reinterpret_cast <scalar_t *>(ones.data_ptr ()), // value begin
506- unique_row_ptr, // key out begin
507- reduced_val_ptr // value out begin
508- );
509-
510- int num_unique_keys = end.first - unique_row_ptr;
511- LOG_DEBUG (" Num unique keys:" , num_unique_keys);
512-
513- // Create values
499+ int num_unique_keys;
500+ try {
501+ // reduce by key
502+ auto end = thrust::reduce_by_key (
503+ thrust::device, // policy
504+ sorted_row_ptr, // key begin
505+ sorted_row_ptr + nnz, // key end
506+ reinterpret_cast <scalar_t *>(ones.data_ptr ()), // value begin
507+ unique_row_ptr, // key out begin
508+ reduced_val_ptr // value out begin
509+ );
510+ num_unique_keys = end.first - unique_row_ptr;
511+ LOG_DEBUG (" Num unique keys:" , num_unique_keys);
512+ } THRUST_CATCH;
514513
515514 // Copy the results to the correct output
516515 inverse_val<th_int_type, scalar_t >
0 commit comments