Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
aef4168
Add ArcSin, ArcCos, ArcTan and ArcTan2 to float operations
relativityhd Sep 4, 2025
416fa3d
Add tests for acos, asin, atan and atan2
relativityhd Sep 4, 2025
99a5d96
Add sinh, cosh, asinh, acosh and atanh
relativityhd Sep 4, 2025
89b1642
Add degrees and radians function
relativityhd Sep 4, 2025
068d63e
Merge branch 'main' into feature/arc-trigonomic-functions
relativityhd Sep 4, 2025
e6d0410
Implement trigonometric functions in CPU backend Also try to handle half
relativityhd Sep 5, 2025
6735788
Add to_degrees and to_radians functions to cube-std
relativityhd Sep 6, 2025
d07c2b2
Register math to llvm transform of mlir
relativityhd Sep 6, 2025
e33ac87
Merge branch 'main' into feature/arc-trigonomic-functions
relativityhd Sep 6, 2025
7aa1913
Disable all ods_math dependant arithmetics for now
relativityhd Sep 7, 2025
52429f7
Add dummy implementations instead of todo! to satisfy compilation of
relativityhd Sep 7, 2025
831fa97
Merge branch 'main' into feature/arc-trigonomic-functions
relativityhd Sep 9, 2025
11cd568
Fix merge formatting
relativityhd Sep 9, 2025
7977e7f
Rename degrees and radians to to_degrees and to_radians to reflect rusts
relativityhd Sep 9, 2025
5f06df5
Add tan operation
relativityhd Sep 10, 2025
5e1f84d
Make runtime tests for unary epsilon dependent
relativityhd Sep 12, 2025
e55eb53
Add trigonometry module
relativityhd Sep 12, 2025
48302fe
Merge branch 'main' into feature/arc-trigonomic-functions
relativityhd Sep 13, 2025
93dda9b
Update spir-v calls to forked version
relativityhd Sep 13, 2025
1319a30
Remove unnecessary trig function in std
relativityhd Oct 7, 2025
8923d14
Merge branch 'main' into feature/arc-trigonomic-functions
relativityhd Oct 7, 2025
0af20fc
Fix for refactored launch and reenable std trig tests
relativityhd Oct 7, 2025
4eff5a8
remove dummy implementations for ods math arithmetics
relativityhd Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions crates/cubecl-core/src/frontend/element/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,19 @@ pub trait Float:
+ Log1p
+ Cos
+ Sin
+ Tan
+ Tanh
+ Sinh
+ Cosh
+ ArcCos
+ ArcSin
+ ArcTan
+ ArcSinh
+ ArcCosh
+ ArcTanh
+ Degrees
+ Radians
+ ArcTan2
+ Powf
+ Powi<i32>
+ Sqrt
Expand Down
12 changes: 12 additions & 0 deletions crates/cubecl-core/src/frontend/element/float/typemap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,19 @@ impl<const POS: u8> Log for ElemExpand<POS> {}
impl<const POS: u8> Log1p for ElemExpand<POS> {}
impl<const POS: u8> Cos for ElemExpand<POS> {}
impl<const POS: u8> Sin for ElemExpand<POS> {}
impl<const POS: u8> Tan for ElemExpand<POS> {}
impl<const POS: u8> Tanh for ElemExpand<POS> {}
impl<const POS: u8> Sinh for ElemExpand<POS> {}
impl<const POS: u8> Cosh for ElemExpand<POS> {}
impl<const POS: u8> ArcCos for ElemExpand<POS> {}
impl<const POS: u8> ArcSin for ElemExpand<POS> {}
impl<const POS: u8> ArcTan for ElemExpand<POS> {}
impl<const POS: u8> ArcSinh for ElemExpand<POS> {}
impl<const POS: u8> ArcCosh for ElemExpand<POS> {}
impl<const POS: u8> ArcTanh for ElemExpand<POS> {}
impl<const POS: u8> Degrees for ElemExpand<POS> {}
impl<const POS: u8> Radians for ElemExpand<POS> {}
impl<const POS: u8> ArcTan2 for ElemExpand<POS> {}
impl<const POS: u8> Powf for ElemExpand<POS> {}
impl<const POS: u8, I: CubePrimitive> Powi<I> for ElemExpand<POS> {}
impl<const POS: u8> Sqrt for ElemExpand<POS> {}
Expand Down
11 changes: 11 additions & 0 deletions crates/cubecl-core/src/frontend/operation/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ impl_binary_func!(
f32,
f64
);
impl_binary_func!(
ArcTan2,
atan2,
Arithmetic::ArcTan2,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_binary_func!(
Max,
max,
Expand Down
132 changes: 132 additions & 0 deletions crates/cubecl-core/src/frontend/operation/unary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,18 @@ impl_unary_func!(
f32,
f64
);
impl_unary_func!(
Tan,
tan,
__expand_tan,
Arithmetic::Tan,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
Tanh,
tanh,
Expand All @@ -187,6 +199,126 @@ impl_unary_func!(
f32,
f64
);
impl_unary_func!(
Sinh,
sinh,
__expand_sinh,
Arithmetic::Sinh,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
Cosh,
cosh,
__expand_cosh,
Arithmetic::Cosh,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
ArcCos,
acos,
__expand_acos,
Arithmetic::ArcCos,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
ArcSin,
asin,
__expand_asin,
Arithmetic::ArcSin,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
ArcTan,
atan,
__expand_atan,
Arithmetic::ArcTan,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
ArcSinh,
asinh,
__expand_asinh,
Arithmetic::ArcSinh,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
ArcCosh,
acosh,
__expand_acosh,
Arithmetic::ArcCosh,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
ArcTanh,
atanh,
__expand_atanh,
Arithmetic::ArcTanh,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
Degrees,
to_degrees,
__expand_to_degrees,
Arithmetic::Degrees,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
Radians,
to_radians,
__expand_to_radians,
Arithmetic::Radians,
f16,
bf16,
flex32,
tf32,
f32,
f64
);
impl_unary_func!(
Sqrt,
sqrt,
Expand Down
30 changes: 30 additions & 0 deletions crates/cubecl-core/src/runtime_tests/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,35 @@ test_binary_impl!(
]
);

test_binary_impl!(
test_atan2,
F,
F::atan2,
[
{
input_vectorization: 1,
out_vectorization: 1,
lhs: as_type![F: 0., 1., -1., 1., -1.],
rhs: as_type![F: 1., 0., 0., 1., -1.],
expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339, -2.35619449019]
},
{
input_vectorization: 2,
out_vectorization: 2,
lhs: as_type![F: 0., 1., -1., 1.],
rhs: as_type![F: 1., 0., 0., 1.],
expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339]
},
{
input_vectorization: 4,
out_vectorization: 4,
lhs: as_type![F: 0., 1., -1., 1.],
rhs: as_type![F: 1., 0., 0., 1.],
expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339]
}
]
);

#[cube(launch_unchecked)]
fn test_powi_kernel<F: Float>(
lhs: &Array<Line<F>>,
Expand Down Expand Up @@ -321,6 +350,7 @@ macro_rules! testgen_binary {
add_test!(test_dot);
add_test!(test_powf);
add_test!(test_powi);
add_test!(test_atan2);
}
};
}
Expand Down
Loading
Loading