-
Notifications
You must be signed in to change notification settings - Fork 112
generic_float for Float8E8M0
#4403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: mlir_mxfp4_test
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
|
|
||
| #ifndef MIGRAPHX_GUARD_RTGLIB_FP8E8M0FN_HPP | ||
| #define MIGRAPHX_GUARD_RTGLIB_FP8E8M0FN_HPP | ||
|
|
||
| #include <migraphx/generic_float.hpp> | ||
| #include <migraphx/config.hpp> | ||
|
|
||
| namespace migraphx { | ||
| inline namespace MIGRAPHX_INLINE_NS { | ||
|
|
||
| using fp8e8m0 = migraphx::generic_float<0, 8, 0>; | ||
|
|
||
| } // namespace MIGRAPHX_INLINE_NS | ||
| } // namespace migraphx | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -71,6 +71,53 @@ | |
| using type = std::uint64_t; | ||
| }; | ||
|
|
||
| // CRTP base for operators | ||
| template <class Derived> | ||
| struct generic_float_operators | ||
| { | ||
|
|
||
| // NOLINTNEXTLINE | ||
| #define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \ | ||
| friend constexpr Derived& operator op(Derived & lhs, const Derived & rhs) \ | ||
| { \ | ||
| float self = lhs; \ | ||
| float frhs = rhs; \ | ||
| self op frhs; \ | ||
| lhs = self; \ | ||
| return lhs; \ | ||
| } | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=) | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=) | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=) | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=) | ||
|
|
||
| // NOLINTNEXTLINE | ||
| #define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \ | ||
| friend constexpr Derived operator op(const Derived& x, const Derived& y) \ | ||
| { \ | ||
| return Derived(float(x) op float(y)); \ | ||
| } | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*) | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-) | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+) | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/) | ||
|
|
||
| // NOLINTNEXTLINE | ||
| #define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \ | ||
| friend constexpr bool operator op(const Derived& x, const Derived& y) \ | ||
| { \ | ||
| return float(x) op float(y); \ | ||
| } | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<) | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=) | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>) | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=) | ||
|
|
||
| protected: | ||
| // prohibit creation of this base object | ||
| generic_float_operators() = default; | ||
| }; | ||
|
|
||
| struct float32_parts | ||
| { | ||
| unsigned int mantissa : 23; | ||
|
|
@@ -92,6 +139,7 @@ | |
|
|
||
| template <unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0> | ||
| struct __attribute__((packed, may_alias)) generic_float | ||
| : generic_float_operators<generic_float<MantissaSize, ExponentSize, Flags>> | ||
| { | ||
| using type = typename unsigned_type<bit_ceil( | ||
| integer_divide_ceil(MantissaSize + ExponentSize + 1, 8))>::type; | ||
|
|
@@ -228,6 +276,8 @@ | |
| return exponent == all_ones<ExponentSize>() and mantissa != 0; | ||
| } | ||
|
|
||
| constexpr bool has_infinity() const noexcept { return true; } | ||
|
|
||
| constexpr bool is_finite() const noexcept { return exponent != all_ones<ExponentSize>(); } | ||
|
|
||
| constexpr operator float() const noexcept { return this->to_float(); } | ||
|
|
@@ -296,40 +346,6 @@ | |
| x.mantissa++; | ||
| return generic_float{x.to_float() - 1.0f}; | ||
| } | ||
| // NOLINTNEXTLINE | ||
| #define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \ | ||
| constexpr generic_float& operator op(const generic_float & rhs) \ | ||
| { \ | ||
| float self = *this; \ | ||
| float frhs = rhs; \ | ||
| self op frhs; \ | ||
| *this = generic_float(self); \ | ||
| return *this; \ | ||
| } | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=) | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=) | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=) | ||
| MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=) | ||
| // NOLINTNEXTLINE | ||
| #define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \ | ||
| friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \ | ||
| { \ | ||
| return generic_float(float(x) op float(y)); \ | ||
| } | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*) | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-) | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+) | ||
| MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/) | ||
| // NOLINTNEXTLINE | ||
| #define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \ | ||
| friend constexpr bool operator op(const generic_float& x, const generic_float& y) \ | ||
| { \ | ||
| return float(x) op float(y); \ | ||
| } | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<) | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=) | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>) | ||
| MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=) | ||
|
|
||
| friend constexpr bool operator==(const generic_float& x, const generic_float& y) | ||
| { | ||
|
|
@@ -363,6 +379,148 @@ | |
| } | ||
| }; | ||
|
|
||
| template <unsigned int Flags> | ||
| struct __attribute__((packed, may_alias)) generic_float<0, 8, Flags> | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is |
||
| : generic_float_operators<generic_float<0, 8, Flags>> | ||
| { | ||
| uint8_t exponent; | ||
|
|
||
| static constexpr int exponent_bias() { return all_ones<7>(); } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: The |
||
|
|
||
| explicit constexpr generic_float(float f = 1.0) noexcept { from_float(get_parts(f)); } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you be defining a |
||
|
|
||
| constexpr generic_float& operator=(float f) noexcept | ||
| { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this assign from a negative float value? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can, I set it up to ignore the sign bit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I meant logically speaking, the answer to this question should be: "negative" :-) |
||
| from_float(get_parts(f)); | ||
| return *this; | ||
| } | ||
|
|
||
| // No sign for this type | ||
| constexpr generic_float operator-() const noexcept { return snan(); } | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there is no unary There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, fundamentally, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you can allow an assignment from a negative float, then Therefore, no assignment from a negative value should be allowed. An interesting problem :-) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should probably only be the equality binary operation. |
||
|
|
||
| constexpr generic_float operator+() const noexcept { return *this; } | ||
|
|
||
| constexpr float to_float() const noexcept | ||
| { | ||
| float32_parts f{}; | ||
| f.sign = 0; | ||
| if(exponent == 0) | ||
| { | ||
| // 2^(-127) is a fp32 denormal number | ||
lakhinderwalia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| f.mantissa = 1; | ||
| f.mantissa = f.mantissa << (float32_parts::mantissa_width() - 1); | ||
| } | ||
| else if(exponent == all_ones<8>()) | ||
| { | ||
| // setting to fp32 qNaN | ||
| f.mantissa = (1 << (float32_parts::mantissa_width() - 1)) + 1; | ||
|
Check warning on line 416 in src/include/migraphx/generic_float.hpp
|
||
| } | ||
| else | ||
| { | ||
| f.mantissa = 0; | ||
| } | ||
| f.exponent = exponent; | ||
| return f.to_float(); | ||
| } | ||
|
|
||
| /** | ||
| * Extracts only exponent bits from float. | ||
| * All fp32 denorm numbers will go to fp8e8m0{2^(-127)}. | ||
| * All fp32 NaN and infinity go to fp8e8m0{NaN}. | ||
| */ | ||
| constexpr void from_float(float32_parts f) noexcept { exponent = f.exponent; } | ||
|
|
||
| // No denorm numbers in fp8e8m0. | ||
| constexpr bool is_normal() const noexcept { return not is_nan(); } | ||
|
|
||
| // No infinity numbers in fp8e8m0. | ||
| constexpr bool is_inf() const noexcept { return false; } | ||
|
|
||
| constexpr bool is_nan() const noexcept { return exponent == all_ones<8>(); } | ||
|
|
||
| constexpr bool is_finite() const noexcept { return not is_nan(); } | ||
|
|
||
| constexpr bool has_infinity() const noexcept | ||
| { | ||
| return false; | ||
| ; | ||
| } | ||
|
|
||
| constexpr operator float() const noexcept { return this->to_float(); } | ||
|
|
||
| // doesn't have infinity, returning 2**0 | ||
| static constexpr generic_float infinity() | ||
| { | ||
| generic_float x{}; | ||
| x.exponent = all_ones<8>() >> 1u; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is no infinity for E8M0 so I tried to follow the convention integral types have for this https://en.cppreference.com/w/cpp/types/numeric_limits/infinity.html. Unfortunately, there is also no zero in E8M0 so this decision is arbitrary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree there is no |
||
| return x; | ||
| } | ||
|
|
||
| // only one NaN value | ||
| static constexpr generic_float snan() | ||
| { | ||
| generic_float x{}; | ||
| x.exponent = all_ones<8>(); | ||
| return x; | ||
| } | ||
|
|
||
| // only one NaN value | ||
| static constexpr generic_float qnan() { return snan(); } | ||
|
|
||
| // min value = 2**(-127) | ||
| static constexpr generic_float min() | ||
| { | ||
| generic_float x{}; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not define |
||
| x.exponent = 0; | ||
| return x; | ||
| } | ||
|
|
||
| // No subnormal numbers in FP8E8M0 | ||
| static constexpr generic_float denorm_min() { return min(); } | ||
|
|
||
| static constexpr generic_float lowest() { return min(); } | ||
|
|
||
| // max value = 2**(127) | ||
| static constexpr generic_float max() | ||
| { | ||
| generic_float x{}; | ||
| x.exponent = all_ones<8>() - 1; | ||
| return x; | ||
| } | ||
|
|
||
| // next number from 2**0 is 2**1 so epsilon is 2**0 | ||
| static constexpr generic_float epsilon() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I were to nit-pick, there is no |
||
| { | ||
| generic_float x{}; | ||
| x.exponent = all_ones<8>() >> 1u; | ||
| return x; | ||
| } | ||
|
|
||
| friend constexpr bool operator==(const generic_float& x, const generic_float& y) | ||
| { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not just say |
||
|
|
||
| return x.exponent == y.exponent; | ||
| } | ||
|
|
||
| friend constexpr bool operator!=(const generic_float& x, const generic_float& y) | ||
| { | ||
| return not(x == y); | ||
| } | ||
|
|
||
| constexpr generic_float& operator++() noexcept | ||
| { | ||
| ++exponent; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You perhaps should check for a number to not become a |
||
| return *this; | ||
| } | ||
|
|
||
| const generic_float operator++(int) noexcept // NOLINT(readability-const-return-type) | ||
| { | ||
| generic_float temp = *this; | ||
| operator++(this->exponent); | ||
| return temp; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace MIGRAPHX_INLINE_NS | ||
| } // namespace migraphx | ||
|
|
||
|
|
@@ -373,7 +531,7 @@ | |
| class numeric_limits<migraphx::generic_float<M, E, F>> | ||
| { | ||
| public: | ||
| static constexpr bool has_infinity = true; | ||
| static constexpr bool has_infinity = not(M == 0 and E == 0); | ||
|
Check warning on line 534 in src/include/migraphx/generic_float.hpp
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tidy is being nitpicky here but It saves you the extra not if you make this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's actually a bug. Should probably just be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tidy is not suggesting |
||
| static constexpr migraphx::generic_float<M, E, F> epsilon() | ||
| { | ||
| return migraphx::generic_float<M, E, F>::epsilon(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A template parameter needs to be added for the sign. So it should become:
And the typedef should then be
using fp8e8m0 = migraphx::generic_float<0, 8, false>.