Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
| `tensor.contains_nan()` | N/A |
| `tensor.cummin(dim)` | `tensor.cummin(dim)` |
| `tensor.cumsum(dim)` | `tensor.cumsum(dim)` |
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-autodiff/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_cumsum(tensor, dim)
}

fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
B::int_cummin(tensor, dim)
}

fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
B::int_repeat_dim(tensor, dim, times)
}
Expand Down
9 changes: 9 additions & 0 deletions crates/burn-autodiff/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1665,6 +1665,15 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}

fn float_cummin(_tensor: FloatTensor<Self>, _dim: usize) -> FloatTensor<Self> {
// Cummin backward pass requires scatter_add which is not yet implemented
// The gradient should only flow to the first occurrence of each minimum value
panic!(
"Cummin is not supported for autodiff backend. \
Proper implementation requires scatter_add operation."
);
}
Comment on lines +1669 to +1675
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as cummax: tensor.scatter applies the sum reduction (scatter_add equivalent).

We need to improve this expected behavior discrepancy at the tensor API level. select_assign also performs a sum 😅


fn float_argmax(tensor: FloatTensor<Self>, dim: usize) -> IntTensor<B> {
B::float_argmax(tensor.primitive, dim)
}
Expand Down
1 change: 1 addition & 0 deletions crates/burn-candle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ mod tests {
burn_tensor::testgen_transpose!();
burn_tensor::testgen_expand!();
burn_tensor::testgen_cumsum!();
burn_tensor::testgen_cummin!();

// test stats
burn_tensor::testgen_var!();
Expand Down
22 changes: 22 additions & 0 deletions crates/burn-candle/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,28 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
CandleTensor::new(result_float.to_dtype(dtype).unwrap())
}

fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
// Candle doesn't have cummin for int, convert to float, compute, convert back
let dtype = tensor.tensor.dtype();
let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap();

let dim_size = tensor_float.dims()[dim];
let mut slices = Vec::with_capacity(dim_size);

// First slice is just the first element along dim
slices.push(tensor_float.narrow(dim, 0, 1).unwrap());

// For each subsequent position, take min of previous cummin and current element
for i in 1..dim_size {
let curr = tensor_float.narrow(dim, i, 1).unwrap();
let min_val = slices[i - 1].broadcast_minimum(&curr).unwrap();
slices.push(min_val);
}

let result = candle_core::Tensor::cat(&slices, dim).unwrap();
CandleTensor::new(result.to_dtype(dtype).unwrap())
}

fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
CandleTensor::new(
tensor
Expand Down
19 changes: 19 additions & 0 deletions crates/burn-candle/src/ops/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,25 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
}

fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
// Candle doesn't have cummin, implement manually using slicing and min
let dim_size = tensor.tensor.dims()[dim];
let mut slices = Vec::with_capacity(dim_size);

// First slice is just the first element along dim
slices.push(tensor.tensor.narrow(dim, 0, 1).unwrap());

// For each subsequent position, take min of previous cummin and current element
for i in 1..dim_size {
let curr = tensor.tensor.narrow(dim, i, 1).unwrap();
let min_val = slices[i - 1].broadcast_minimum(&curr).unwrap();
slices.push(min_val);
}

let result = candle_core::Tensor::cat(&slices, dim).unwrap();
CandleTensor::new(result)
}

fn float_exp(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.exp().unwrap())
}
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-cubecl/src/ops/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,10 @@ where
execute_with_dtype!(float(tensor.dtype), E, numeric::cumsum::<R, E>(tensor, dim))
}

fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
execute_with_dtype!(float(tensor.dtype), E, numeric::cummin::<R, E>(tensor, dim))
}

fn float_prod(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
execute_with_dtype!(
float(tensor.dtype),
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-cubecl/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,10 @@ where
execute_with_dtype!(int(tensor.dtype), I, numeric::cumsum::<R, I>(tensor, dim))
}

fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
execute_with_dtype!(int(tensor.dtype), I, numeric::cummin::<R, I>(tensor, dim))
}

fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
execute_with_dtype!(
int(tensor.dtype),
Expand Down
88 changes: 88 additions & 0 deletions crates/burn-cubecl/src/ops/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,91 @@ pub fn cumsum<R: CubeRuntime, E: CubeElement>(input: CubeTensor<R>, dim: usize)

output
}

#[cube(launch)]
fn cummin_kernel<C: Numeric>(
input: &Tensor<C>,
output: &mut Tensor<C>,
dim_stride: u32,
#[comptime] dim_size: u32,
) {
if ABSOLUTE_POS >= output.len() {
terminate!();
}

let idx = ABSOLUTE_POS;

// Compute components of the index
let before_dim = idx / dim_stride;
let after_dim = idx % dim_stride;

// Compute how many strides along dim we are
let dim_offset = (idx / dim_stride) % dim_size;

// Compute cumulative minimum
let read_idx_0 = (before_dim / dim_size) * (dim_size * dim_stride) + after_dim;
let mut min_val = input[read_idx_0];

for i in 1..dim_size {
if i <= dim_offset {
let read_idx =
(before_dim / dim_size) * (dim_size * dim_stride) + i * dim_stride + after_dim;
let val = input[read_idx];
if val < min_val {
min_val = val;
}
}
}

output[idx] = min_val;
}

/// Compute the cumulative minimum along a dimension
///
/// # Limitations
///
/// This is a **naive sequential implementation** along the cummin dimension:
/// - Each output element sequentially reads all previous elements along the dimension
/// - Computational complexity: O(n²) memory reads where n is the size of the cummin dimension
/// - **Performance:** Suitable for small tensors or small dimensions. For large tensors,
/// performance will degrade significantly compared to an optimized parallel scan algorithm.
///
/// # TODO
///
/// Implement an efficient GPU-optimized parallel scan algorithm (cubecl-scan crate).
/// See draft PR: https://github.com/tracel-ai/cubecl/pull/863
///
/// References:
/// - https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda
/// - https://www.w3.org/TR/WGSL/#builtin-subgroupInclusiveAdd
pub fn cummin<R: CubeRuntime, E: CubeElement>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
let client = input.client.clone();
let device = input.device.clone();
let shape = input.shape.clone();
let dim_size = shape.dims[dim];

// Calculate stride for the cummin dimension
let dim_stride: usize = shape.dims[dim + 1..].iter().product();

let output = empty_device::<R, E>(client.clone(), device, shape);

let num_elems = output.shape.num_elements();
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);

cummin_kernel::launch::<E, R>(
&client,
cube_count,
cube_dim,
unsafe {
TensorArg::from_raw_parts::<E>(&input.handle, &input.strides, &input.shape.dims, 1)
},
unsafe {
TensorArg::from_raw_parts::<E>(&output.handle, &output.strides, &output.shape.dims, 1)
},
ScalarArg::new(dim_stride as u32),
dim_size as u32,
);

output
}
36 changes: 36 additions & 0 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1479,6 +1479,42 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}

fn float_cummin(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
#[derive(new, Debug)]
struct CumminOps<B: FusionBackend> {
desc: DimOpIr,
_b: PhantomData<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_float_tensor::<B>(&self.desc.input);
let output = B::float_cummin(input, self.desc.axis);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}

let mut streams = OperationStreams::default();
streams.tensor(&tensor);
let dtype = tensor.dtype;
let shape = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape, dtype);

let desc = DimOpIr {
input: tensor.into_ir(),
out: out.to_ir_out(),
axis: dim,
};

out.client.register(
streams,
OperationIr::BaseFloat(BaseOperationIr::CumMin(desc.clone())),
CumminOps::<B>::new(desc),
);

out
}

fn float_exp(lhs: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(ExpOps, B::float_exp);

Expand Down
35 changes: 35 additions & 0 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1314,6 +1314,41 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}

fn int_cummin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
#[derive(new, Debug)]
struct CumminOps<B: FusionBackend> {
desc: DimOpIr,
_b: PhantomData<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for CumminOps<B> {
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
let input = handles.get_int_tensor::<B>(&self.desc.input);
let output = B::int_cummin(input, self.desc.axis);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}

let dtype = tensor.dtype;
let mut streams = OperationStreams::default();
streams.tensor(&tensor);
let shape = tensor.shape.clone();
let out = tensor.client.tensor_uninitialized(shape, dtype);

let desc = DimOpIr {
out: out.to_ir_out(),
input: tensor.into_ir(),
axis: dim,
};
out.client.register(
streams,
OperationIr::BaseInt(BaseOperationIr::CumMin(desc.clone())),
CumminOps::<B>::new(desc),
);

out
}

fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
reduce_int_ops!(ArgMaxOps, B::int_argmax);

Expand Down
5 changes: 5 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,11 @@ impl RelativeOps for BaseOperationIr {
out: desc.out.to_relative(converter),
axis: desc.axis,
}),
BaseOperationIr::CumMin(desc) => BaseOperationIr::CumMin(DimOpIr {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
axis: desc.axis,
}),
BaseOperationIr::Empty(desc) => BaseOperationIr::Empty(desc.to_relative(converter)),
}
}
Expand Down
10 changes: 10 additions & 0 deletions crates/burn-ir/src/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ pub enum BaseOperationIr {
/// Int => [cumsum](burn_tensor::ops::IntTensorOps::int_cumsum).
CumSum(DimOpIr),

/// Operation corresponding to:
///
/// Float => [cummin](burn_tensor::ops::FloatTensorOps::float_cummin).
/// Int => [cummin](burn_tensor::ops::IntTensorOps::int_cummin).
CumMin(DimOpIr),

/// Operation corresponding to:
///
/// Float => [empty](burn_tensor::ops::FloatTensorOps::float_empty).
Expand Down Expand Up @@ -1523,6 +1529,7 @@ impl BaseOperationIr {
}
BaseOperationIr::Cast(repr) => vec![&repr.input, &repr.out],
BaseOperationIr::CumSum(repr) => vec![&repr.input, &repr.out],
BaseOperationIr::CumMin(repr) => vec![&repr.input, &repr.out],
BaseOperationIr::Empty(repr) => vec![repr],
BaseOperationIr::Unfold(repr) => {
vec![&repr.input, &repr.out]
Expand Down Expand Up @@ -1579,6 +1586,9 @@ impl BaseOperationIr {
BaseOperationIr::CumSum(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::CumMin(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
BaseOperationIr::Unfold(repr) => {
repr.input.mark_read_only(nodes, &mut output);
}
Expand Down
6 changes: 5 additions & 1 deletion crates/burn-ndarray/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use crate::ops::simd::{
use crate::reshape;
use crate::{
IntNdArrayElement, ShapeOps,
ops::macros::{cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim},
ops::macros::{cummin_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim},
};
use crate::{SharedArray, element::NdArrayElement};
use burn_tensor::Shape;
Expand Down Expand Up @@ -616,6 +616,10 @@ where
cumsum_dim(tensor, dim)
}

pub fn cummin(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
cummin_dim(tensor, dim)
}

pub fn gather<I: NdArrayElement>(
dim: usize,
mut tensor: SharedArray<E>,
Expand Down
4 changes: 4 additions & 0 deletions crates/burn-ndarray/src/ops/int_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,10 @@ where
execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
}

fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor {
execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim))
}

fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor {
execute_with_int_dtype!(tensor, E, |tensor: SharedArray<E>| -> NdArrayTensor {
execute_with_int_dtype!(indices, |indices| NdArrayMathOps::gather(
Expand Down
20 changes: 20 additions & 0 deletions crates/burn-ndarray/src/ops/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,23 @@ pub(crate) fn cumsum_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize)

result.into_shared()
}

pub(crate) fn cummin_dim<E: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
let axis = Axis(dim);
let shape = tensor.shape().to_vec();
let mut result = tensor.to_owned();

// Compute cumulative minimum along the specified axis
let dim_size = shape[dim];
for i in 1..dim_size {
let prev = result.index_axis(axis, i - 1).to_owned();
let mut current = result.index_axis_mut(axis, i);
Zip::from(&mut current).and(&prev).for_each(|c, &p| {
if p < *c {
*c = p;
}
});
}

result.into_shared()
}
Loading