Skip to content

Commit db714e3

Browse files
committed
AVX512 exp/log
1 parent 1419d9c commit db714e3

File tree

2 files changed

+81
-12
lines changed

2 files changed

+81
-12
lines changed

shared/libebm/compute/avx512f_ebm/avx512f_32.cpp

Lines changed: 81 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "Registration.hpp"
2626
#include "Objective.hpp"
2727

28+
#include "math.hpp"
2829
#include "approximate_math.hpp"
2930
#include "compute_wrapper.hpp"
3031

@@ -102,6 +103,10 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
102103
return Avx512f_32_Int(_mm512_add_epi32(m_data, other.m_data));
103104
}
104105

106+
inline Avx512f_32_Int operator-(const Avx512f_32_Int& other) const noexcept {
107+
return Avx512f_32_Int(_mm512_sub_epi32(m_data, other.m_data));
108+
}
109+
105110
inline Avx512f_32_Int operator*(const T& other) const noexcept {
106111
return Avx512f_32_Int(_mm512_mullo_epi32(m_data, _mm512_set1_epi32(other)));
107112
}
@@ -118,6 +123,16 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
118123
return Avx512f_32_Int(_mm512_and_si512(m_data, other.m_data));
119124
}
120125

126+
inline Avx512f_32_Int operator|(const Avx512f_32_Int& other) const noexcept {
127+
return Avx512f_32_Int(_mm512_or_si512(m_data, other.m_data));
128+
}
129+
130+
friend inline Avx512f_32_Int IfThenElse(
131+
const __mmask16& cmp, const Avx512f_32_Int& trueVal, const Avx512f_32_Int& falseVal) noexcept {
132+
return Avx512f_32_Int(_mm512_castps_si512(
133+
_mm512_mask_blend_ps(cmp, _mm512_castsi512_ps(falseVal.m_data), _mm512_castsi512_ps(trueVal.m_data))));
134+
}
135+
121136
friend inline Avx512f_32_Int PermuteForInterleaf(const Avx512f_32_Int& val) noexcept {
122137
// this function permutes the values into positions that the Interleaf function expects
123138
// but for any SIMD implementation the positions can be variable as long as they work together
@@ -137,7 +152,28 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
137152
static_assert(std::is_standard_layout<Avx512f_32_Int>::value && std::is_trivially_copyable<Avx512f_32_Int>::value,
138153
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
139154

155+
template<bool bNegateInput = false,
156+
bool bNaNPossible = true,
157+
bool bUnderflowPossible = true,
158+
bool bOverflowPossible = true>
159+
inline Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept;
160+
template<bool bNegateOutput = false,
161+
bool bNaNPossible = true,
162+
bool bNegativePossible = true,
163+
bool bZeroPossible = true,
164+
bool bPositiveInfinityPossible = true>
165+
inline Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept;
166+
140167
struct alignas(k_cAlignment) Avx512f_32_Float final {
168+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
169+
friend Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept;
170+
template<bool bNegateOutput,
171+
bool bNaNPossible,
172+
bool bNegativePossible,
173+
bool bZeroPossible,
174+
bool bPositiveInfinityPossible>
175+
friend Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept;
176+
141177
using T = float;
142178
using TPack = __m512;
143179
using TInt = Avx512f_32_Int;
@@ -155,6 +191,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
155191
inline Avx512f_32_Float(const double val) noexcept : m_data(_mm512_set1_ps(static_cast<T>(val))) {}
156192
inline Avx512f_32_Float(const float val) noexcept : m_data(_mm512_set1_ps(static_cast<T>(val))) {}
157193
inline Avx512f_32_Float(const int val) noexcept : m_data(_mm512_set1_ps(static_cast<T>(val))) {}
194+
explicit Avx512f_32_Float(const Avx512f_32_Int& val) : m_data(_mm512_cvtepi32_ps(val.m_data)) {}
158195

159196
inline Avx512f_32_Float operator+() const noexcept { return *this; }
160197

@@ -231,6 +268,10 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
231268
return Avx512f_32_Float(val) / other;
232269
}
233270

271+
friend inline __mmask16 operator<=(const Avx512f_32_Float& left, const Avx512f_32_Float& right) noexcept {
272+
return _mm512_cmp_ps_mask(left.m_data, right.m_data, _CMP_LE_OQ);
273+
}
274+
234275
inline static Avx512f_32_Float Load(const T* const a) noexcept { return Avx512f_32_Float(_mm512_load_ps(a)); }
235276

236277
inline void Store(T* const a) const noexcept { _mm512_store_ps(a, m_data); }
@@ -545,6 +586,11 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
545586
return Avx512f_32_Float(_mm512_mask_blend_ps(mask, falseVal.m_data, trueVal.m_data));
546587
}
547588

589+
friend inline Avx512f_32_Float IfThenElse(
590+
const __mmask16& cmp, const Avx512f_32_Float& trueVal, const Avx512f_32_Float& falseVal) noexcept {
591+
return Avx512f_32_Float(_mm512_mask_blend_ps(cmp, falseVal.m_data, trueVal.m_data));
592+
}
593+
548594
friend inline Avx512f_32_Float IfEqual(const Avx512f_32_Float& cmp1,
549595
const Avx512f_32_Float& cmp2,
550596
const Avx512f_32_Float& trueVal,
@@ -572,6 +618,20 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
572618
return Avx512f_32_Float(_mm512_mask_blend_ps(mask, falseVal.m_data, trueVal.m_data));
573619
}
574620

621+
static inline __mmask16 ReinterpretInt(const __mmask16& val) noexcept { return val; }
622+
623+
static inline Avx512f_32_Int ReinterpretInt(const Avx512f_32_Float& val) noexcept {
624+
return Avx512f_32_Int(_mm512_castps_si512(val.m_data));
625+
}
626+
627+
static inline Avx512f_32_Float ReinterpretFloat(const Avx512f_32_Int& val) noexcept {
628+
return Avx512f_32_Float(_mm512_castsi512_ps(val.m_data));
629+
}
630+
631+
friend inline Avx512f_32_Float Round(const Avx512f_32_Float& val) noexcept {
632+
return Avx512f_32_Float(_mm512_roundscale_ps(val.m_data, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
633+
}
634+
575635
friend inline Avx512f_32_Float Abs(const Avx512f_32_Float& val) noexcept {
576636
return Avx512f_32_Float(
577637
_mm512_castsi512_ps(_mm512_and_si512(_mm512_castps_si512(val.m_data), _mm512_set1_epi32(0x7FFFFFFF))));
@@ -609,14 +669,6 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
609669
return Avx512f_32_Float(_mm512_sqrt_ps(val.m_data));
610670
}
611671

612-
friend inline Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept {
613-
return ApplyFunc([](T x) { return std::exp(x); }, val);
614-
}
615-
616-
friend inline Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept {
617-
return ApplyFunc([](T x) { return std::log(x); }, val);
618-
}
619-
620672
template<bool bDisableApprox,
621673
bool bNegateInput = false,
622674
bool bNaNPossible = true,
@@ -627,7 +679,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
627679
static inline Avx512f_32_Float ApproxExp(const Avx512f_32_Float& val,
628680
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
629681
UNUSED(addExpSchraudolphTerm);
630-
return Exp(bNegateInput ? -val : val);
682+
return Exp<bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
631683
}
632684

633685
template<bool bDisableApprox,
@@ -687,8 +739,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
687739
static inline Avx512f_32_Float ApproxLog(
688740
const Avx512f_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
689741
UNUSED(addLogSchraudolphTerm);
690-
Avx512f_32_Float ret = Log(val);
691-
return bNegateOutput ? -ret : ret;
742+
return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
692743
}
693744

694745
template<bool bDisableApprox,
@@ -772,6 +823,25 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
772823
static_assert(std::is_standard_layout<Avx512f_32_Float>::value && std::is_trivially_copyable<Avx512f_32_Float>::value,
773824
"This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed");
774825

826+
template<bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
827+
inline Avx512f_32_Float Exp(const Avx512f_32_Float& val) noexcept {
828+
return Exp32<Avx512f_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
829+
}
830+
831+
template<bool bNegateOutput,
832+
bool bNaNPossible,
833+
bool bNegativePossible,
834+
bool bZeroPossible,
835+
bool bPositiveInfinityPossible>
836+
inline Avx512f_32_Float Log(const Avx512f_32_Float& val) noexcept {
837+
return Log32<Avx512f_32_Float,
838+
bNegateOutput,
839+
bNaNPossible,
840+
bNegativePossible,
841+
bZeroPossible,
842+
bPositiveInfinityPossible>(val);
843+
}
844+
775845
INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx512f_32(
776846
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
777847
const Objective* const pObjective = static_cast<const Objective*>(pObjectiveWrapper->m_pObjective);

shared/libebm/compute/math.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ static INLINE_ALWAYS TFloat Exp32(const TFloat val) {
9393

9494
ret = (ret + TFloat{1}) * rounded2;
9595

96-
// TODO: handling overflow/underflow possible faster see vectormath version2 code
9796
if(bOverflowPossible) {
9897
if(bNegateInput) {
9998
ret = IfLess(val,

0 commit comments

Comments
 (0)