Skip to content

Conversation

amontoison
Copy link
Member

@amontoison amontoison commented Sep 16, 2025

@michel2323 I started to work on the support of oneSparseMatrixCSC.
We can work on it together if you want.

@amontoison amontoison marked this pull request as ready for review September 16, 2025 21:03
Copy link
Contributor

github-actions bot commented Sep 16, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/mkl/array.jl b/lib/mkl/array.jl
index 0db254e..1931d4d 100644
--- a/lib/mkl/array.jl
+++ b/lib/mkl/array.jl
@@ -18,7 +18,7 @@ mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
     colPtr::oneVector{Ti}
     rowVal::oneVector{Ti}
     nzVal::oneVector{Tv}
-    dims::NTuple{2,Int}
+    dims::NTuple{2, Int}
     nnz::Ti
 end
 
@@ -46,7 +46,7 @@ SparseArrays.nnz(A::oneAbstractSparseMatrix) = A.nnz
 SparseArrays.nonzeros(A::oneAbstractSparseMatrix) = A.nzVal
 
 for (gpu, cpu) in [:oneSparseMatrixCSR => :SparseMatrixCSC,
-                   :oneSparseMatrixCSC => :SparseMatrixCSC,
+        :oneSparseMatrixCSC => :SparseMatrixCSC,
                    :oneSparseMatrixCOO => :SparseMatrixCSC]
     @eval Base.show(io::IOContext, x::$gpu) =
         show(io, $cpu(x))
diff --git a/lib/mkl/interfaces.jl b/lib/mkl/interfaces.jl
index 343131d..7ccc57e 100644
--- a/lib/mkl/interfaces.jl
+++ b/lib/mkl/interfaces.jl
@@ -7,9 +7,9 @@ function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::
     sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
 end
 
-function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
+function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where {T <: BlasReal}
     tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
-    sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
+    return sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
 end
 
 function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
@@ -18,16 +18,16 @@ function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseM
     sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
 end
 
-function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
+function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: BlasReal}
     tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
     tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
-    sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
+    return sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
 end
 
-function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where T <: BlasFloat
-    sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
+function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where {T <: BlasFloat}
+    return sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
 end
 
-function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
-    sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
+function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where {T <: BlasFloat}
+    return sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
 end
diff --git a/lib/mkl/wrappers_sparse.jl b/lib/mkl/wrappers_sparse.jl
index 360c00b..bb9df18 100644
--- a/lib/mkl/wrappers_sparse.jl
+++ b/lib/mkl/wrappers_sparse.jl
@@ -46,7 +46,7 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data   , :Float32   , :Int3
             nnzA = length(A.nzval)
             queue = global_queue(context(nzVal), device())
             $fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal)  # CSC of A is CSR of Aᵀ
-            dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m,n), nnzA)
+            dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, (m, n), nnzA)
             finalizer(sparse_release_matrix_handle, dA)
             return dA
         end
@@ -122,19 +122,23 @@ for SparseMatrix in (:oneSparseMatrixCSR, :oneSparseMatrixCOO)
 end
 
 for SparseMatrix in (:oneSparseMatrixCSC,)
-    for (fname, elty) in ((:onemklSsparse_gemv, :Float32),
-                          (:onemklDsparse_gemv, :Float64))
+    for (fname, elty) in (
+            (:onemklSsparse_gemv, :Float32),
+            (:onemklDsparse_gemv, :Float64),
+        )
         @eval begin
-            function sparse_gemv!(trans::Char,
-                                  alpha::Number,
-                                  A::$SparseMatrix{$elty},
-                                  x::oneStridedVector{$elty},
-                                  beta::Number,
-                                  y::oneStridedVector{$elty})
+            function sparse_gemv!(
+                    trans::Char,
+                    alpha::Number,
+                    A::$SparseMatrix{$elty},
+                    x::oneStridedVector{$elty},
+                    beta::Number,
+                    y::oneStridedVector{$elty}
+                )
 
                 queue = global_queue(context(x), device())
                 $fname(sycl_queue(queue), flip_trans(trans), alpha, A.handle, x, beta, y)
-                y
+                return y
             end
         end
     end
@@ -187,27 +191,31 @@ function sparse_optimize_gemm!(trans::Char, transB::Char, nrhs::Int, A::oneSpars
     return A
 end
 
-for (fname, elty) in ((:onemklSsparse_gemm, :Float32),
-                      (:onemklDsparse_gemm, :Float64))
+for (fname, elty) in (
+        (:onemklSsparse_gemm, :Float32),
+        (:onemklDsparse_gemm, :Float64),
+    )
     @eval begin
-        function sparse_gemm!(transa::Char,
-                              transb::Char,
-                              alpha::Number,
-                              A::oneSparseMatrixCSC{$elty},
-                              B::oneStridedMatrix{$elty},
-                              beta::Number,
-                              C::oneStridedMatrix{$elty})
+        function sparse_gemm!(
+                transa::Char,
+                transb::Char,
+                alpha::Number,
+                A::oneSparseMatrixCSC{$elty},
+                B::oneStridedMatrix{$elty},
+                beta::Number,
+                C::oneStridedMatrix{$elty}
+            )
 
             mB, nB = size(B)
             mC, nC = size(C)
             (nB != nC) && (transb == 'N') && throw(ArgumentError("B and C must have the same number of columns."))
             (mB != nC) && (transb != 'N') && throw(ArgumentError("Bᵀ and C must have the same number of columns."))
             nrhs = size(B, 2)
-            ldb = max(1,stride(B,2))
-            ldc = max(1,stride(C,2))
+            ldb = max(1, stride(B, 2))
+            ldc = max(1, stride(C, 2))
             queue = global_queue(context(C), device())
             $fname(sycl_queue(queue), 'C', flip_trans(transa), transb, alpha, A.handle, B, nrhs, ldb, beta, C, ldc)
-            C
+            return C
         end
     end
 end
@@ -243,19 +251,23 @@ for (fname, elty) in ((:onemklSsparse_symv, :Float32),
     end
 end
 
-for (fname, elty) in ((:onemklSsparse_symv, :Float32),
-                      (:onemklDsparse_symv, :Float64))
+for (fname, elty) in (
+        (:onemklSsparse_symv, :Float32),
+        (:onemklDsparse_symv, :Float64),
+    )
     @eval begin
-        function sparse_symv!(uplo::Char,
-                              alpha::Number,
-                              A::oneSparseMatrixCSC{$elty},
-                              x::oneStridedVector{$elty},
-                              beta::Number,
-                              y::oneStridedVector{$elty})
+        function sparse_symv!(
+                uplo::Char,
+                alpha::Number,
+                A::oneSparseMatrixCSC{$elty},
+                x::oneStridedVector{$elty},
+                beta::Number,
+                y::oneStridedVector{$elty}
+            )
 
             queue = global_queue(context(y), device())
             $fname(sycl_queue(queue), flip_uplo(uplo), alpha, A.handle, x, beta, y)
-            y
+            return y
         end
     end
 end
diff --git a/test/onemkl.jl b/test/onemkl.jl
index 95c0ed3..ef668dd 100644
--- a/test/onemkl.jl
+++ b/test/onemkl.jl
@@ -3,7 +3,7 @@ if Sys.iswindows()
 else
 
 using oneAPI
-using oneAPI.oneMKL: band, bandex, oneSparseMatrixCSR, oneSparseMatrixCOO, oneSparseMatrixCSC
+    using oneAPI.oneMKL: band, bandex, oneSparseMatrixCSR, oneSparseMatrixCOO, oneSparseMatrixCSC
 
 using SparseArrays
 using LinearAlgebra
@@ -1078,7 +1078,7 @@ end
 
 @testset "SPARSE" begin
     @testset "$T" for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64])
-        ε = sqrt(eps(real(T)))
+            ε = sqrt(eps(real(T)))
 
         @testset "oneSparseMatrixCSR" begin
             for S in (Int32, Int64)
@@ -1090,16 +1090,16 @@ end
             end
         end
 
-        @testset "oneSparseMatrixCSC" begin
-            (T isa Complex) && continue
-            for S in (Int32, Int64)
-                A = sprand(T, 20, 10, 0.5)
-                A = SparseMatrixCSC{T, S}(A)
-                B = oneSparseMatrixCSC(A)
-                A2 = SparseMatrixCSC(B)
-                @test A == A2
+            @testset "oneSparseMatrixCSC" begin
+                (T isa Complex) && continue
+                for S in (Int32, Int64)
+                    A = sprand(T, 20, 10, 0.5)
+                    A = SparseMatrixCSC{T, S}(A)
+                    B = oneSparseMatrixCSC(A)
+                    A2 = SparseMatrixCSC(B)
+                    @test A == A2
+                end
             end
-        end
 
         @testset "oneSparseMatrixCOO" begin
             for S in (Int32, Int64)
@@ -1112,9 +1112,9 @@ end
         end
 
         @testset "sparse gemv" begin
-            @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCOO, oneSparseMatrixCSR, oneSparseMatrixCSC)
+                @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCOO, oneSparseMatrixCSR, oneSparseMatrixCSC)
                 @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
-                    (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
+                        (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
                     A = sprand(T, 20, 10, 0.5)
                     x = transa == 'N' ? rand(T, 10) : rand(T, 20)
                     y = transa == 'N' ? rand(T, 20) : rand(T, 10)
@@ -1127,142 +1127,146 @@ end
                     beta = rand(T)
                     oneMKL.sparse_optimize_gemv!(transa, dA)
                     oneMKL.sparse_gemv!(transa, alpha, dA, dx, beta, dy)
-                    @test isapprox(alpha * opa(A) * x + beta * y, collect(dy), atol=ε)
+                        @test isapprox(alpha * opa(A) * x + beta * y, collect(dy), atol = ε)
                 end
             end
         end
 
         @testset "sparse gemm" begin
-            @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
-                @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
-                    @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)]
-                        (transb == 'N') || continue
-                        (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
-                        A = sprand(T, 10, 10, 0.5)
-                        B = transb == 'N' ? rand(T, 10, 2) : rand(T, 2, 10)
-                        C = rand(T, 10, 2)
+                @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
+                    @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
+                        @testset "transb = $transb" for (transb, opb) in [('N', identity), ('T', transpose), ('C', adjoint)]
+                            (transb == 'N') || continue
+                            (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
+                            A = sprand(T, 10, 10, 0.5)
+                            B = transb == 'N' ? rand(T, 10, 2) : rand(T, 2, 10)
+                            C = rand(T, 10, 2)
 
-                        dA = SparseMatrix(A)
-                        dB = oneMatrix{T}(B)
-                        dC = oneMatrix{T}(C)
+                            dA = SparseMatrix(A)
+                            dB = oneMatrix{T}(B)
+                            dC = oneMatrix{T}(C)
 
-                        alpha = rand(T)
-                        beta = rand(T)
-                        oneMKL.sparse_optimize_gemm!(transa, dA)
-                        oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC)
-                        @test isapprox(alpha * opa(A) * opb(B) + beta * C, collect(dC), atol=ε)
+                            alpha = rand(T)
+                            beta = rand(T)
+                            oneMKL.sparse_optimize_gemm!(transa, dA)
+                            oneMKL.sparse_gemm!(transa, transb, alpha, dA, dB, beta, dC)
+                            @test isapprox(alpha * opa(A) * opb(B) + beta * C, collect(dC), atol = ε)
 
-                        oneMKL.sparse_optimize_gemm!(transa, transb, 2, dA)
-                    end
+                            oneMKL.sparse_optimize_gemm!(transa, transb, 2, dA)
+                        end
                 end
             end
         end
 
         @testset "sparse symv" begin
-            @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
-                @testset "uplo = $uplo" for uplo in ('L', 'U')
-                    (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
+                @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR, oneSparseMatrixCSC)
+                    @testset "uplo = $uplo" for uplo in ('L', 'U')
+                        (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
                     A = sprand(T, 10, 10, 0.5)
-                    A = A + transpose(A)
+                        A = A + transpose(A)
                     x = rand(T, 10)
                     y = rand(T, 10)
 
-                    dA = uplo == 'L' ? SparseMatrix(A |> tril) : SparseMatrix(A |> triu)
+                        dA = uplo == 'L' ? SparseMatrix(A |> tril) : SparseMatrix(A |> triu)
                     dx = oneVector{T}(x)
                     dy = oneVector{T}(y)
 
                     alpha = rand(T)
                     beta = rand(T)
-                    oneMKL.sparse_symv!(uplo, alpha, dA, dx, beta, dy)
-                    @test isapprox(alpha * A * x + beta * y, collect(dy), atol=ε)
+                        oneMKL.sparse_symv!(uplo, alpha, dA, dx, beta, dy)
+                        @test isapprox(alpha * A * x + beta * y, collect(dy), atol = ε)
                 end
             end
         end
 
-        @testset "sparse trmv" begin
-            @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR,)
-                @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
-                    for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
-                                                  ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)]
-                        (transa == 'N') || continue
-                        (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
-                        A = sprand(T, 10, 10, 0.5)
-                        x = rand(T, 10)
-                        y = rand(T, 10)
+            @testset "sparse trmv" begin
+                @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR,)
+                    @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
+                        for (uplo, diag, wrapper) in [
+                                ('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
+                                ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular),
+                            ]
+                            (transa == 'N') || continue
+                            (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
+                            A = sprand(T, 10, 10, 0.5)
+                            x = rand(T, 10)
+                            y = rand(T, 10)
 
-                        B = uplo == 'L' ? tril(A) : triu(A)
-                        B = diag == 'U' ? B - Diagonal(B) + I : B
-                        dA = SparseMatrix(B)
-                        dx = oneVector{T}(x)
-                        dy = oneVector{T}(y)
+                            B = uplo == 'L' ? tril(A) : triu(A)
+                            B = diag == 'U' ? B - Diagonal(B) + I : B
+                            dA = SparseMatrix(B)
+                            dx = oneVector{T}(x)
+                            dy = oneVector{T}(y)
 
-                        alpha = rand(T)
-                        beta = rand(T)
+                            alpha = rand(T)
+                            beta = rand(T)
 
-                        oneMKL.sparse_optimize_trmv!(uplo, transa, diag, dA)
-                        oneMKL.sparse_trmv!(uplo, transa, diag, alpha, dA, dx, beta, dy)
-                        @test isapprox(alpha * wrapper(opa(A)) * x + beta * y, collect(dy), atol=ε)
-                    end
+                            oneMKL.sparse_optimize_trmv!(uplo, transa, diag, dA)
+                            oneMKL.sparse_trmv!(uplo, transa, diag, alpha, dA, dx, beta, dy)
+                            @test isapprox(alpha * wrapper(opa(A)) * x + beta * y, collect(dy), atol = ε)
+                        end
                 end
             end
         end
 
-        @testset "sparse trsv" begin
-            @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR,)
-                @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
+            @testset "sparse trsv" begin
+                @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR,)
+                    @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
                     for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
                                                   ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)]
                         (transa == 'N') || continue
-                        (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
+                            (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
                         alpha = rand(T)
                         A = rand(T, 10, 10) + I
                         A = sparse(A)
-                        x = rand(T, 10)
-                        y = rand(T, 10)
+                            x = rand(T, 10)
+                            y = rand(T, 10)
 
                         B = uplo == 'L' ? tril(A) : triu(A)
                         B = diag == 'U' ? B - Diagonal(B) + I : B
-                        dA = SparseMatrix(B)
-                        dx = oneVector{T}(x)
-                        dy = oneVector{T}(y)
-
-                        oneMKL.sparse_optimize_trsv!(uplo, transa, diag, dA)
-                        oneMKL.sparse_trsv!(uplo, transa, diag, alpha, dA, dx, dy)
-                        y = wrapper(opa(A)) \ (alpha * x)
-                        @test isapprox(y, collect(dy), atol=ε)
+                            dA = SparseMatrix(B)
+                            dx = oneVector{T}(x)
+                            dy = oneVector{T}(y)
+
+                            oneMKL.sparse_optimize_trsv!(uplo, transa, diag, dA)
+                            oneMKL.sparse_trsv!(uplo, transa, diag, alpha, dA, dx, dy)
+                            y = wrapper(opa(A)) \ (alpha * x)
+                            @test isapprox(y, collect(dy), atol = ε)
+                        end
                     end
                 end
             end
-        end
-
-        @testset "sparse trsm" begin
-            @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR,)
-                @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
-                    @testset "transx = $transx" for (transx, opx) in [('N', identity), ('T', transpose), ('C', adjoint)]
-                        (transx != 'N') && continue
-                        for (uplo, diag, wrapper) in [('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
-                                                      ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular)]
-                            (transa == 'N') || continue
-                            (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
-                            alpha = rand(T)
-                            A = rand(T, 10, 10) + I
-                            A = sparse(A)
-                            X = transx == 'N' ? rand(T, 10, 4) : rand(T, 4, 10)
-                            Y = rand(T, 10, 4)
 
-                            B = uplo == 'L' ? tril(A) : triu(A)
-                            B = diag == 'U' ? B - Diagonal(B) + I : B
-                            dA = SparseMatrix(B)
-                            dX = oneMatrix{T}(X)
-                            dY = oneMatrix{T}(Y)
-
-                            oneMKL.sparse_optimize_trsm!(uplo, transa, diag, dA)
-                            oneMKL.sparse_trsm!(uplo, transa, transx, diag, alpha, dA, dX, dY)
-                            Y = wrapper(opa(A)) \ (alpha * opx(X))
-                            @test isapprox(Y, collect(dY), atol=ε)
-
-                            oneMKL.sparse_optimize_trsm!(uplo, transa, diag, 4, dA)
-                        end
+            @testset "sparse trsm" begin
+                @testset  "$SparseMatrix" for SparseMatrix in (oneSparseMatrixCSR,)
+                    @testset "transa = $transa" for (transa, opa) in [('N', identity), ('T', transpose), ('C', adjoint)]
+                        @testset "transx = $transx" for (transx, opx) in [('N', identity), ('T', transpose), ('C', adjoint)]
+                            (transx != 'N') && continue
+                            for (uplo, diag, wrapper) in [
+                                    ('L', 'N', LowerTriangular), ('L', 'U', UnitLowerTriangular),
+                                    ('U', 'N', UpperTriangular), ('U', 'U', UnitUpperTriangular),
+                                ]
+                                (transa == 'N') || continue
+                                (T <: Complex) && (SparseMatrix == oneSparseMatrixCSC) && continue
+                                alpha = rand(T)
+                                A = rand(T, 10, 10) + I
+                                A = sparse(A)
+                                X = transx == 'N' ? rand(T, 10, 4) : rand(T, 4, 10)
+                                Y = rand(T, 10, 4)
+
+                                B = uplo == 'L' ? tril(A) : triu(A)
+                                B = diag == 'U' ? B - Diagonal(B) + I : B
+                                dA = SparseMatrix(B)
+                                dX = oneMatrix{T}(X)
+                                dY = oneMatrix{T}(Y)
+
+                                oneMKL.sparse_optimize_trsm!(uplo, transa, diag, dA)
+                                oneMKL.sparse_trsm!(uplo, transa, transx, diag, alpha, dA, dX, dY)
+                                Y = wrapper(opa(A)) \ (alpha * opx(X))
+                                @test isapprox(Y, collect(dY), atol = ε)
+
+                                oneMKL.sparse_optimize_trsm!(uplo, transa, diag, 4, dA)
+                            end
                     end
                 end
             end

Copy link

codecov bot commented Sep 17, 2025

Codecov Report

❌ Patch coverage is 82.25806% with 11 lines in your changes missing coverage. Please review.
✅ Project coverage is 79.95%. Comparing base (3d3278d) to head (a8bbba2).

Files with missing lines Patch % Lines
lib/mkl/interfaces.jl 0.00% 11 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #526      +/-   ##
==========================================
+ Coverage   79.70%   79.95%   +0.24%     
==========================================
  Files          45       45              
  Lines        2818     2878      +60     
==========================================
+ Hits         2246     2301      +55     
- Misses        572      577       +5     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@michel2323
Copy link
Member

michel2323 commented Sep 17, 2025

Can we support the complex case too? Actually, I was working yesterday on it too 🙂 . I got the complex case for gemv and gemm working. Maybe we can merge somehow. I'll check.

@amontoison
Copy link
Member Author

amontoison commented Sep 17, 2025

We can support the complex case for gemv and gemm but we need to conjugate the vecteur x, do the product and then conjugate the output y.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants