diff --git a/crates/cubecl-core/src/frontend/element/float.rs b/crates/cubecl-core/src/frontend/element/float.rs index ca483c703..43d80984b 100644 --- a/crates/cubecl-core/src/frontend/element/float.rs +++ b/crates/cubecl-core/src/frontend/element/float.rs @@ -25,7 +25,19 @@ pub trait Float: + Log1p + Cos + Sin + + Tan + Tanh + + Sinh + + Cosh + + ArcCos + + ArcSin + + ArcTan + + ArcSinh + + ArcCosh + + ArcTanh + + Degrees + + Radians + + ArcTan2 + Powf + Powi + Sqrt diff --git a/crates/cubecl-core/src/frontend/element/float/typemap.rs b/crates/cubecl-core/src/frontend/element/float/typemap.rs index f027346ff..265c4422e 100644 --- a/crates/cubecl-core/src/frontend/element/float/typemap.rs +++ b/crates/cubecl-core/src/frontend/element/float/typemap.rs @@ -243,7 +243,19 @@ impl Log for ElemExpand {} impl Log1p for ElemExpand {} impl Cos for ElemExpand {} impl Sin for ElemExpand {} +impl Tan for ElemExpand {} impl Tanh for ElemExpand {} +impl Sinh for ElemExpand {} +impl Cosh for ElemExpand {} +impl ArcCos for ElemExpand {} +impl ArcSin for ElemExpand {} +impl ArcTan for ElemExpand {} +impl ArcSinh for ElemExpand {} +impl ArcCosh for ElemExpand {} +impl ArcTanh for ElemExpand {} +impl Degrees for ElemExpand {} +impl Radians for ElemExpand {} +impl ArcTan2 for ElemExpand {} impl Powf for ElemExpand {} impl Powi for ElemExpand {} impl Sqrt for ElemExpand {} diff --git a/crates/cubecl-core/src/frontend/operation/binary.rs b/crates/cubecl-core/src/frontend/operation/binary.rs index 336ba4ef1..53be63601 100644 --- a/crates/cubecl-core/src/frontend/operation/binary.rs +++ b/crates/cubecl-core/src/frontend/operation/binary.rs @@ -252,6 +252,17 @@ impl_binary_func!( f32, f64 ); +impl_binary_func!( + ArcTan2, + atan2, + Arithmetic::ArcTan2, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_binary_func!( Max, max, diff --git a/crates/cubecl-core/src/frontend/operation/unary.rs b/crates/cubecl-core/src/frontend/operation/unary.rs index 382e64698..221226643 100644 --- a/crates/cubecl-core/src/frontend/operation/unary.rs +++ b/crates/cubecl-core/src/frontend/operation/unary.rs @@ -175,6 +175,18 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + Tan, + tan, + __expand_tan, + Arithmetic::Tan, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Tanh, tanh, @@ -187,6 +199,126 @@ impl_unary_func!( f32, f64 ); +impl_unary_func!( + Sinh, + sinh, + __expand_sinh, + Arithmetic::Sinh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + Cosh, + cosh, + __expand_cosh, + Arithmetic::Cosh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcCos, + acos, + __expand_acos, + Arithmetic::ArcCos, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcSin, + asin, + __expand_asin, + Arithmetic::ArcSin, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcTan, + atan, + __expand_atan, + Arithmetic::ArcTan, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcSinh, + asinh, + __expand_asinh, + Arithmetic::ArcSinh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcCosh, + acosh, + __expand_acosh, + Arithmetic::ArcCosh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + ArcTanh, + atanh, + __expand_atanh, + Arithmetic::ArcTanh, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + Degrees, + to_degrees, + __expand_to_degrees, + Arithmetic::Degrees, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); +impl_unary_func!( + Radians, + to_radians, + __expand_to_radians, + Arithmetic::Radians, + f16, + bf16, + flex32, + tf32, + f32, + f64 +); impl_unary_func!( Sqrt, sqrt, diff --git a/crates/cubecl-core/src/runtime_tests/binary.rs b/crates/cubecl-core/src/runtime_tests/binary.rs index c27fcce3c..df88cb787 100644 --- a/crates/cubecl-core/src/runtime_tests/binary.rs +++ b/crates/cubecl-core/src/runtime_tests/binary.rs @@ -150,6 +150,35 @@ test_binary_impl!( ] ); +test_binary_impl!( + test_atan2, + F, + F::atan2, + [ + { + input_vectorization: 1, + out_vectorization: 1, + lhs: as_type![F: 0., 1., -1., 1., -1.], + rhs: as_type![F: 1., 0., 0., 1., -1.], + expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339, -2.35619449019] + }, + { + input_vectorization: 2, + out_vectorization: 2, + lhs: as_type![F: 0., 1., -1., 1.], + rhs: as_type![F: 1., 0., 0., 1.], + expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339] + }, + { + input_vectorization: 4, + out_vectorization: 4, + lhs: as_type![F: 0., 1., -1., 1.], + rhs: as_type![F: 1., 0., 0., 1.], + expected: as_type![F: 0., 1.57079632679, -1.57079632679, 0.78539816339] + } + ] +); + #[cube(launch_unchecked)] fn test_powi_kernel( lhs: &Array>, @@ -321,6 +350,7 @@ macro_rules! testgen_binary { add_test!(test_dot); add_test!(test_powf); add_test!(test_powi); + add_test!(test_atan2); } }; } diff --git a/crates/cubecl-core/src/runtime_tests/unary.rs b/crates/cubecl-core/src/runtime_tests/unary.rs index 714ee5ddd..1c9d75df0 100644 --- a/crates/cubecl-core/src/runtime_tests/unary.rs +++ b/crates/cubecl-core/src/runtime_tests/unary.rs @@ -1,3 +1,4 @@ +use std::f32::consts::PI; use std::fmt::Display; use crate::{self as cubecl, as_type}; @@ -46,6 +47,24 @@ macro_rules! test_unary_impl { input: $input:expr, expected: $expected:expr }),*]) => { + test_unary_impl!($test_name, $float_type, $unary_func, [$({ + input_vectorization: $input_vectorization, + out_vectorization: $out_vectorization, + input: $input, + expected: $expected + }),*], 0.02); + }; + ( + $test_name:ident, + $float_type:ident, + $unary_func:expr, + [$({ + input_vectorization: $input_vectorization:expr, + out_vectorization: $out_vectorization:expr, + input: $input:expr, + expected: $expected:expr + }),*], + $epsilon:expr) => { pub fn $test_name(client: ComputeClient) { #[cube(launch_unchecked)] fn test_function<$float_type: Float>(input: &Array<$float_type>, output: &mut Array<$float_type>) { @@ -70,7 +89,7 @@ macro_rules! test_unary_impl { ) }; - assert_equals_approx::(&client, output_handle, $expected, $float_type::new(0.02)); + assert_equals_approx::(&client, output_handle, $expected, $float_type::new($epsilon)); } )* } @@ -214,6 +233,279 @@ macro_rules! test_unary_impl_int_fixed { }; } +test_unary_impl!(test_sin, F, F::sin, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 0., 1., 0., -1.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 0., 1., 0., -1.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 0., 1., 0., -1.] + } +]); + +test_unary_impl!(test_cos, F, F::cos, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 1., 0., -1., 0.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 1., 0., -1., 0.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1.57079632679, 3.14159265359, -1.57079632679], + expected: as_type![F: 1., 0., -1., 0.] + } +]); + +test_unary_impl!(test_tan, F, F::tan, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 0.78539816339, 1.04719755119, -0.78539816339], + expected: as_type![F: 0., 1., 1.73205080757, -1.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 0.78539816339, 1.04719755119, -0.78539816339], + expected: as_type![F: 0., 1., 1.73205080757, -1.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 0.78539816339, 1.04719755119, -0.78539816339], + expected: as_type![F: 0., 1., 1.73205080757, -1.] + } +]); + +test_unary_impl!(test_asin, F, F::asin, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 0.5, 1., -0.5, -1.], + expected: as_type![F: 0., 0.52359877559, 1.57079632679, -0.52359877559, -1.57079632679] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 0.5, 1., -0.5], + expected: as_type![F: 0., 0.52359877559, 1.57079632679, -0.52359877559] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 0.5, 1., -0.5], + expected: as_type![F: 0., 0.52359877559, 1.57079632679, -0.52359877559] + } +]); + +test_unary_impl!(test_acos, F, F::acos, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 1., 0.5, 0., -0.5, -1.], + expected: as_type![F: 0., 1.04719755119, 1.57079632679, 2.09439510239, 3.14159265359] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 1., 0.5, 0., -0.5], + expected: as_type![F: 0., 1.04719755119, 1.57079632679, 2.09439510239] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 1., 0.5, 0., -0.5], + expected: as_type![F: 0., 1.04719755119, 1.57079632679, 2.09439510239] + } +]); + +test_unary_impl!(test_atan, F, F::atan, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 1000., -1000.], + expected: as_type![F: 0., 0.78539816339, -0.78539816339, 1.56979632472, -1.56979632472] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 1000.], + expected: as_type![F: 0., 0.78539816339, -0.78539816339, 1.56979632472] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 1000.], + expected: as_type![F: 0., 0.78539816339, -0.78539816339, 1.56979632472] + } +]); + +test_unary_impl!(test_sinh, F, F::sinh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 2., -2.], + expected: as_type![F: 0., 1.1752011936, -1.1752011936, 3.6268604078, -3.6268604078] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 1.1752011936, -1.1752011936, 3.6268604078] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 1.1752011936, -1.1752011936, 3.6268604078] + } +]); + +test_unary_impl!(test_cosh, F, F::cosh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 2., -2.], + expected: as_type![F: 1., 1.5430806348, 1.5430806348, 3.7621956911, 3.7621956911] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 1., 1.5430806348, 1.5430806348, 3.7621956911] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 1., 1.5430806348, 1.5430806348, 3.7621956911] + } +]); + +test_unary_impl!(test_asinh, F, F::asinh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 1., -1., 2., -2.], + expected: as_type![F: 0., 0.88137358702, -0.88137358702, 1.44363547517, -1.44363547517] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 0.88137358702, -0.88137358702, 1.44363547517] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 1., -1., 2.], + expected: as_type![F: 0., 0.88137358702, -0.88137358702, 1.44363547517] + } +]); + +test_unary_impl!(test_acosh, F, F::acosh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 1., 2., 3., 10.], + expected: as_type![F: 0., 1.31695789692, 1.76274717404, 2.99322284612] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 1., 2., 3., 10.], + expected: as_type![F: 0., 1.31695789692, 1.76274717404, 2.99322284612] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 1., 2., 3., 10.], + expected: as_type![F: 0., 1.31695789692, 1.76274717404, 2.99322284612] + } +]); + +test_unary_impl!(test_atanh, F, F::atanh, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 0.5, -0.5, 0.9, -0.9], + expected: as_type![F: 0., 0.54930614433, -0.54930614433, 1.47221948958, -1.47221948958] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 0.5, -0.5, 0.9], + expected: as_type![F: 0., 0.54930614433, -0.54930614433, 1.47221948958] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 0.5, -0.5, 0.9], + expected: as_type![F: 0., 0.54930614433, -0.54930614433, 1.47221948958] + } +]); + +test_unary_impl!(test_degrees, F, F::to_degrees, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., PI / 2., PI, PI * 2., -PI / 2., -PI, -PI * 2.], + expected: as_type![F: 0., 90., 180., 360., -90., -180., -360.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., PI / 2., PI, -PI / 2.], + expected: as_type![F: 0., 90., 180., -90.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., PI / 2., PI, -PI / 2.], + expected: as_type![F: 0., 90., 180., -90.] + } +], 0.3); + +test_unary_impl!(test_radians, F, F::to_radians, [ + { + input_vectorization: 1, + out_vectorization: 1, + input: as_type![F: 0., 90., 180., 360., -90., -180., -360.], + expected: as_type![F: 0., PI / 2., PI, PI * 2., -PI / 2., -PI, -PI * 2.] + }, + { + input_vectorization: 2, + out_vectorization: 2, + input: as_type![F: 0., 90., 180., -90.], + expected: as_type![F: 0., PI / 2., PI, -PI / 2.] + }, + { + input_vectorization: 4, + out_vectorization: 4, + input: as_type![F: 0., 90., 180., -90.], + expected: as_type![F: 0., PI / 2., PI, -PI / 2.] + } +]); + test_unary_impl!( test_magnitude, F, @@ -476,6 +768,18 @@ macro_rules! testgen_unary { }; } + add_test!(test_sin); + add_test!(test_cos); + add_test!(test_sinh); + add_test!(test_cosh); + add_test!(test_asin); + add_test!(test_acos); + add_test!(test_atan); + add_test!(test_asinh); + add_test!(test_acosh); + add_test!(test_atanh); + add_test!(test_degrees); + add_test!(test_radians); add_test!(test_normalize); add_test!(test_magnitude); add_test!(test_abs); diff --git a/crates/cubecl-cpp/src/metal/dialect.rs b/crates/cubecl-cpp/src/metal/dialect.rs index 0f226f9e2..9ba1a9316 100644 --- a/crates/cubecl-cpp/src/metal/dialect.rs +++ b/crates/cubecl-cpp/src/metal/dialect.rs @@ -825,6 +825,10 @@ impl DialectInstructions for MslDialect { write!(f, "safe_tanh_scalar({input})") } + fn compile_instruction_atan2(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "atan2") + } + // unary fn compile_instruction_find_first_set>( f: &mut std::fmt::Formatter<'_>, diff --git a/crates/cubecl-cpp/src/shared/base.rs b/crates/cubecl-cpp/src/shared/base.rs index cbd7a759f..7b531c9c8 100644 --- a/crates/cubecl-cpp/src/shared/base.rs +++ b/crates/cubecl-cpp/src/shared/base.rs @@ -953,11 +953,69 @@ impl CppCompiler { gpu::Arithmetic::Sin(op) => { instructions.push(Instruction::Sin(self.compile_unary(op, out))) } + gpu::Arithmetic::Tan(op) => { + instructions.push(Instruction::Tan(self.compile_unary(op, out))) + } gpu::Arithmetic::Tanh(op) => { let instruction = Instruction::Tanh(self.compile_unary(op, out)); D::register_instruction_extension(&mut self.extensions, &instruction); instructions.push(instruction) } + gpu::Arithmetic::Sinh(op) => { + let instruction = Instruction::Sinh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::Cosh(op) => { + let instruction = Instruction::Cosh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcCos(op) => { + let instruction = Instruction::ArcCos(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcSin(op) => { + let instruction = Instruction::ArcSin(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcTan(op) => { + let instruction = Instruction::ArcTan(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcSinh(op) => { + let instruction = Instruction::ArcSinh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcCosh(op) => { + let instruction = Instruction::ArcCosh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcTanh(op) => { + let instruction = Instruction::ArcTanh(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::Degrees(op) => { + let instruction = Instruction::Degrees(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::Radians(op) => { + let instruction = Instruction::Radians(self.compile_unary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } + gpu::Arithmetic::ArcTan2(op) => { + let instruction = Instruction::ArcTan2(self.compile_binary(op, out)); + D::register_instruction_extension(&mut self.extensions, &instruction); + instructions.push(instruction) + } gpu::Arithmetic::Powf(op) => { instructions.push(Instruction::Powf(self.compile_binary(op, out))) } diff --git a/crates/cubecl-cpp/src/shared/binary.rs b/crates/cubecl-cpp/src/shared/binary.rs index 98a64f2c0..3f7c28f3b 100644 --- a/crates/cubecl-cpp/src/shared/binary.rs +++ b/crates/cubecl-cpp/src/shared/binary.rs @@ -297,6 +297,54 @@ impl Binary for Powi { } } +pub struct ArcTan2; + +impl Binary for ArcTan2 { + // ArcTan2 doesn't support half and no half equivalent exists + fn format_scalar( + f: &mut std::fmt::Formatter<'_>, + lhs: Lhs, + rhs: Rhs, + item: Item, + ) -> std::fmt::Result { + let elem = item.elem; + match elem { + Elem::F16 | Elem::F16x2 | Elem::BF16 | Elem::BF16x2 => { + write!(f, "{elem}(")?; + D::compile_instruction_atan2(f)?; + write!(f, "(float({lhs}), float({rhs})))") + } + _ => { + D::compile_instruction_atan2(f)?; + write!(f, "({lhs}, {rhs})") + } + } + } + + // ArcTan2 doesn't support half and no half equivalent exists + fn unroll_vec( + f: &mut Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + ) -> core::fmt::Result { + let item_out = out.item(); + let index = out.item().vectorization; + + let out = out.fmt_left(); + writeln!(f, "{out} = {item_out}{{")?; + for i in 0..index { + let lhsi = lhs.index(i); + let rhsi = rhs.index(i); + + Self::format_scalar(f, lhsi, rhsi, item_out)?; + f.write_str(", ")?; + } + + f.write_str("};\n") + } +} + pub struct Max; impl Binary for Max { diff --git a/crates/cubecl-cpp/src/shared/dialect.rs b/crates/cubecl-cpp/src/shared/dialect.rs index 98a736752..ad2b3dace 100644 --- a/crates/cubecl-cpp/src/shared/dialect.rs +++ b/crates/cubecl-cpp/src/shared/dialect.rs @@ -633,6 +633,10 @@ pub trait DialectInstructions { } } + fn compile_instruction_atan2(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "atan2") + } + fn compile_instruction_half_function_name_prefix() -> &'static str { "h" } diff --git a/crates/cubecl-cpp/src/shared/instruction.rs b/crates/cubecl-cpp/src/shared/instruction.rs index 70d450757..78bf5be77 100644 --- a/crates/cubecl-cpp/src/shared/instruction.rs +++ b/crates/cubecl-cpp/src/shared/instruction.rs @@ -164,7 +164,19 @@ pub enum Instruction { Log1p(UnaryInstruction), Cos(UnaryInstruction), Sin(UnaryInstruction), + Tan(UnaryInstruction), Tanh(UnaryInstruction), + Sinh(UnaryInstruction), + Cosh(UnaryInstruction), + ArcCos(UnaryInstruction), + ArcSin(UnaryInstruction), + ArcTan(UnaryInstruction), + ArcSinh(UnaryInstruction), + ArcCosh(UnaryInstruction), + ArcTanh(UnaryInstruction), + Degrees(UnaryInstruction), + Radians(UnaryInstruction), + ArcTan2(BinaryInstruction), Powf(BinaryInstruction), Powi(BinaryInstruction), Sqrt(UnaryInstruction), @@ -515,7 +527,19 @@ for ({i_ty} {i} = {start}; {i} {cmp} {end}; {increment}) {{ Instruction::Log1p(it) => Log1p::format(f, &it.input, &it.out), Instruction::Cos(it) => Cos::format(f, &it.input, &it.out), Instruction::Sin(it) => Sin::format(f, &it.input, &it.out), + Instruction::Tan(it) => Tan::format(f, &it.input, &it.out), Instruction::Tanh(it) => Tanh::format(f, &it.input, &it.out), + Instruction::Sinh(it) => Sinh::format(f, &it.input, &it.out), + Instruction::Cosh(it) => Cosh::format(f, &it.input, &it.out), + Instruction::ArcCos(it) => ArcCos::format(f, &it.input, &it.out), + Instruction::ArcSin(it) => ArcSin::format(f, &it.input, &it.out), + Instruction::ArcTan(it) => ArcTan::format(f, &it.input, &it.out), + Instruction::ArcSinh(it) => ArcSinh::format(f, &it.input, &it.out), + Instruction::ArcCosh(it) => ArcCosh::format(f, &it.input, &it.out), + Instruction::ArcTanh(it) => ArcTanh::format(f, &it.input, &it.out), + Instruction::Degrees(it) => Degrees::format(f, &it.input, &it.out), + Instruction::Radians(it) => Radians::format(f, &it.input, &it.out), + Instruction::ArcTan2(it) => ArcTan2::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powf(it) => Powf::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Powi(it) => Powi::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sqrt(it) => Sqrt::format(f, &it.input, &it.out), diff --git a/crates/cubecl-cpp/src/shared/unary.rs b/crates/cubecl-cpp/src/shared/unary.rs index b23fb6490..e80d79ce0 100644 --- a/crates/cubecl-cpp/src/shared/unary.rs +++ b/crates/cubecl-cpp/src/shared/unary.rs @@ -151,6 +151,15 @@ macro_rules! function { function!(Log, "log"); function!(Cos, "cos"); function!(Sin, "sin"); +function!(Tan, "tan"); +function!(Sinh, "sinh", false); +function!(Cosh, "cosh", false); +function!(ArcCos, "acos", false); +function!(ArcSin, "asin", false); +function!(ArcTan, "atan", false); +function!(ArcSinh, "asinh", false); +function!(ArcCosh, "acosh", false); +function!(ArcTanh, "atanh", false); function!(Sqrt, "sqrt"); function!(Exp, "exp"); function!(Ceil, "ceil"); @@ -192,6 +201,38 @@ impl Unary for Tanh { } } +pub struct Degrees; + +impl Unary for Degrees { + fn format_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: Input, + _out_elem: Elem, + ) -> std::fmt::Result { + write!(f, "{input}*57.29577951308232f") + } + + fn can_optimize() -> bool { + false + } +} + +pub struct Radians; + +impl Unary for Radians { + fn format_scalar>( + f: &mut std::fmt::Formatter<'_>, + input: Input, + _out_elem: Elem, + ) -> std::fmt::Result { + write!(f, "{input}*0.017453292519943295f") + } + + fn can_optimize() -> bool { + false + } +} + pub fn zero_extend(input: impl Component) -> String { match input.elem() { Elem::I8 => format!("{}({}({input}))", Elem::::U32, Elem::::U8), diff --git a/crates/cubecl-cpu/src/compiler/module.rs b/crates/cubecl-cpu/src/compiler/module.rs index 587a9ea7a..c07b9a3ad 100644 --- a/crates/cubecl-cpu/src/compiler/module.rs +++ b/crates/cubecl-cpu/src/compiler/module.rs @@ -73,6 +73,7 @@ impl<'a> Module<'a> { pass_manager.add_pass(pass::conversion::create_vector_to_llvm()); pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); pass_manager.add_pass(pass::conversion::create_func_to_llvm()); + pass_manager.add_pass(pass::conversion::create_math_to_llvm()); pass_manager.add_pass(pass::transform::create_inliner()); pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts()); pass_manager.add_pass(pass::transform::create_sccp()); diff --git a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs index 647bed3ca..5835c7809 100644 --- a/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs +++ b/crates/cubecl-cpu/src/compiler/visitor/operation/arithmetic.rs @@ -3,7 +3,7 @@ use tracel_llvm::mlir_rs::{ dialect::{ arith::{self}, llvm, - ods::{llvm as llvm_ods, vector}, + ods::{llvm as llvm_ods, math as math_ods, vector}, }, ir::Attribute, }; @@ -28,6 +28,70 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } + Arithmetic::ArcCos(acos) => { + let value = self.get_variable(acos.input); + let result = self.append_operation_with_result(math_ods::acos( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::ArcSin(asin) => { + let value = self.get_variable(asin.input); + let result = self.append_operation_with_result(math_ods::asin( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::ArcTan(atan) => { + let value = self.get_variable(atan.input); + let result = self.append_operation_with_result(math_ods::atan( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::ArcSinh(asinh) => { + let value = self.get_variable(asinh.input); + let result = self.append_operation_with_result(math_ods::asinh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::ArcCosh(acosh) => { + let value = self.get_variable(acosh.input); + let result = self.append_operation_with_result(math_ods::acosh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::ArcTanh(atanh) => { + let value = self.get_variable(atanh.input); + let result = self.append_operation_with_result(math_ods::atanh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::ArcTan2(atan2) => { + let (lhs, rhs) = self.get_binary_op_variable(atan2.lhs, atan2.rhs); + let result = self.append_operation_with_result(math_ods::atan_2( + self.context, + lhs, + rhs, + self.location, + )); + self.insert_variable(out, result); + } Arithmetic::SaturatingAdd(_) => { unreachable!("Should be removed by preprocessor") } @@ -97,6 +161,23 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } + Arithmetic::Cosh(cosh) => { + let value = self.get_variable(cosh.input); + let result = self.append_operation_with_result(llvm_ods::intr_cosh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } + Arithmetic::Degrees(degrees) => { + let value = self.get_variable(degrees.input); + // 180 / pi + let f = self.create_float_constant_from_item(degrees.input.ty, 57.29577951308232); + let result = + self.append_operation_with_result(arith::mulf(value, f, self.location)); + self.insert_variable(out, result); + } Arithmetic::Div(div) => { let (lhs, rhs) = self.get_binary_op_variable(div.lhs, div.rhs); let operation = if div.lhs.storage_type().is_signed_int() { @@ -347,6 +428,15 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, result); } + Arithmetic::Radians(radians) => { + let value = self.get_variable(radians.input); + // pi / 180 + let f = + self.create_float_constant_from_item(radians.input.ty, 0.017453292519943295); + let result = + self.append_operation_with_result(arith::mulf(value, f, self.location)); + self.insert_variable(out, result); + } Arithmetic::Recip(recip) => { let value = self.get_variable(recip.input); let one = self.create_float_constant_from_item(recip.input.ty, 1.0); @@ -398,6 +488,15 @@ impl<'a> Visitor<'a> { )); self.insert_variable(out, output); } + Arithmetic::Sinh(sinh) => { + let value = self.get_variable(sinh.input); + let result = self.append_operation_with_result(llvm_ods::intr_sinh( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } Arithmetic::Sqrt(sqrt) => { let input = self.get_variable(sqrt.input); let output = self.append_operation_with_result(llvm_ods::intr_sqrt( @@ -417,6 +516,15 @@ impl<'a> Visitor<'a> { let result = self.append_operation_with_result(operation); self.insert_variable(out, result); } + Arithmetic::Tan(tan) => { + let value = self.get_variable(tan.input); + let result = self.append_operation_with_result(math_ods::tan( + self.context, + value, + self.location, + )); + self.insert_variable(out, result); + } Arithmetic::SaturatingSub(_) => { unreachable!("Should be removed by preprocessor") } diff --git a/crates/cubecl-ir/src/arithmetic.rs b/crates/cubecl-ir/src/arithmetic.rs index 67f0abfc9..9002f1d70 100644 --- a/crates/cubecl-ir/src/arithmetic.rs +++ b/crates/cubecl-ir/src/arithmetic.rs @@ -25,7 +25,19 @@ pub enum Arithmetic { Log1p(UnaryOperator), Cos(UnaryOperator), Sin(UnaryOperator), + Tan(UnaryOperator), Tanh(UnaryOperator), + Sinh(UnaryOperator), + Cosh(UnaryOperator), + ArcCos(UnaryOperator), + ArcSin(UnaryOperator), + ArcTan(UnaryOperator), + ArcSinh(UnaryOperator), + ArcCosh(UnaryOperator), + ArcTanh(UnaryOperator), + Degrees(UnaryOperator), + Radians(UnaryOperator), + ArcTan2(BinaryOperator), Powf(BinaryOperator), Powi(BinaryOperator), Sqrt(UnaryOperator), @@ -66,8 +78,20 @@ impl Display for Arithmetic { Arithmetic::Log1p(op) => write!(f, "{}.log_1p()", op.input), Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input), Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input), + Arithmetic::Tan(op) => write!(f, "{}.tan()", op.input), Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input), - Arithmetic::Powf(op) => write!(f, "{}.powf({})", op.lhs, op.rhs), + Arithmetic::Sinh(op) => write!(f, "{}.sinh()", op.input), + Arithmetic::Cosh(op) => write!(f, "{}.cosh()", op.input), + Arithmetic::ArcCos(op) => write!(f, "{}.acos()", op.input), + Arithmetic::ArcSin(op) => write!(f, "{}.asin()", op.input), + Arithmetic::ArcTan(op) => write!(f, "{}.atan()", op.input), + Arithmetic::ArcSinh(op) => write!(f, "{}.asinh()", op.input), + Arithmetic::ArcCosh(op) => write!(f, "{}.acosh()", op.input), + Arithmetic::ArcTanh(op) => write!(f, "{}.atanh()", op.input), + Arithmetic::Degrees(op) => write!(f, "{}.degrees()", op.input), + Arithmetic::Radians(op) => write!(f, "{}.radians()", op.input), + Arithmetic::ArcTan2(op) => write!(f, "{}.atan2({})", op.lhs, op.rhs), + Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs), Arithmetic::Powi(op) => write!(f, "{}.powi({})", op.lhs, op.rhs), Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input), Arithmetic::Round(op) => write!(f, "{}.round()", op.input), diff --git a/crates/cubecl-ir/src/processing.rs b/crates/cubecl-ir/src/processing.rs index 8f2968c29..fdb3c2681 100644 --- a/crates/cubecl-ir/src/processing.rs +++ b/crates/cubecl-ir/src/processing.rs @@ -109,9 +109,46 @@ impl ScopeProcessing { Arithmetic::Sin(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Tan(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } Arithmetic::Tanh(op) => { sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); } + Arithmetic::Sinh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::Cosh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcCos(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcSin(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcTan(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcSinh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcCosh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcTanh(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::Degrees(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::Radians(op) => { + sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap()); + } + Arithmetic::ArcTan2(op) => { + sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); + sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); + } Arithmetic::Powf(op) => { sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap()); sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap()); diff --git a/crates/cubecl-opt/src/instructions.rs b/crates/cubecl-opt/src/instructions.rs index 38dfe129d..0ed147ea0 100644 --- a/crates/cubecl-opt/src/instructions.rs +++ b/crates/cubecl-opt/src/instructions.rs @@ -84,7 +84,8 @@ impl Optimizer { | Arithmetic::Min(binary_operator) | Arithmetic::Remainder(binary_operator) | Arithmetic::Dot(binary_operator) - | Arithmetic::MulHi(binary_operator) => self.visit_binop(binary_operator, visit_read), + | Arithmetic::MulHi(binary_operator) + | Arithmetic::ArcTan2(binary_operator) => self.visit_binop(binary_operator, visit_read), Arithmetic::Abs(unary_operator) | Arithmetic::Exp(unary_operator) @@ -92,7 +93,18 @@ impl Optimizer { | Arithmetic::Log1p(unary_operator) | Arithmetic::Cos(unary_operator) | Arithmetic::Sin(unary_operator) + | Arithmetic::Tan(unary_operator) | Arithmetic::Tanh(unary_operator) + | Arithmetic::Sinh(unary_operator) + | Arithmetic::Cosh(unary_operator) + | Arithmetic::ArcCos(unary_operator) + | Arithmetic::ArcSin(unary_operator) + | Arithmetic::ArcTan(unary_operator) + | Arithmetic::ArcSinh(unary_operator) + | Arithmetic::ArcCosh(unary_operator) + | Arithmetic::ArcTanh(unary_operator) + | Arithmetic::Degrees(unary_operator) + | Arithmetic::Radians(unary_operator) | Arithmetic::Sqrt(unary_operator) | Arithmetic::Round(unary_operator) | Arithmetic::Floor(unary_operator) diff --git a/crates/cubecl-opt/src/passes/constant_prop.rs b/crates/cubecl-opt/src/passes/constant_prop.rs index 9c8fba0a6..d22cfe29f 100644 --- a/crates/cubecl-opt/src/passes/constant_prop.rs +++ b/crates/cubecl-opt/src/passes/constant_prop.rs @@ -416,7 +416,32 @@ fn try_const_eval_arithmetic(op: &mut Arithmetic) -> Option Arithmetic::Log1p(op) => const_eval_float!(op.input; num::Float::ln_1p), Arithmetic::Cos(op) => const_eval_float!(op.input; num::Float::cos), Arithmetic::Sin(op) => const_eval_float!(op.input; num::Float::sin), + Arithmetic::Tan(op) => const_eval_float!(op.input; num::Float::tan), Arithmetic::Tanh(op) => const_eval_float!(op.input; num::Float::tanh), + Arithmetic::Sinh(op) => const_eval_float!(op.input; num::Float::sinh), + Arithmetic::Cosh(op) => const_eval_float!(op.input; num::Float::cosh), + Arithmetic::ArcCos(op) => const_eval_float!(op.input; num::Float::acos), + Arithmetic::ArcSin(op) => const_eval_float!(op.input; num::Float::asin), + Arithmetic::ArcTan(op) => const_eval_float!(op.input; num::Float::atan), + Arithmetic::ArcSinh(op) => const_eval_float!(op.input; num::Float::asinh), + Arithmetic::ArcCosh(op) => const_eval_float!(op.input; num::Float::acosh), + Arithmetic::ArcTanh(op) => const_eval_float!(op.input; num::Float::atanh), + Arithmetic::Degrees(op) => const_eval_float!(op.input; num::Float::to_degrees), + Arithmetic::Radians(op) => const_eval_float!(op.input; num::Float::to_radians), + Arithmetic::ArcTan2(op) => { + use ConstantScalarValue::*; + if let (Some(lhs), Some(rhs)) = (op.lhs.as_const(), op.rhs.as_const()) { + let rhs = rhs.cast_to(lhs.storage_type()); + Some(match (lhs, rhs) { + (Float(lhs, kind), Float(rhs, _)) => { + ConstantScalarValue::Float(lhs.atan2(rhs), kind) + } + _ => unreachable!(), + }) + } else { + None + } + } Arithmetic::Sqrt(op) => const_eval_float!(op.input; num::Float::sqrt), Arithmetic::Round(op) => const_eval_float!(op.input; num::Float::round), Arithmetic::Floor(op) => const_eval_float!(op.input; num::Float::floor), diff --git a/crates/cubecl-spirv/src/arithmetic.rs b/crates/cubecl-spirv/src/arithmetic.rs index 642351e6d..f96de0c07 100644 --- a/crates/cubecl-spirv/src/arithmetic.rs +++ b/crates/cubecl-spirv/src/arithmetic.rs @@ -302,6 +302,14 @@ impl SpirvCompiler { } }) } + Arithmetic::Tan(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::tan(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } Arithmetic::Tanh(op) => { self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { T::tanh(b, ty, input, out); @@ -310,6 +318,94 @@ impl SpirvCompiler { } }) } + Arithmetic::Sinh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::sinh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::Cosh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::cosh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcCos(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::acos(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcSin(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::asin(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcTan(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::atan(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcSinh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::asinh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcCosh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::acosh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcTanh(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::atanh(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::Degrees(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::degrees(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::Radians(op) => { + self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| { + T::radians(b, ty, input, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } + Arithmetic::ArcTan2(op) => { + self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { + T::atan2(b, ty, lhs, rhs, out); + if matches!(out_ty.elem(), Elem::Relaxed) { + b.decorate(out, Decoration::RelaxedPrecision, []); + } + }) + } // No powi for Vulkan, just auto-cast to float Arithmetic::Powf(op) | Arithmetic::Powi(op) => { self.compile_binary_op(op, out, uniform, |b, out_ty, ty, lhs, rhs, out| { diff --git a/crates/cubecl-spirv/src/extensions.rs b/crates/cubecl-spirv/src/extensions.rs index 303d69119..c2093b0fd 100644 --- a/crates/cubecl-spirv/src/extensions.rs +++ b/crates/cubecl-spirv/src/extensions.rs @@ -12,7 +12,19 @@ pub trait TargetExtensions { fn ceil(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn sin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn cos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn tan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn sinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn cosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn asin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn acos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn atan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn asinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn acosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn atanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn degrees(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn radians(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); + fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word); fn exp(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); fn log(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word); @@ -65,10 +77,58 @@ pub mod glcompute { b.gl_cos_id(ty, Some(out), input).unwrap(); } + fn tan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_tan_id(ty, Some(out), input).unwrap(); + } + fn tanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { b.gl_tanh_id(ty, Some(out), input).unwrap(); } + fn sinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_sinh_id(ty, Some(out), input).unwrap(); + } + + fn cosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_cosh_id(ty, Some(out), input).unwrap(); + } + + fn asin(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_asin_id(ty, Some(out), input).unwrap(); + } + + fn acos(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_acos_id(ty, Some(out), input).unwrap(); + } + + fn atan(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_atan_id(ty, Some(out), input).unwrap(); + } + + fn asinh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_asinh_id(ty, Some(out), input).unwrap(); + } + + fn acosh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_acosh_id(ty, Some(out), input).unwrap(); + } + + fn atanh(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_atanh_id(ty, Some(out), input).unwrap(); + } + + fn degrees(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_degrees_id(ty, Some(out), input).unwrap(); + } + + fn radians(b: &mut SpirvCompiler, ty: Word, input: Word, out: Word) { + b.cl_radians_id(ty, Some(out), input).unwrap(); + } + + fn atan2(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { + b.cl_atan2_id(ty, Some(out), lhs, rhs).unwrap(); + } + fn pow(b: &mut SpirvCompiler, ty: Word, lhs: Word, rhs: Word, out: Word) { b.gl_pow_id(ty, Some(out), lhs, rhs).unwrap(); } diff --git a/crates/cubecl-std/src/lib.rs b/crates/cubecl-std/src/lib.rs index 91090e546..604be5075 100644 --- a/crates/cubecl-std/src/lib.rs +++ b/crates/cubecl-std/src/lib.rs @@ -6,6 +6,9 @@ pub use reinterpret_slice::*; mod fast_math; pub use fast_math::*; +mod trigonometry; +pub use trigonometry::*; + mod option; pub use option::*; diff --git a/crates/cubecl-std/src/tests/mod.rs b/crates/cubecl-std/src/tests/mod.rs index cd9c92a32..bcd799585 100644 --- a/crates/cubecl-std/src/tests/mod.rs +++ b/crates/cubecl-std/src/tests/mod.rs @@ -1,5 +1,6 @@ pub mod reinterpret_slice; pub mod tensor; +pub mod trigonometry; #[macro_export] macro_rules! testgen { @@ -9,6 +10,7 @@ macro_rules! testgen { use half::{bf16, f16}; cubecl_std::testgen_reinterpret_slice!(); + cubecl_std::testgen_trigonometry!(); } }; } diff --git a/crates/cubecl-std/src/tests/trigonometry.rs b/crates/cubecl-std/src/tests/trigonometry.rs new file mode 100644 index 000000000..97c0e3af1 --- /dev/null +++ b/crates/cubecl-std/src/tests/trigonometry.rs @@ -0,0 +1,150 @@ +use cubecl::prelude::*; +use cubecl_core as cubecl; +use std::f32::consts::{PI, TAU}; + +use crate::trigonometry::*; + +#[cube(launch_unchecked)] +fn kernel_to_degrees(input: &Array, output: &mut Array) { + if UNIT_POS < input.len() { + output[UNIT_POS] = to_degrees::(input[UNIT_POS]); + } +} + +pub fn test_to_degrees(client: ComputeClient) { + let input_data = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU]; + let expected = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0]; + + let input = client.create(f32::as_bytes(&input_data)); + let output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_to_degrees::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&output, input_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_to_radians(input: &Array, output: &mut Array) { + if UNIT_POS < input.len() { + output[UNIT_POS] = to_radians::(input[UNIT_POS]); + } +} + +pub fn test_to_radians(client: ComputeClient) { + let input_data = vec![0.0, 30.0, 45.0, 90.0, 180.0, 360.0]; + let expected = vec![0.0, PI / 6.0, PI / 4.0, PI / 2.0, PI, TAU]; + + let input = client.create(f32::as_bytes(&input_data)); + let output = client.empty(input_data.len() * core::mem::size_of::()); + + unsafe { + kernel_to_radians::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(input_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&input, input_data.len(), 1), + ArrayArg::from_raw_parts::(&output, input_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[cube(launch_unchecked)] +fn kernel_hypot(x: &Array, y: &Array, output: &mut Array) { + if UNIT_POS < x.len() { + output[UNIT_POS] = hypot::(x[UNIT_POS], y[UNIT_POS]); + } +} + +pub fn test_hypot(client: ComputeClient) { + let x_data = vec![3.0, 0.0, 1.0, 5.0, 0.0]; + let y_data = vec![4.0, 1.0, 1.0, 12.0, 0.0]; + let expected = vec![5.0, 1.0, 1.4142135623730951, 13.0, 0.0]; + + let x = client.create(f32::as_bytes(&x_data)); + let y = client.create(f32::as_bytes(&y_data)); + let output = client.empty(x_data.len() * core::mem::size_of::()); + + unsafe { + kernel_hypot::launch_unchecked::( + &client, + CubeCount::Static(1, 1, 1), + CubeDim::new(x_data.len() as u32, 1, 1), + ArrayArg::from_raw_parts::(&x, x_data.len(), 1), + ArrayArg::from_raw_parts::(&y, y_data.len(), 1), + ArrayArg::from_raw_parts::(&output, x_data.len(), 1), + ); + } + + let actual = client.read_one(output); + let actual = f32::from_bytes(&actual); + + for (i, (&expected_val, &actual_val)) in expected.iter().zip(actual.iter()).enumerate() { + assert!( + (expected_val - actual_val).abs() < 1e-5, + "Hypot test {} failed: expected {}, got {}", + i, + expected_val, + actual_val + ); + } +} + +#[macro_export] +macro_rules! testgen_trigonometry { + () => { + mod trigonometry { + use super::*; + use $crate::tests::trigonometry::*; + + #[test] + fn test_to_degrees_conversion() { + let client = TestRuntime::client(&Default::default()); + test_to_degrees::(client); + } + + #[test] + fn test_to_radians_conversion() { + let client = TestRuntime::client(&Default::default()); + test_to_radians::(client); + } + + #[test] + fn test_hypot_computation() { + let client = TestRuntime::client(&Default::default()); + test_hypot::(client); + } + } + }; +} diff --git a/crates/cubecl-std/src/trigonometry.rs b/crates/cubecl-std/src/trigonometry.rs new file mode 100644 index 000000000..355bc101b --- /dev/null +++ b/crates/cubecl-std/src/trigonometry.rs @@ -0,0 +1,61 @@ +//! Trigonometric functions and utilities for CubeCL. +//! +//! This module provides basic trigonometric operations and angle conversion utilities +//! that can be used in all GPU kernels. + +use core::f32; +use cubecl::prelude::*; +use cubecl_core as cubecl; + +/// Converts an angle from radians to degrees. +/// +/// # Example +/// +/// ```rust,ignore +/// let radians = F::new(std::f32::consts::PI); +/// let degrees = to_degrees(radians); +/// assert!((degrees - F::new(180.0)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn to_degrees(val: F) -> F { + val * F::new(180.0 / f32::consts::PI) +} + +/// Converts an angle from degrees to radians. +/// +/// # Example +/// +/// ```rust,ignore +/// let degrees = F::new(180.0); +/// let radians = to_radians(degrees); +/// assert!((radians - F::new(std::f32::consts::PI)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn to_radians(val: F) -> F { + val * F::new(f32::consts::PI / 180.0) +} + +/// Computes the hypotenuse of a right triangle given the lengths of the other two sides. +/// +/// This function computes `sqrt(x² + y²)` in a numerically stable way that avoids +/// overflow and underflow issues. +/// +/// # Arguments +/// +/// * `x` - Length of one side +/// * `y` - Length of the other side +/// +/// # Returns +/// +/// The length of the hypotenuse +/// +/// # Example +/// +/// ```rust,ignore +/// let hyp = hypot(F::new(3.0), F::new(4.0)); +/// assert!((hyp - F::new(5.0)).abs() < F::new(1e-6)); +/// ``` +#[cube] +pub fn hypot(x: F, y: F) -> F { + F::sqrt(x * x + y * y) +} diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 19c57f9c1..f3eebe272 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -759,10 +759,59 @@ impl WgslCompiler { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Tan(op) => instructions.push(wgsl::Instruction::Tan { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), cube::Arithmetic::Tanh(op) => instructions.push(wgsl::Instruction::Tanh { input: self.compile_variable(op.input), out: self.compile_variable(out), }), + cube::Arithmetic::Sinh(op) => instructions.push(wgsl::Instruction::Sinh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::Cosh(op) => instructions.push(wgsl::Instruction::Cosh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcCos(op) => instructions.push(wgsl::Instruction::ArcCos { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcSin(op) => instructions.push(wgsl::Instruction::ArcSin { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcTan(op) => instructions.push(wgsl::Instruction::ArcTan { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcSinh(op) => instructions.push(wgsl::Instruction::ArcSinh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcCosh(op) => instructions.push(wgsl::Instruction::ArcCosh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcTanh(op) => instructions.push(wgsl::Instruction::ArcTanh { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::Degrees(op) => instructions.push(wgsl::Instruction::Degrees { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::Radians(op) => instructions.push(wgsl::Instruction::Radians { + input: self.compile_variable(op.input), + out: self.compile_variable(out), + }), + cube::Arithmetic::ArcTan2(op) => instructions.push(wgsl::Instruction::ArcTan2 { + lhs: self.compile_variable(op.lhs), + rhs: self.compile_variable(op.rhs), + out: self.compile_variable(out), + }), // No powi in WGSL cube::Arithmetic::Powf(op) | cube::Arithmetic::Powi(op) => { instructions.push(wgsl::Instruction::Powf { diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index 2d252d167..0f6837901 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -117,10 +117,59 @@ pub enum Instruction { input: Variable, out: Variable, }, + Tan { + input: Variable, + out: Variable, + }, Tanh { input: Variable, out: Variable, }, + Sinh { + input: Variable, + out: Variable, + }, + Cosh { + input: Variable, + out: Variable, + }, + ArcCos { + input: Variable, + out: Variable, + }, + ArcSin { + input: Variable, + out: Variable, + }, + ArcTan { + input: Variable, + out: Variable, + }, + ArcSinh { + input: Variable, + out: Variable, + }, + ArcCosh { + input: Variable, + out: Variable, + }, + ArcTanh { + input: Variable, + out: Variable, + }, + Degrees { + input: Variable, + out: Variable, + }, + Radians { + input: Variable, + out: Variable, + }, + ArcTan2 { + lhs: Variable, + rhs: Variable, + out: Variable, + }, Powf { lhs: Variable, rhs: Variable, @@ -611,6 +660,10 @@ impl Display for Instruction { let out = out.fmt_left(); writeln!(f, "{out} = sin({input});") } + Instruction::Tan { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = tan({input});") + } Instruction::Tanh { input, out } => { #[cfg(target_os = "macos")] let result = super::call_safe_tanh(f, input, out); @@ -622,6 +675,50 @@ impl Display for Instruction { result } + Instruction::Sinh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = sinh({input});") + } + Instruction::Cosh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = cosh({input});") + } + Instruction::ArcCos { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = acos({input});") + } + Instruction::ArcSin { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = asin({input});") + } + Instruction::ArcTan { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = atan({input});") + } + Instruction::ArcSinh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = asinh({input});") + } + Instruction::ArcCosh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = acosh({input});") + } + Instruction::ArcTanh { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = atanh({input});") + } + Instruction::Degrees { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = degrees({input});") + } + Instruction::Radians { input, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = radians({input});") + } + Instruction::ArcTan2 { lhs, rhs, out } => { + let out = out.fmt_left(); + writeln!(f, "{out} = atan2({lhs}, {rhs});") + } Instruction::Recip { input, out } => { let item = input.item(); let out = out.fmt_left();