Skip to content
Open
Changes from 6 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
05096e7
Matmul complex POC
PawelSwider2000 Aug 8, 2025
f70f4fc
Merge remote-tracking branch 'origin/main' into pswider/complex-matmul
PawelSwider2000 Aug 11, 2025
e865b3f
MM kernels improvements
PawelSwider2000 Aug 26, 2025
55dc07e
Switch to TORCH_LIBRARY makro
PawelSwider2000 Aug 28, 2025
963531c
Refactor
PawelSwider2000 Aug 29, 2025
ac98994
Merge branch 'main' into pswider/complex-matmul
PawelSwider2000 Aug 29, 2025
35930d2
Complex matmul passing test unskipping
PawelSwider2000 Aug 29, 2025
5702779
Merge branch 'main' into pswider/complex-matmul-with-tests
PawelSwider2000 Aug 29, 2025
312d8ed
Add device guard
PawelSwider2000 Sep 1, 2025
0e86518
Merge branch 'main' into pswider/complex-matmul
PawelSwider2000 Sep 1, 2025
4b905ad
Merge branch 'main' into pswider/complex-matmul
CuiYifeng Sep 3, 2025
50291f2
Switch to column_major implementation
PawelSwider2000 Sep 15, 2025
0a118a2
Merge branch 'main' into pswider/complex-matmul
PawelSwider2000 Sep 18, 2025
7e652e7
Update src/ATen/native/xpu/Blas.cpp
PawelSwider2000 Sep 18, 2025
1dca021
Switch to calling function instread of custom ops
PawelSwider2000 Sep 19, 2025
a165c0b
Merge branch 'main' into pswider/complex-matmul
PawelSwider2000 Sep 22, 2025
0e5ca25
Add EOF
PawelSwider2000 Sep 22, 2025
29d2124
Merge branch 'pswider/complex-matmul-with-tests' into pswider/complex…
PawelSwider2000 Sep 22, 2025
2d76644
Merge pull request #2 from PawelSwider2000/pswider/complex-matmul
PawelSwider2000 Sep 22, 2025
13ee83d
Update skip_list_common.py
PawelSwider2000 Sep 22, 2025
00e5d08
Remove test from PR
PawelSwider2000 Sep 22, 2025
4c5d95b
Apply review comments
PawelSwider2000 Sep 23, 2025
b369f18
Merge branch 'main' into pswider/complex-matmul
PawelSwider2000 Sep 24, 2025
1c9cf3c
Merge branch 'pswider/complex-matmul-with-tests' into pswider/complex…
PawelSwider2000 Sep 24, 2025
58a4ec0
Update skip_list_common.py
PawelSwider2000 Sep 25, 2025
619c3bc
Review fixes
PawelSwider2000 Sep 25, 2025
dc0ee24
Merge branch 'main' into pswider/complex-matmul
PawelSwider2000 Sep 26, 2025
3217b10
Merge branch 'main' into pswider/complex-matmul
PawelSwider2000 Sep 29, 2025
1e7b21c
Merge branch 'main' into pswider/complex-matmul
tadkrawiec Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 306 additions & 0 deletions src/ATen/native/xpu/Blas.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
#include <ATen/ATen.h>
#include <ATen/xpu/XPUContext.h>
#include <comm/Runtime.h>
#include <oneapi/mkl/blas.hpp>
#include <torch/library.h>

namespace at::native {

#if defined(USE_ONEMKL_XPU)

at::Tensor& handle_output_copy(at::Tensor& out, const at::Tensor& result) {
if (!out.is_same(result)) {
if (out.sizes() == result.sizes()) {
out.copy_(result);
} else {
out.copy_(result.view(out.sizes()));
}
}

return out;
}

template <typename T>
at::Tensor& mm_complex_out_xpu_impl(
const at::Tensor& self,
const at::Tensor& mat2,
at::Tensor& out) {
at::Tensor self_cont = self.contiguous().resolve_conj();
at::Tensor mat2_cont = mat2.contiguous().resolve_conj();
at::Tensor out_cont = out.contiguous().resolve_conj();

const int64_t m = self_cont.sizes().at(0);
const int64_t n = mat2_cont.sizes().at(1);
const int64_t k = self_cont.sizes().at(1);

constexpr std::complex<T> alpha = {T(1), T(0)};
constexpr std::complex<T> beta = {T(0), T(0)};

oneapi::mkl::blas::row_major::gemm(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
m,
n,
k,
alpha,
reinterpret_cast<const std::complex<T>*>(self_cont.const_data_ptr()),
k,
reinterpret_cast<const std::complex<T>*>(mat2_cont.const_data_ptr()),
n,
beta,
reinterpret_cast<std::complex<T>*>(out_cont.data_ptr()),
n);

return handle_output_copy(out, out_cont);
}

at::Tensor& mm_complex_out_xpu(
const at::Tensor& self,
const at::Tensor& mat2,
at::Tensor& out) {
TORCH_CHECK(
self.is_complex(), "_mm_mkl.out expects self to be a complex datatype.");

AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "mm_complex_out_xpu", [&] {
using underlying_t = typename c10::scalar_value_type<scalar_t>::type;
mm_complex_out_xpu_impl<underlying_t>(self, mat2, out);
});

return out;
}

template <typename T>
at::Tensor& bmm_complex_out_xpu_impl(
const at::Tensor& self,
const at::Tensor& mat2,
at::Tensor& out) {
at::Tensor self_cont = self.contiguous().resolve_conj();
at::Tensor mat2_cont = mat2.contiguous().resolve_conj();
at::Tensor out_cont = out.contiguous().resolve_conj();

const int64_t batch_size = self_cont.sizes().at(0);
const int64_t m = self_cont.sizes().at(1);
const int64_t n = mat2_cont.sizes().at(2);
const int64_t k = self_cont.sizes().at(2);

constexpr std::complex<T> alpha = {T(1), T(0)};
constexpr std::complex<T> beta = {T(0), T(0)};

oneapi::mkl::blas::row_major::gemm_batch(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
m,
n,
k,
alpha,
reinterpret_cast<const std::complex<T>*>(self_cont.const_data_ptr()),
k,
m * k,
reinterpret_cast<const std::complex<T>*>(mat2_cont.const_data_ptr()),
n,
k * n,
beta,
reinterpret_cast<std::complex<T>*>(out_cont.data_ptr()),
n,
m * n,
batch_size);

return handle_output_copy(out, out_cont);
}

at::Tensor& bmm_complex_out_xpu(
const at::Tensor& self,
const at::Tensor& mat2,
at::Tensor& out) {
TORCH_CHECK(
self.is_complex(), "_bmm_mkl.out expects self to be a complex datatype.");

AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "bmm_complex_out_xpu", [&] {
using underlying_t = typename c10::scalar_value_type<scalar_t>::type;
bmm_complex_out_xpu_impl<underlying_t>(self, mat2, out);
});

return out;
}

template <typename T>
at::Tensor& addmm_complex_out_xpu_impl(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
at::Tensor mat1_cont = mat1.contiguous().resolve_conj();
at::Tensor mat2_cont = mat2.contiguous().resolve_conj();
at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach();

const int64_t m = mat1_cont.sizes().at(0);
const int64_t n = mat2_cont.sizes().at(1);
const int64_t k = mat1_cont.sizes().at(1);

if (k == 0) {
if (out.numel() == 0) {
return out;
}
if (beta.toComplexDouble() == 0.0) {
out.zero_();
} else {
if (!self.is_same(out)) {
out.copy_(self);
}
out.mul_(beta);
}
return out;
}

if (m == 0 || n == 0) {
return out;
}

const std::vector<int64_t> mm_output_size = {m, n};
if (self_cont.sizes() != mm_output_size) {
self_cont = at::broadcast_to(self_cont, mm_output_size).contiguous();
}

std::complex<T> complex_alpha =
static_cast<std::complex<T>>(alpha.toComplexDouble());
std::complex<T> complex_beta =
static_cast<std::complex<T>>(beta.toComplexDouble());

oneapi::mkl::blas::row_major::gemm(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
m,
n,
k,
complex_alpha,
reinterpret_cast<const std::complex<T>*>(mat1_cont.const_data_ptr()),
k,
reinterpret_cast<const std::complex<T>*>(mat2_cont.const_data_ptr()),
n,
complex_beta,
reinterpret_cast<std::complex<T>*>(self_cont.data_ptr()),
n);

return handle_output_copy(out, self_cont);
}

at::Tensor& addmm_complex_out_xpu(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
TORCH_CHECK(
self.is_complex(),
"_addmm_mkl.out expects self to be a complex datatype.");

AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "addmm_complex_out_xpu", [&] {
using underlying_t = typename c10::scalar_value_type<scalar_t>::type;
addmm_complex_out_xpu_impl<underlying_t>(
self, mat1, mat2, beta, alpha, out);
});

return out;
}

template <typename T>
at::Tensor& baddbmm_complex_out_xpu_impl(
const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
at::Tensor batch1_cont = batch1.contiguous().resolve_conj();
at::Tensor batch2_cont = batch2.contiguous().resolve_conj();
at::Tensor self_cont = self.contiguous().resolve_conj().clone().detach();

const int64_t batch_size = batch1_cont.sizes().at(0);
const int64_t m = batch1_cont.sizes().at(1);
const int64_t n = batch2_cont.sizes().at(2);
const int64_t k = batch1_cont.sizes().at(2);

const std::vector<int64_t> mm_output_size = {batch_size, m, n};
if (self_cont.sizes() != mm_output_size) {
self_cont = at::broadcast_to(self_cont, mm_output_size).contiguous();
}

std::complex<T> complex_alpha =
static_cast<std::complex<T>>(alpha.toComplexDouble());
std::complex<T> complex_beta =
static_cast<std::complex<T>>(beta.toComplexDouble());

oneapi::mkl::blas::row_major::gemm_batch(
c10::xpu::getCurrentXPUStream().queue(),
oneapi::mkl::transpose::nontrans,
oneapi::mkl::transpose::nontrans,
m,
n,
k,
complex_alpha,
reinterpret_cast<const std::complex<T>*>(batch1_cont.const_data_ptr()),
k,
m * k,
reinterpret_cast<const std::complex<T>*>(batch2_cont.const_data_ptr()),
n,
k * n,
complex_beta,
reinterpret_cast<std::complex<T>*>(self_cont.data_ptr()),
n,
m * n,
batch_size);

return handle_output_copy(out, self_cont);
}

at::Tensor& baddbmm_complex_out_xpu(
const Tensor& self,
const Tensor& batch1,
const Tensor& batch2,
const Scalar& beta,
const Scalar& alpha,
Tensor& out) {
TORCH_CHECK(
self.is_complex(),
"_baddbmm_mkl.out expects self to be a complex datatype.");

AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "baddbmm_complex_out_xpu", [&] {
using underlying_t = typename c10::scalar_value_type<scalar_t>::type;
baddbmm_complex_out_xpu_impl<underlying_t>(
self, batch1, batch2, beta, alpha, out);
});

return out;
}

#endif // USE_ONEMKL_XPU

TORCH_LIBRARY_FRAGMENT(aten, m) {
m.def(
"aten::_mm_mkl.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"aten::_bmm_mkl.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)");
m.def(
"aten::_addmm_mkl.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)");
m.def(
"aten::_baddbmm_mkl.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)");
}

#if defined(USE_ONEMKL_XPU)

TORCH_LIBRARY_IMPL(aten, XPU, m) {
m.impl("aten::_mm_mkl.out", mm_complex_out_xpu);
m.impl("aten::_bmm_mkl.out", bmm_complex_out_xpu);
m.impl("aten::_addmm_mkl.out", addmm_complex_out_xpu);
m.impl("aten::_baddbmm_mkl.out", baddbmm_complex_out_xpu);
}

#endif // USE_ONEMKL_XPU

} // namespace at::native