diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index b70fb62edd..77af0346b4 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -202,6 +202,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` | | `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` | | `tensor.contains_nan()` | N/A | +| `tensor.cummin(dim)` | `tensor.cummin(dim)` | | `tensor.cumsum(dim)` | `tensor.cumsum(dim)` | | `tensor.div(other)` or `tensor / other` | `tensor / other` | | `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` | diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 5fcbb77478..e229e31a3f 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -157,6 +157,10 @@ impl IntTensorOps for Autodiff { B::int_cumsum(tensor, dim) } + fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { + B::int_cummin(tensor, dim) + } + fn int_repeat_dim(tensor: IntTensor, dim: usize, times: usize) -> IntTensor { B::int_repeat_dim(tensor, dim, times) } diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 47fd87f981..30208f553e 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -1665,6 +1665,15 @@ impl FloatTensorOps for Autodiff } } + fn float_cummin(_tensor: FloatTensor, _dim: usize) -> FloatTensor { + // Cummin backward pass requires scatter_add which is not yet implemented + // The gradient should only flow to the first occurrence of each minimum value + panic!( + "Cummin is not supported for autodiff backend. \ + Proper implementation requires scatter_add operation." + ); + } + fn float_argmax(tensor: FloatTensor, dim: usize) -> IntTensor { B::float_argmax(tensor.primitive, dim) } diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index 315bc15538..b0e632beab 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -120,6 +120,7 @@ mod tests { burn_tensor::testgen_transpose!(); burn_tensor::testgen_expand!(); burn_tensor::testgen_cumsum!(); + burn_tensor::testgen_cummin!(); // test stats burn_tensor::testgen_var!(); diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 1bc35e9605..483c942f1c 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -303,6 +303,28 @@ impl IntTensorOps for Candle, dim: usize) -> IntTensor { + // Candle doesn't have cummin for int, convert to float, compute, convert back + let dtype = tensor.tensor.dtype(); + let tensor_float = tensor.tensor.to_dtype(candle_core::DType::F32).unwrap(); + + let dim_size = tensor_float.dims()[dim]; + let mut slices = Vec::with_capacity(dim_size); + + // First slice is just the first element along dim + slices.push(tensor_float.narrow(dim, 0, 1).unwrap()); + + // For each subsequent position, take min of previous cummin and current element + for i in 1..dim_size { + let curr = tensor_float.narrow(dim, i, 1).unwrap(); + let min_val = slices[i - 1].broadcast_minimum(&curr).unwrap(); + slices.push(min_val); + } + + let result = candle_core::Tensor::cat(&slices, dim).unwrap(); + CandleTensor::new(result.to_dtype(dtype).unwrap()) + } + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { CandleTensor::new( tensor diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index a48ea3314c..68e8a9cc9b 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -330,6 +330,25 @@ impl FloatTensorOps for Candle CandleTensor::new(tensor.tensor.cumsum(dim).unwrap()) } + fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { + // Candle doesn't have cummin, implement manually using slicing and min + let dim_size = tensor.tensor.dims()[dim]; + let mut slices = Vec::with_capacity(dim_size); + + // First slice is just the first element along dim + slices.push(tensor.tensor.narrow(dim, 0, 1).unwrap()); + + // For each subsequent position, take min of previous cummin and current element + for i in 1..dim_size { + let curr = tensor.tensor.narrow(dim, i, 1).unwrap(); + let min_val = slices[i - 1].broadcast_minimum(&curr).unwrap(); + slices.push(min_val); + } + + let result = candle_core::Tensor::cat(&slices, dim).unwrap(); + CandleTensor::new(result) + } + fn float_exp(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.exp().unwrap()) } diff --git a/crates/burn-cubecl/src/ops/float_ops.rs b/crates/burn-cubecl/src/ops/float_ops.rs index 4c86e9bfc3..dca9059fd9 100644 --- a/crates/burn-cubecl/src/ops/float_ops.rs +++ b/crates/burn-cubecl/src/ops/float_ops.rs @@ -511,6 +511,10 @@ where execute_with_dtype!(float(tensor.dtype), E, numeric::cumsum::(tensor, dim)) } + fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_dtype!(float(tensor.dtype), E, numeric::cummin::(tensor, dim)) + } + fn float_prod(tensor: FloatTensor) -> FloatTensor { execute_with_dtype!( float(tensor.dtype), diff --git a/crates/burn-cubecl/src/ops/int_ops.rs b/crates/burn-cubecl/src/ops/int_ops.rs index b3f26f3a50..2dea3c0a77 100644 --- a/crates/burn-cubecl/src/ops/int_ops.rs +++ b/crates/burn-cubecl/src/ops/int_ops.rs @@ -480,6 +480,10 @@ where execute_with_dtype!(int(tensor.dtype), I, numeric::cumsum::(tensor, dim)) } + fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { + execute_with_dtype!(int(tensor.dtype), I, numeric::cummin::(tensor, dim)) + } + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { execute_with_dtype!( int(tensor.dtype), diff --git a/crates/burn-cubecl/src/ops/numeric.rs b/crates/burn-cubecl/src/ops/numeric.rs index e8533630c6..9417e6ab29 100644 --- a/crates/burn-cubecl/src/ops/numeric.rs +++ b/crates/burn-cubecl/src/ops/numeric.rs @@ -318,3 +318,91 @@ pub fn cumsum(input: CubeTensor, dim: usize) output } + +#[cube(launch)] +fn cummin_kernel( + input: &Tensor, + output: &mut Tensor, + dim_stride: u32, + #[comptime] dim_size: u32, +) { + if ABSOLUTE_POS >= output.len() { + terminate!(); + } + + let idx = ABSOLUTE_POS; + + // Compute components of the index + let before_dim = idx / dim_stride; + let after_dim = idx % dim_stride; + + // Compute how many strides along dim we are + let dim_offset = (idx / dim_stride) % dim_size; + + // Compute cumulative minimum + let read_idx_0 = (before_dim / dim_size) * (dim_size * dim_stride) + after_dim; + let mut min_val = input[read_idx_0]; + + for i in 1..dim_size { + if i <= dim_offset { + let read_idx = + (before_dim / dim_size) * (dim_size * dim_stride) + i * dim_stride + after_dim; + let val = input[read_idx]; + if val < min_val { + min_val = val; + } + } + } + + output[idx] = min_val; +} + +/// Compute the cumulative minimum along a dimension +/// +/// # Limitations +/// +/// This is a **naive sequential implementation** along the cummin dimension: +/// - Each output element sequentially reads all previous elements along the dimension +/// - Computational complexity: O(n²) memory reads where n is the size of the cummin dimension +/// - **Performance:** Suitable for small tensors or small dimensions. For large tensors, +/// performance will degrade significantly compared to an optimized parallel scan algorithm. +/// +/// # TODO +/// +/// Implement an efficient GPU-optimized parallel scan algorithm (cubecl-scan crate). +/// See draft PR: https://github.com/tracel-ai/cubecl/pull/863 +/// +/// References: +/// - https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda +/// - https://www.w3.org/TR/WGSL/#builtin-subgroupInclusiveAdd +pub fn cummin(input: CubeTensor, dim: usize) -> CubeTensor { + let client = input.client.clone(); + let device = input.device.clone(); + let shape = input.shape.clone(); + let dim_size = shape.dims[dim]; + + // Calculate stride for the cummin dimension + let dim_stride: usize = shape.dims[dim + 1..].iter().product(); + + let output = empty_device::(client.clone(), device, shape); + + let num_elems = output.shape.num_elements(); + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); + + cummin_kernel::launch::( + &client, + cube_count, + cube_dim, + unsafe { + TensorArg::from_raw_parts::(&input.handle, &input.strides, &input.shape.dims, 1) + }, + unsafe { + TensorArg::from_raw_parts::(&output.handle, &output.strides, &output.shape.dims, 1) + }, + ScalarArg::new(dim_stride as u32), + dim_size as u32, + ); + + output +} diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 7da702f2de..d8ffd9fbc1 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -1479,6 +1479,42 @@ impl FloatTensorOps for Fusion { out } + fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { + #[derive(new, Debug)] + struct CumminOps { + desc: DimOpIr, + _b: PhantomData, + } + + impl Operation for CumminOps { + fn execute(&self, handles: &mut HandleContainer) { + let input = handles.get_float_tensor::(&self.desc.input); + let output = B::float_cummin(input, self.desc.axis); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + + let mut streams = OperationStreams::default(); + streams.tensor(&tensor); + let dtype = tensor.dtype; + let shape = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape, dtype); + + let desc = DimOpIr { + input: tensor.into_ir(), + out: out.to_ir_out(), + axis: dim, + }; + + out.client.register( + streams, + OperationIr::BaseFloat(BaseOperationIr::CumMin(desc.clone())), + CumminOps::::new(desc), + ); + + out + } + fn float_exp(lhs: FloatTensor) -> FloatTensor { unary_float_ops!(ExpOps, B::float_exp); diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 0f7fcb48c9..187a3e6477 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -1314,6 +1314,41 @@ impl IntTensorOps for Fusion { out } + fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { + #[derive(new, Debug)] + struct CumminOps { + desc: DimOpIr, + _b: PhantomData, + } + + impl Operation for CumminOps { + fn execute(&self, handles: &mut HandleContainer) { + let input = handles.get_int_tensor::(&self.desc.input); + let output = B::int_cummin(input, self.desc.axis); + handles.register_int_tensor::(&self.desc.out.id, output); + } + } + + let dtype = tensor.dtype; + let mut streams = OperationStreams::default(); + streams.tensor(&tensor); + let shape = tensor.shape.clone(); + let out = tensor.client.tensor_uninitialized(shape, dtype); + + let desc = DimOpIr { + out: out.to_ir_out(), + input: tensor.into_ir(), + axis: dim, + }; + out.client.register( + streams, + OperationIr::BaseInt(BaseOperationIr::CumMin(desc.clone())), + CumminOps::::new(desc), + ); + + out + } + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { reduce_int_ops!(ArgMaxOps, B::int_argmax); diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 19537c5320..90a0b6e15c 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -962,6 +962,11 @@ impl RelativeOps for BaseOperationIr { out: desc.out.to_relative(converter), axis: desc.axis, }), + BaseOperationIr::CumMin(desc) => BaseOperationIr::CumMin(DimOpIr { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + axis: desc.axis, + }), BaseOperationIr::Empty(desc) => BaseOperationIr::Empty(desc.to_relative(converter)), } } diff --git a/crates/burn-ir/src/operation.rs b/crates/burn-ir/src/operation.rs index 102ab674e7..68efd0c1d0 100644 --- a/crates/burn-ir/src/operation.rs +++ b/crates/burn-ir/src/operation.rs @@ -286,6 +286,12 @@ pub enum BaseOperationIr { /// Int => [cumsum](burn_tensor::ops::IntTensorOps::int_cumsum). CumSum(DimOpIr), + /// Operation corresponding to: + /// + /// Float => [cummin](burn_tensor::ops::FloatTensorOps::float_cummin). + /// Int => [cummin](burn_tensor::ops::IntTensorOps::int_cummin). + CumMin(DimOpIr), + /// Operation corresponding to: /// /// Float => [empty](burn_tensor::ops::FloatTensorOps::float_empty). @@ -1523,6 +1529,7 @@ impl BaseOperationIr { } BaseOperationIr::Cast(repr) => vec![&repr.input, &repr.out], BaseOperationIr::CumSum(repr) => vec![&repr.input, &repr.out], + BaseOperationIr::CumMin(repr) => vec![&repr.input, &repr.out], BaseOperationIr::Empty(repr) => vec![repr], BaseOperationIr::Unfold(repr) => { vec![&repr.input, &repr.out] @@ -1579,6 +1586,9 @@ impl BaseOperationIr { BaseOperationIr::CumSum(repr) => { repr.input.mark_read_only(nodes, &mut output); } + BaseOperationIr::CumMin(repr) => { + repr.input.mark_read_only(nodes, &mut output); + } BaseOperationIr::Unfold(repr) => { repr.input.mark_read_only(nodes, &mut output); } diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index f23e3d7aa0..11cb234fd5 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -33,7 +33,7 @@ use crate::ops::simd::{ use crate::reshape; use crate::{ IntNdArrayElement, ShapeOps, - ops::macros::{cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim}, + ops::macros::{cummin_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim}, }; use crate::{SharedArray, element::NdArrayElement}; use burn_tensor::Shape; @@ -616,6 +616,10 @@ where cumsum_dim(tensor, dim) } + pub fn cummin(tensor: SharedArray, dim: usize) -> SharedArray { + cummin_dim(tensor, dim) + } + pub fn gather( dim: usize, mut tensor: SharedArray, diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index b1ee86bfcc..efa066ff98 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -212,6 +212,10 @@ where execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim)) } + fn int_cummin(tensor: NdArrayTensor, dim: usize) -> NdArrayTensor { + execute_with_int_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim)) + } + fn int_gather(dim: usize, tensor: NdArrayTensor, indices: NdArrayTensor) -> NdArrayTensor { execute_with_int_dtype!(tensor, E, |tensor: SharedArray| -> NdArrayTensor { execute_with_int_dtype!(indices, |indices| NdArrayMathOps::gather( diff --git a/crates/burn-ndarray/src/ops/macros.rs b/crates/burn-ndarray/src/ops/macros.rs index f93fd6d5c8..f16c2b79de 100644 --- a/crates/burn-ndarray/src/ops/macros.rs +++ b/crates/burn-ndarray/src/ops/macros.rs @@ -68,3 +68,23 @@ pub(crate) fn cumsum_dim(tensor: SharedArray, dim: usize) result.into_shared() } + +pub(crate) fn cummin_dim(tensor: SharedArray, dim: usize) -> SharedArray { + let axis = Axis(dim); + let shape = tensor.shape().to_vec(); + let mut result = tensor.to_owned(); + + // Compute cumulative minimum along the specified axis + let dim_size = shape[dim]; + for i in 1..dim_size { + let prev = result.index_axis(axis, i - 1).to_owned(); + let mut current = result.index_axis_mut(axis, i); + Zip::from(&mut current).and(&prev).for_each(|c, &p| { + if p < *c { + *c = p; + } + }); + } + + result.into_shared() +} diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index 606d78729c..6092ad8d9c 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -323,6 +323,10 @@ where execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim)) } + fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { + execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cummin(tensor, dim)) + } + fn float_sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::sum_dim(tensor, dim)) } diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index 9e051f539e..48a80329b2 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -962,6 +962,23 @@ impl FloatTensorOps for BackendRouter { out } + fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let shape = tensor.shape.clone(); + let out = client.register_empty_tensor(shape, dtype); + + let desc = DimOpIr { + input: tensor.into_ir(), + axis: dim, + out: out.to_ir_out(), + }; + + client.register(OperationIr::BaseFloat(BaseOperationIr::CumMin(desc))); + + out + } + fn float_exp(lhs: FloatTensor) -> FloatTensor { let client = lhs.client.clone(); let dtype = lhs.dtype; diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index 3b32ba8bba..78e8e7ffc9 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -865,6 +865,23 @@ impl IntTensorOps for BackendRouter { out } + fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let shape = tensor.shape.clone(); + let out = client.register_empty_tensor(shape, dtype); + + let desc = DimOpIr { + input: tensor.into_ir(), + axis: dim, + out: out.to_ir_out(), + }; + + client.register(OperationIr::BaseInt(BaseOperationIr::CumMin(desc))); + + out + } + fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { let client = tensor.client.clone(); let dtype = tensor.dtype; diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index b0c6dc16bf..e4dd52d122 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -234,6 +234,11 @@ impl RunnerClient for Runner { let output = B::float_cumsum(tensor, desc.axis); handles.register_float_tensor::(&desc.out.id, output); } + BaseOperationIr::CumMin(desc) => { + let tensor = handles.get_float_tensor::(&desc.input); + let output = B::float_cummin(tensor, desc.axis); + handles.register_float_tensor::(&desc.out.id, output); + } BaseOperationIr::Empty(desc) => { let shape = Shape::from(desc.shape.clone()); let output = B::float_empty(shape, &self.device, desc.dtype.into()); @@ -316,6 +321,11 @@ impl RunnerClient for Runner { let output = B::int_cumsum(tensor, desc.axis); handles.register_int_tensor::(&desc.out.id, output); } + BaseOperationIr::CumMin(desc) => { + let tensor = handles.get_int_tensor::(&desc.input); + let output = B::int_cummin(tensor, desc.axis); + handles.register_int_tensor::(&desc.out.id, output); + } BaseOperationIr::Empty(desc) => { let shape = Shape::from(desc.shape.clone()); let output = B::int_empty(shape, &self.device, desc.dtype.into()); @@ -398,6 +408,7 @@ impl RunnerClient for Runner { } BaseOperationIr::Cast(_) => unreachable!(), BaseOperationIr::CumSum(_) => unreachable!("cumsum not supported for bool tensors"), + BaseOperationIr::CumMin(_) => unreachable!("cummin not supported for bool tensors"), BaseOperationIr::Empty(desc) => { let shape = Shape::from(desc.shape.clone()); let output = B::bool_empty(shape, &self.device); diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 324ceaaeb2..6536d2a29c 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -458,6 +458,11 @@ impl TchOps { ) } + pub fn cummin(tensor: TchTensor, dim: usize) -> TchTensor { + let (values, _indices) = tensor.tensor.cummin(dim as i64); + TchTensor::from_existing(values, tensor.storage) + } + pub fn argmax(tensor: TchTensor, dim: usize) -> TchTensor { let storage = tensor.storage.clone(); let tensor = tensor.tensor.argmax(dim as i64, true); diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index c7e9fd00fc..61731dd405 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -281,6 +281,10 @@ impl IntTensorOps for LibTorch { TchOps::cumsum(tensor, dim) } + fn int_cummin(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::cummin(tensor, dim) + } + fn int_gather(dim: usize, tensor: TchTensor, indices: TchTensor) -> TchTensor { TchOps::gather(dim, tensor, indices) } diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index fd5e40e4f4..32f558f6c6 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -315,6 +315,10 @@ impl FloatTensorOps for LibTorch { TchOps::cumsum(tensor, dim) } + fn float_cummin(tensor: TchTensor, dim: usize) -> TchTensor { + TchOps::cummin(tensor, dim) + } + fn float_prod(tensor: TchTensor) -> TchTensor { TchOps::prod(tensor) } diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index dbf637fd73..ad08c1e3dc 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -695,6 +695,39 @@ where Self::new(K::cumsum(self.primitive, dim)) } + /// Computes the cumulative minimum of elements along the given *dimension* or *axis*. + /// + /// # Arguments + /// + /// * `dim` - The dimension or axis along which to compute the cumulative minimum. + /// + /// # Note + /// + /// This operation is **not supported for the autodiff backend** and will panic. + /// Proper gradient computation requires scatter_add which is not yet implemented. + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Shape}; + /// + /// fn example() { + /// let device = B::Device::default(); + /// let tensor = Tensor::::from_data([[3.0, 5.0, 2.0], [4.0, 1.0, 6.0]], &device); + /// let result = tensor.clone().cummin(0); + /// println!("{result}"); + /// // [[3.0, 5.0, 2.0], [3.0, 1.0, 2.0]] + /// let result = tensor.cummin(1); + /// println!("{result}"); + /// // [[3.0, 3.0, 2.0], [4.0, 1.0, 1.0]] + /// } + /// ``` + pub fn cummin(self, dim: usize) -> Self { + check!(TensorCheck::aggregate_dim::("CumMin", dim)); + Self::new(K::cummin(self.primitive, dim)) + } + /// /// # Arguments /// @@ -2794,6 +2827,28 @@ where /// the [Tensor::cumsum](Tensor::cumsum) function, which is more high-level and designed for public use. fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + /// Computes the cumulative minimum of elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the cumulative minimum of. + /// * `dim` - The dimension along which to compute the cumulative minimum. + /// + /// # Returns + /// + /// A tensor with the same shape as the input tensor, where each element is the minimum + /// of all elements up to and including that position along the specified dimension. + /// + /// # Remarks + /// + /// This is a low-level function used internally by the library to call different backend functions + /// with static dispatch. It is not designed for direct usage by users, and not recommended to import + /// or use this function directly. + /// + /// For computing the cumulative minimum of elements along a dimension, users should prefer + /// the [Tensor::cummin](Tensor::cummin) function, which is more high-level and designed for public use. + fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive; + /// Element-wise equality between two tensors. /// /// # Arguments @@ -3642,6 +3697,10 @@ impl Numeric for Int { B::int_cumsum(tensor, dim) } + fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + B::int_cummin(tensor, dim) + } + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive { B::int_equal_elem(lhs, rhs) } @@ -3962,6 +4021,13 @@ impl Numeric for Float { } } + fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive { + match tensor { + TensorPrimitive::Float(tensor) => TensorPrimitive::Float(B::float_cummin(tensor, dim)), + TensorPrimitive::QFloat(tensor) => B::q_cummin(tensor, dim), + } + } + fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive { B::float_equal_elem(lhs.tensor(), rhs) } diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index 8f02240b18..0478c1e9f8 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -767,6 +767,19 @@ pub trait IntTensorOps { /// of all elements up to and including that position along the dimension. fn int_cumsum(tensor: IntTensor, dim: usize) -> IntTensor; + /// Computes the cumulative minimum of elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the cumulative minimum of. + /// * `dim` - The dimension along which to compute the cumulative minimum. + /// + /// # Returns + /// + /// A tensor with the same shape where each element is the minimum + /// of all elements up to and including that position along the dimension. + fn int_cummin(tensor: IntTensor, dim: usize) -> IntTensor; + /// Gets the indices of the maximum elements along a dimension. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/qtensor.rs b/crates/burn-tensor/src/tensor/ops/qtensor.rs index 3bb8024efe..d8ef43abbe 100644 --- a/crates/burn-tensor/src/tensor/ops/qtensor.rs +++ b/crates/burn-tensor/src/tensor/ops/qtensor.rs @@ -690,6 +690,25 @@ pub trait QTensorOps { ) } + /// Computes the cumulative minimum of elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the cumulative minimum of. + /// * `dim` - The dimension along which to compute the cumulative minimum. + /// + /// # Returns + /// + /// A tensor with the same shape where each element is the minimum + /// of all elements up to and including that position along the dimension. + fn q_cummin(tensor: QuantizedTensor, dim: usize) -> TensorPrimitive { + dequant_op_flow!( + ty Self, + float_op |tensor| B::float_cummin(tensor, dim), + tensor + ) + } + /// Returns a new tensor with exponential values. /// /// # Arguments diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 0fed3751a2..523d45f887 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -801,6 +801,19 @@ pub trait FloatTensorOps { /// of all elements up to and including that position along the dimension. fn float_cumsum(tensor: FloatTensor, dim: usize) -> FloatTensor; + /// Computes the cumulative minimum of elements along a dimension. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to compute the cumulative minimum of. + /// * `dim` - The dimension along which to compute the cumulative minimum. + /// + /// # Returns + /// + /// A tensor with the same shape where each element is the minimum + /// of all elements up to and including that position along the dimension. + fn float_cummin(tensor: FloatTensor, dim: usize) -> FloatTensor; + /// Converts a tensor to another floating point data type. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index a6d20d3899..f5eae342b1 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -208,6 +208,7 @@ macro_rules! testgen_with_float_param { burn_tensor::testgen_cosh!(); burn_tensor::testgen_create_like!(); burn_tensor::testgen_cross!(); + burn_tensor::testgen_cummin!(); burn_tensor::testgen_cumsum!(); burn_tensor::testgen_div!(); burn_tensor::testgen_dot!(); @@ -297,6 +298,7 @@ macro_rules! testgen_with_int_param { burn_tensor::testgen_cast!(); burn_tensor::testgen_bool!(); burn_tensor::testgen_cat!(); + burn_tensor::testgen_cummin!(); burn_tensor::testgen_cumsum!(); burn_tensor::testgen_div!(); burn_tensor::testgen_expand!(); diff --git a/crates/burn-tensor/src/tests/ops/cummin.rs b/crates/burn-tensor/src/tests/ops/cummin.rs new file mode 100644 index 0000000000..1781c8a8db --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/cummin.rs @@ -0,0 +1,61 @@ +#[burn_tensor_testgen::testgen(cummin)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn test_cummin_float_dim_0() { + let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]); + + let output = tensor.cummin(0); + + output + .into_data() + .assert_eq(&TensorData::from([[3.0, 1.0, 4.0], [2.0, 1.0, 1.0]]), false); + } + + #[test] + fn test_cummin_float_dim_1() { + let tensor = TestTensor::<2>::from([[3.0, 1.0, 4.0], [2.0, 5.0, 1.0]]); + + let output = tensor.cummin(1); + + output + .into_data() + .assert_eq(&TensorData::from([[3.0, 1.0, 1.0], [2.0, 2.0, 1.0]]), false); + } + + #[test] + fn test_cummin_int_dim_0() { + let tensor = TestTensorInt::<2>::from([[3, 1, 4], [2, 5, 1]]); + + let output = tensor.cummin(0); + + output + .into_data() + .assert_eq(&TensorData::from([[3, 1, 4], [2, 1, 1]]), false); + } + + #[test] + fn test_cummin_int_dim_1() { + let tensor = TestTensorInt::<2>::from([[3, 1, 4], [2, 5, 1]]); + + let output = tensor.cummin(1); + + output + .into_data() + .assert_eq(&TensorData::from([[3, 1, 1], [2, 2, 1]]), false); + } + + #[test] + fn test_cummin_float_3d() { + let tensor = TestTensor::<3>::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 6.0], [7.0, 8.0]]]); + + let output = tensor.cummin(2); + + output.into_data().assert_eq( + &TensorData::from([[[4.0, 2.0], [3.0, 1.0]], [[5.0, 5.0], [7.0, 7.0]]]), + false, + ); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index 2342557f81..3907ec44d1 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -20,6 +20,7 @@ mod cos; mod cosh; mod create_like; mod cross; +mod cummin; mod cumsum; mod div; mod dot;