Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/api/include/migraphx/migraphx.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \
m(bf16_type, bf16) \
m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz)
m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) \
m(fp8e8m0_type, fp8e8m0)
// clang-format on

#ifdef __cplusplus
Expand Down
39 changes: 39 additions & 0 deletions src/include/migraphx/fp8e8m0.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#ifndef MIGRAPHX_GUARD_RTGLIB_FP8E8M0FN_HPP
#define MIGRAPHX_GUARD_RTGLIB_FP8E8M0FN_HPP

#include <migraphx/generic_float.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

using fp8e8m0 = migraphx::generic_float<0, 8, 0>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A template parameter needs to be added for the sign. So it should become:

template <unsigned int MantissaSize, unsigned int ExponentSize, bool Signed = true, unsigned int Flags = 0>
struct __attribute__((packed, may_alias)) generic_float;

And the typedef should then be using fp8e8m0 = migraphx::generic_float<0, 8, false>.


} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
228 changes: 193 additions & 35 deletions src/include/migraphx/generic_float.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,53 @@
using type = std::uint64_t;
};

// CRTP base for operators
template <class Derived>
struct generic_float_operators
{

// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \
friend constexpr Derived& operator op(Derived & lhs, const Derived & rhs) \
{ \
float self = lhs; \
float frhs = rhs; \
self op frhs; \
lhs = self; \
return lhs; \
}
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=)

// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \
friend constexpr Derived operator op(const Derived& x, const Derived& y) \
{ \
return Derived(float(x) op float(y)); \
}
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/)

// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \
friend constexpr bool operator op(const Derived& x, const Derived& y) \
{ \
return float(x) op float(y); \
}
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<)
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=)
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>)
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=)

protected:
// prohibit creation of this base object
generic_float_operators() = default;
};

struct float32_parts
{
unsigned int mantissa : 23;
Expand All @@ -92,6 +139,7 @@

template <unsigned int MantissaSize, unsigned int ExponentSize, unsigned int Flags = 0>
struct __attribute__((packed, may_alias)) generic_float
: generic_float_operators<generic_float<MantissaSize, ExponentSize, Flags>>
{
using type = typename unsigned_type<bit_ceil(
integer_divide_ceil(MantissaSize + ExponentSize + 1, 8))>::type;
Expand Down Expand Up @@ -228,6 +276,8 @@
return exponent == all_ones<ExponentSize>() and mantissa != 0;
}

constexpr bool has_infinity() const noexcept { return true; }

constexpr bool is_finite() const noexcept { return exponent != all_ones<ExponentSize>(); }

constexpr operator float() const noexcept { return this->to_float(); }
Expand Down Expand Up @@ -296,40 +346,6 @@
x.mantissa++;
return generic_float{x.to_float() - 1.0f};
}
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(op) \
constexpr generic_float& operator op(const generic_float & rhs) \
{ \
float self = *this; \
float frhs = rhs; \
self op frhs; \
*this = generic_float(self); \
return *this; \
}
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(*=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(-=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(+=)
MIGRAPHX_GENERIC_FLOAT_ASSIGN_OP(/=)
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_BINARY_OP(op) \
friend constexpr generic_float operator op(const generic_float& x, const generic_float& y) \
{ \
return generic_float(float(x) op float(y)); \
}
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(*)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(-)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(+)
MIGRAPHX_GENERIC_FLOAT_BINARY_OP(/)
// NOLINTNEXTLINE
#define MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(op) \
friend constexpr bool operator op(const generic_float& x, const generic_float& y) \
{ \
return float(x) op float(y); \
}
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<)
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(<=)
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>)
MIGRAPHX_GENERIC_FLOAT_COMPARE_OP(>=)

friend constexpr bool operator==(const generic_float& x, const generic_float& y)
{
Expand Down Expand Up @@ -363,6 +379,148 @@
}
};

template <unsigned int Flags>
struct __attribute__((packed, may_alias)) generic_float<0, 8, Flags>
Copy link
Contributor

@lakhinderwalia lakhinderwalia Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Flags used somewhere in E8M0 case here? I must have missed.

: generic_float_operators<generic_float<0, 8, Flags>>
{
uint8_t exponent;

static constexpr int exponent_bias() { return all_ones<7>(); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: The bias is just an integer, per the standard. it should ideally be just referred to by its value: 127, or something like that (a named constant). But defining it all_ones<7> is confusing.


explicit constexpr generic_float(float f = 1.0) noexcept { from_float(get_parts(f)); }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be defining a generic_float for other formats with a different default value?


constexpr generic_float& operator=(float f) noexcept
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this assign from a negative float value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can, I set it up to ignore the sign bit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant logically speaking, the answer to this question should be: "negative" :-)

from_float(get_parts(f));
return *this;
}

// No sign for this type
constexpr generic_float operator-() const noexcept { return snan(); }
Copy link
Contributor

@lakhinderwalia lakhinderwalia Oct 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is no unary - op available, then logically speaking a binary - or -= op is also the same category. One could do a binary x - y, which is say 2 - 2, and get some strange results!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, fundamentally, a - b == a + ( - b )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can allow an assignment from a negative float, then operator-() should just be a nop, but not a snan(). But that totally messes up a - b.

Therefore, no assignment from a negative value should be allowed.

An interesting problem :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should probably only be the equality binary operation.


constexpr generic_float operator+() const noexcept { return *this; }

constexpr float to_float() const noexcept
{
float32_parts f{};
f.sign = 0;
if(exponent == 0)
{
// 2^(-127) is a fp32 denormal number
f.mantissa = 1;
f.mantissa = f.mantissa << (float32_parts::mantissa_width() - 1);
}
else if(exponent == all_ones<8>())
{
// setting to fp32 qNaN
f.mantissa = (1 << (float32_parts::mantissa_width() - 1)) + 1;

Check warning on line 416 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]

Check warning on line 416 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]

Check warning on line 416 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]

Check warning on line 416 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]

Check warning on line 416 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

use of a signed integer operand with a binary bitwise operator [hicpp-signed-bitwise,-warnings-as-errors]
}
else
{
f.mantissa = 0;
}
f.exponent = exponent;
return f.to_float();
}

/**
* Extracts only exponent bits from float.
* All fp32 denorm numbers will go to fp8e8m0{2^(-127)}.
* All fp32 NaN and infinity go to fp8e8m0{NaN}.
*/
constexpr void from_float(float32_parts f) noexcept { exponent = f.exponent; }

// No denorm numbers in fp8e8m0.
constexpr bool is_normal() const noexcept { return not is_nan(); }

// No infinity numbers in fp8e8m0.
constexpr bool is_inf() const noexcept { return false; }

constexpr bool is_nan() const noexcept { return exponent == all_ones<8>(); }

constexpr bool is_finite() const noexcept { return not is_nan(); }

constexpr bool has_infinity() const noexcept
{
return false;
;
}

constexpr operator float() const noexcept { return this->to_float(); }

// doesn't have infinity, returning 2**0
static constexpr generic_float infinity()
{
generic_float x{};
x.exponent = all_ones<8>() >> 1u;
Copy link
Contributor

@lakhinderwalia lakhinderwalia Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This I don't get. In fp32, infinity has its exponent defined as 0xff. why is the above line shifting right? Looks like you are converting infinity to the max() here. Maybe just use max() in that case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no infinity for E8M0 so I tried to follow the convention integral types have for this https://en.cppreference.com/w/cpp/types/numeric_limits/infinity.html. Unfortunately, there is also no zero in E8M0 so this decision is arbitrary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree there is no zero involved here :-)
Should an exception be generated? Because it is the wrong math otherwise for E8M0.
Ideally the test could verify that an exception is being thrown.
These are hairy corner cases.

return x;
}

// only one NaN value
static constexpr generic_float snan()
{
generic_float x{};
x.exponent = all_ones<8>();
return x;
}

// only one NaN value
static constexpr generic_float qnan() { return snan(); }

// min value = 2**(-127)
static constexpr generic_float min()
{
generic_float x{};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not define generic_float natively -- rather than go through a default fp32 parameter in its constructor, subsequently its exponent is set!

x.exponent = 0;
return x;
}

// No subnormal numbers in FP8E8M0
static constexpr generic_float denorm_min() { return min(); }

static constexpr generic_float lowest() { return min(); }

// max value = 2**(127)
static constexpr generic_float max()
{
generic_float x{};
x.exponent = all_ones<8>() - 1;
return x;
}

// next number from 2**0 is 2**1 so epsilon is 2**0
static constexpr generic_float epsilon()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I were to nit-pick, there is no epsilon or a concept like that for the weight specific E8M0 format.

{
generic_float x{};
x.exponent = all_ones<8>() >> 1u;
return x;
}

friend constexpr bool operator==(const generic_float& x, const generic_float& y)
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just say x == y. Is that not compliable?


return x.exponent == y.exponent;
}

friend constexpr bool operator!=(const generic_float& x, const generic_float& y)
{
return not(x == y);
}

constexpr generic_float& operator++() noexcept
{
++exponent;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You perhaps should check for a number to not become a NaN. It should roll over.

return *this;
}

const generic_float operator++(int) noexcept // NOLINT(readability-const-return-type)
{
generic_float temp = *this;
operator++(this->exponent);
return temp;
}
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

Expand All @@ -373,7 +531,7 @@
class numeric_limits<migraphx::generic_float<M, E, F>>
{
public:
static constexpr bool has_infinity = true;
static constexpr bool has_infinity = not(M == 0 and E == 0);

Check warning on line 534 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

boolean expression can be simplified by DeMorgan's theorem [readability-simplify-boolean-expr,-warnings-as-errors]

Check warning on line 534 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

boolean expression can be simplified by DeMorgan's theorem [readability-simplify-boolean-expr,-warnings-as-errors]

Check warning on line 534 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

boolean expression can be simplified by DeMorgan's theorem [readability-simplify-boolean-expr,-warnings-as-errors]

Check warning on line 534 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

boolean expression can be simplified by DeMorgan's theorem [readability-simplify-boolean-expr,-warnings-as-errors]

Check warning on line 534 in src/include/migraphx/generic_float.hpp

View workflow job for this annotation

GitHub Actions / tidy

boolean expression can be simplified by DeMorgan's theorem [readability-simplify-boolean-expr,-warnings-as-errors]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tidy is being nitpicky here but It saves you the extra not if you make this (M == 0 or E == 0)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's actually a bug. Should probably just be M != 0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tidy is not suggesting (M == 0 or E == 0), it is suggesting M != 0 or E != 0.

static constexpr migraphx::generic_float<M, E, F> epsilon()
{
return migraphx::generic_float<M, E, F>::epsilon();
Expand Down
5 changes: 4 additions & 1 deletion src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <migraphx/half.hpp>
#include <migraphx/bf16.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/fp8e8m0.hpp>
#include <migraphx/serialize.hpp>
#include <migraphx/config.hpp>

Expand Down Expand Up @@ -68,7 +69,9 @@ struct MIGRAPHX_EXPORT shape
m(fp8e4m3fn_type, migraphx::fp8::fp8e4m3fn) \
m(fp8e5m2_type, migraphx::fp8::fp8e5m2) \
m(bf16_type, bf16) \
m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) // clang-format on
m(fp8e5m2fnuz_type, migraphx::fp8::fp8e5m2fnuz) \
m(fp8e8m0_type, fp8e8m0)
// clang-format on

#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
Expand Down
4 changes: 4 additions & 0 deletions src/include/migraphx/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <migraphx/bf16.hpp>
#include <migraphx/config.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/fp8e8m0.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -74,6 +75,9 @@ MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, migraphx::fp8::fp8e5m2)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_signed, migraphx::fp8::fp8e5m2)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, migraphx::fp8::fp8e5m2)

MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_floating_point, fp8e8m0)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(is_arithmetic, fp8e8m0)

template <class T>
using accumulator_type =
std::conditional_t<is_floating_point<T>{},
Expand Down
3 changes: 2 additions & 1 deletion src/netron_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ int get_onnx_type(shape::type_t s_type)
case shape::fp8e5m2_type: return 19;
case shape::fp8e5m2fnuz_type: return 20;
case shape::tuple_type: return 0;
case shape::fp4x2_type: return 21; // TODO update this when the type is added
case shape::fp4x2_type: return 23;
case shape::fp8e8m0_type: return 24;
}
MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported");
}
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ static rocblas_datatype get_type(shape::type_t type)
case shape::int64_type:
case shape::uint64_type: MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
case shape::bf16_type: return rocblas_datatype_bf16_r;
case shape::fp8e8m0_type:
}

MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
Expand Down
1 change: 1 addition & 0 deletions src/targets/gpu/hip_gemm_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ static hipDataType get_type_hipblas(shape::type_t type)
case shape::int64_type:
case shape::uint64_type: MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!");
case shape::bf16_type: return HIP_R_16BF;
case shape::fp8e8m0_type:
}

MIGRAPHX_THROW("HIPBLAS_GEMM: data type not supported!");
Expand Down
Loading
Loading