-
Notifications
You must be signed in to change notification settings - Fork 110
feat: add trigonometric functions #861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: add trigonometric functions #861
Conversation
This just makes the compiler happy and is not yet tested!
|
Not sure if it might be a good idea to pull all the trigonometry functions (including the existing ones) into a new enum for the IR, because that's a lot of added stuff, and it's now enough to warrant its own category I think. We should at some point also think about applying a similar separation to the compilers themselves, but that's a larger rework that wouldn't go into this PR. I'm currently looking into the MLIR functions for that stuff - the bindings are auto-generated I believe, so if the functions don't exist, they might be in a different namespace or need a different way to handle them. As for the float types, casting to float for unsupported float types (i.e. F16, BF16) is reasonable. There do appear to be double versions of the functions, so that will work natively. |
let result = self.append_operation_with_result(operation); | ||
self.insert_variable(out, result); | ||
} | ||
Arithmetic::Sinh(_sinh) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The relevant functions are actually in the dialect::ods::math
module (not dialect::ods::llvm
). So you'll need to import that and adjust as appropriate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I implemented everything. Now I get this error message when running tests:
error: cannot be converted to LLVM IR: missing LLVMTranslationDialectInterface
registration for dialect for op: math.atan2
However, according to code hints everything seems fine...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have to tag in @marcantoinem here because I don't understand much about the MLIR backend. I'm guessing the dialect needs to be registered somewhere but not sure about the details.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, I didn't see that there is a dialect other than arith for more complex mathematical operation, I was using llvm intrisic instead. The error you got is because the pass to transform the math dialect to llvm intrisic is not registered. To fix it you just needs to add pass_manager.add_pass(pass::conversion::create_math_to_llvm());
in src/compiler/module.rs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Locally, I still get the same error, even after run cargo clean, deleting the target directory and passing --features mlir-dump to cargo test -p cubecl-cpu. I also checked whether the line was called and it seems like it did - but nothing happend (maybe because I missed another cache?)
I pushed a commit with the added line, maybe it runs on your machine?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried your code and I get the same error the pass seems to do nothing. I didn't use it anywhere else in the MLIR compiler so it must be a problem on the C API of MLIR that is not running correctly this pass. For the moment you can use LLVM intrisic I think, it will needs more investigation to find why registering the math_to_llvm doesn't work. The only problem with using llvm intrisic instead of math is for portability if we want to port the MLIR compiler to GPU,. but it is not a concern right now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arc-* operations are not available in the llvm intrinsic, hence I disabled them for now until the math module becomes available.
precision for some of the new trigonometric operations
Must I implement this in this PR or will this be part of another refactor?
Can you please double-check my implementation? I don't really understand at which point the right functions are written to the shader for each dialect... |
I think that would be better to do after all the outstanding work has been merged, so I'll do it separately. #[cube]
pub fn to_degrees<F: Float>(val: F) -> F {
val * F::new(180.0 / f32::PI)
} It would keep it in one place rather than implemented separately for each compiler. |
CUDA is already correct, not sure what the WGSL behaviour even is because it's an experimental extension. It might just support f16 overloads on those functions by default, but I can't yet test it because of some issues with features on Vulkan. |
There are intrinsics for WGSL and SPIR-V (at least rspirv-ext), also Rusts f32 and f64 do have to_degrees and to_radians operations. Also, as a user, I would probably search in the Float Namespace for these functions instead of the cubecl-std. That's why I think having them there would make still sense. I removed the dialect specific compile_instruction_xxx_scalar calls and instead hardcode the operations directly for the cubecl-cpp, since all CUDA, HIP and Metal should convert implicit to the right type and share the same syntax. |
Before things are getting stale, I mark this as ready to review. Currently, 21 tests are failing, 20 of them related to the missing CPU implementation of the inverse trigonometric functions (arc-*) and one is that
For both problems I can't do more, since the first one needs to be fixed by @marcantoinem I guess? and the second one needs to be fixed by a design decision whether If I should do more, let me know. |
that epsilon looks too tight for f16, might have another bug in the precision-aware comparison like when I forgot an |
Change the epsilon for to_degree() to 0.3, which checks out with the f16 maximum error for our valid tests.
Move to_degrees and to_radians there
I changed the runtime tests for I moved the |
Since I used a lot of PI and PI*2 etc., especially in the Tests, I wondered whether it would be useful to expose constants under the Float trait. I know that rust f32 and f64 already have a lot of constants under ::consts and that there are also a lot of constants already implemented for the cubecl f16 type. |
Another thing: I wondered whether this function signature here would prevent the creation of super-precise f64 floats, since it is only possible to pass in f32. |
The code looks fine to me, but the tests don't pass on the CI. @relativityhd We may also remove some functions in the trigo modules, unsure they are necessary. |
@nathanielsimard yes the CI is failing because as I mentioned this PR is blocked by the MLIR melior project, because the math module there wont register correctly, as @marcantoinem mentioned.
Sure, just tell which ones I should remove. |
The ones that are not stricly necessary, using if statements is also very bad on most hardware, so we can remove the functions that have if statements. |
Sorry for the long wait - my vacation is over and time limited again... I've removed unnecessary trig functions from the std, only keeping hypot, to_radians and to_degrees since they are also present in rusts f32. This PR is still blocked by the MLIR melior project. I still get the following error when testing (I've removed the dummy code and enabled ods math support again, as @marcantoinem described):
@nathanielsimard What repository is responsible for this? Maybe I can check and fix it there. I am quite confused about the different LLVM repositories and packages used. |
|
Add trigonometric functions (atan2 etc.)
Adds the following functions for all Floats:
sinh
)cosh
)asin
)asinh
)acos
)acosh
)atan
)atanh
)atan2
)degrees
)radians
)Open Questions / Missing parts
LLVM MIR implementation
I have setup a placeholder currently with code which I assumed to work commented out.
It seems that support for these functions needs to be added in the tracel-llvm repository
, but I have no clue where.
Non f64/f32
At some point there is limited support for the "special" floats and for e.g. powf which I used as an example a conversion was needed to normal floats.
I am unsure at which points this is necessary and where I can find out whether these conversions are necessary.
Metal safe operations
Since I used the existing sin, cos, tanh and powf functions as examples on how to add the other functions, i stumbled across the implementation for metal for tanh:
I couldn't find anything about this "safe" version in the metal documentation, but I am clearly not an expert.
For which functions is a "safe" implementation needed and which are fine without?
Validate your PR with burn.
It is important that you make sure that you don't introduce any bugs in burn.
Instructions