Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions crates/cubecl-core/src/runtime_tests/matmul_2d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
use cubecl_linalg::matmul::{
components::{MatrixLayout, MatmulProblem},
kernels::matmul::base::matmul_cmma_ref,
};

#[test]
fn test_2d_input_matmul() {
let client = ComputeClient::new().unwrap();

// Test 1xN input
let m = 1;
let n = 32;
let k = 16;

let lhs = client.create_tensor::<f32>(&[m, k]).unwrap();
let rhs = client.create_tensor::<f32>(&[k, n]).unwrap();
let out = client.create_tensor::<f32>(&[m, n]).unwrap();

let problem = MatmulProblem {
m,
n,
k,
batches: (vec![1], vec![1]),
lhs_layout: MatrixLayout::RowMajor,
rhs_layout: MatrixLayout::RowMajor,
out_layout: MatrixLayout::RowMajor,
lhs_line_size: 16,
rhs_line_size: 16,
out_line_size: 16,
};

matmul_cmma_ref::<_, f32, _>(&client, &lhs, &rhs, &out, (false, false)).unwrap();

// Test Mx1 input
let m = 32;
let n = 1;
let k = 16;

let lhs = client.create_tensor::<f32>(&[m, k]).unwrap();
let rhs = client.create_tensor::<f32>(&[k, n]).unwrap();
let out = client.create_tensor::<f32>(&[m, n]).unwrap();

let problem = MatmulProblem {
m,
n,
k,
batches: (vec![1], vec![1]),
lhs_layout: MatrixLayout::RowMajor,
rhs_layout: MatrixLayout::RowMajor,
out_layout: MatrixLayout::RowMajor,
lhs_line_size: 16,
rhs_line_size: 16,
out_line_size: 16,
};

matmul_cmma_ref::<_, f32, _>(&client, &lhs, &rhs, &out, (false, false)).unwrap();
}
26 changes: 15 additions & 11 deletions crates/cubecl-linalg/src/convolution/selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ pub fn matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
// to be the rough cutoff for the k=4 size.
let stage_size_k = if problem.k >= 4096 { 4 } else { 2 };

let (instruction_m, instruction_n, instruction_k) = find_instruction_shape(
let (instruction_m, instruction_n, _instruction_k) = find_instruction_shape(
if TMM::requires_tensor_cores() {
Some((client.properties(), (elem_stage, elem_stage, elem_acc)))
} else {
Expand All @@ -143,17 +143,21 @@ pub fn matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
stage_size_k,
);

let tile_shape = MatmulSize {
m: stage_size_m as u32,
n: stage_size_n as u32,
k: stage_size_k as u32,
};

let tile_count = MatmulSize {
m: (problem.m as u32 + stage_size_m as u32 - 1) / stage_size_m as u32,
n: (problem.n as u32 + stage_size_n as u32 - 1) / stage_size_n as u32,
k: (problem.k as u32 + stage_size_k as u32 - 1) / stage_size_k as u32,
};

MatmulSelection {
tile_shape: MatmulSize {
m: instruction_m as u32,
n: instruction_n as u32,
k: instruction_k as u32,
},
tile_count: MatmulSize {
m: stage_size_m as u32,
n: stage_size_n as u32,
k: stage_size_k as u32,
},
tile_shape,
tile_count,
plane_dim,
rows_per_plane: 1,
}
Expand Down
24 changes: 23 additions & 1 deletion crates/cubecl-linalg/src/matmul/components/tile/accelerated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::matmul::components::{
Ident, InvalidConfigError, MatmulConfigFactory, MatmulPrecision, MatmulProblem, MatmulSize,
MatrixLayout, as_cmma_layout,
};
use crate::matmul::kernels::MatmulAvailabilityError;
use crate::matmul::kernels::{MatmulAvailabilityError, matmul::find_instruction_shape};
use cubecl_core::ir::{Elem, FloatKind};
use cubecl_core::{self as cubecl, Feature};
use cubecl_core::{cmma, prelude::*};
Expand Down Expand Up @@ -126,6 +126,28 @@ impl MatmulConfigFactory for Accelerated {
"Error: Expected plane dimension to be 32, but found {}. Please ensure that cube dimension x is set correctly.",
));
}

// Get instruction shapes
let (instruction_m, instruction_n, _) = find_instruction_shape(None, config.size.m as usize, config.size.n as usize);

// Validate stage sizes for 2D inputs
let size = config.size;
if size.m == 1 || size.n == 1 {
// For 2D inputs, ensure stage sizes are valid
if size.m == 1 && size.n % instruction_m as u32 != 0 {
return Err(Box::new(format!(
"Error: For 1xN input, stage size n ({}) must divide input dimension evenly",
size.n
)));
}
if size.n == 1 && size.m % instruction_n as u32 != 0 {
return Err(Box::new(format!(
"Error: For Mx1 input, stage size m ({}) must divide input dimension evenly",
size.m
)));
}
}

Ok(())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,29 @@ pub fn matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
// Makes all rows the length of plane_dim
let k = plane_dim / instruction_k as u32;

// For 2D inputs, ensure stage sizes are valid
let stage_size_m = if problem.m == 1 {
// For 1xN input, ensure stage size divides n evenly
let mut valid_size = stage_size_m;
while valid_size > 1 && (problem.n % valid_size != 0 || valid_size % instruction_m != 0) {
valid_size /= 2;
}
valid_size
} else {
stage_size_m
};

let stage_size_n = if problem.n == 1 {
// For Mx1 input, ensure stage size divides m evenly
let mut valid_size = stage_size;
while valid_size > 1 && (problem.m % valid_size != 0 || valid_size % instruction_n != 0) {
valid_size /= 2;
}
valid_size
} else {
stage_size
};

MatmulSelection {
tile_shape: MatmulSize {
m: instruction_m as u32,
Expand All @@ -253,7 +276,7 @@ pub fn matmul_selection<TMM: TileMatmulFamily, R: Runtime>(
},
tile_count: MatmulSize {
m: stage_size_m as u32,
n: stage_size as u32,
n: stage_size_n as u32,
k,
},
plane_dim,
Expand Down