25
25
#include " Registration.hpp"
26
26
#include " Objective.hpp"
27
27
28
+ #include " math.hpp"
28
29
#include " approximate_math.hpp"
29
30
#include " compute_wrapper.hpp"
30
31
@@ -102,6 +103,10 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
102
103
return Avx512f_32_Int (_mm512_add_epi32 (m_data, other.m_data ));
103
104
}
104
105
106
+ inline Avx512f_32_Int operator -(const Avx512f_32_Int& other) const noexcept {
107
+ return Avx512f_32_Int (_mm512_sub_epi32 (m_data, other.m_data ));
108
+ }
109
+
105
110
inline Avx512f_32_Int operator *(const T& other) const noexcept {
106
111
return Avx512f_32_Int (_mm512_mullo_epi32 (m_data, _mm512_set1_epi32 (other)));
107
112
}
@@ -118,6 +123,16 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
118
123
return Avx512f_32_Int (_mm512_and_si512 (m_data, other.m_data ));
119
124
}
120
125
126
+ inline Avx512f_32_Int operator |(const Avx512f_32_Int& other) const noexcept {
127
+ return Avx512f_32_Int (_mm512_or_si512 (m_data, other.m_data ));
128
+ }
129
+
130
+ friend inline Avx512f_32_Int IfThenElse (
131
+ const __mmask16& cmp, const Avx512f_32_Int& trueVal, const Avx512f_32_Int& falseVal) noexcept {
132
+ return Avx512f_32_Int (_mm512_castps_si512 (
133
+ _mm512_mask_blend_ps (cmp, _mm512_castsi512_ps (falseVal.m_data ), _mm512_castsi512_ps (trueVal.m_data ))));
134
+ }
135
+
121
136
friend inline Avx512f_32_Int PermuteForInterleaf (const Avx512f_32_Int& val) noexcept {
122
137
// this function permutes the values into positions that the Interleaf function expects
123
138
// but for any SIMD implementation the positions can be variable as long as they work together
@@ -137,7 +152,28 @@ struct alignas(k_cAlignment) Avx512f_32_Int final {
137
152
static_assert (std::is_standard_layout<Avx512f_32_Int>::value && std::is_trivially_copyable<Avx512f_32_Int>::value,
138
153
" This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
139
154
155
+ template <bool bNegateInput = false ,
156
+ bool bNaNPossible = true ,
157
+ bool bUnderflowPossible = true ,
158
+ bool bOverflowPossible = true >
159
+ inline Avx512f_32_Float Exp (const Avx512f_32_Float& val) noexcept ;
160
+ template <bool bNegateOutput = false ,
161
+ bool bNaNPossible = true ,
162
+ bool bNegativePossible = true ,
163
+ bool bZeroPossible = true ,
164
+ bool bPositiveInfinityPossible = true >
165
+ inline Avx512f_32_Float Log (const Avx512f_32_Float& val) noexcept ;
166
+
140
167
struct alignas (k_cAlignment) Avx512f_32_Float final {
168
+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
169
+ friend Avx512f_32_Float Exp (const Avx512f_32_Float& val) noexcept ;
170
+ template <bool bNegateOutput,
171
+ bool bNaNPossible,
172
+ bool bNegativePossible,
173
+ bool bZeroPossible,
174
+ bool bPositiveInfinityPossible>
175
+ friend Avx512f_32_Float Log (const Avx512f_32_Float& val) noexcept ;
176
+
141
177
using T = float ;
142
178
using TPack = __m512;
143
179
using TInt = Avx512f_32_Int;
@@ -155,6 +191,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
155
191
inline Avx512f_32_Float (const double val) noexcept : m_data (_mm512_set1_ps (static_cast <T>(val))) {}
156
192
inline Avx512f_32_Float (const float val) noexcept : m_data (_mm512_set1_ps (static_cast <T>(val))) {}
157
193
inline Avx512f_32_Float (const int val) noexcept : m_data (_mm512_set1_ps (static_cast <T>(val))) {}
194
+ explicit Avx512f_32_Float (const Avx512f_32_Int& val) : m_data (_mm512_cvtepi32_ps (val.m_data )) {}
158
195
159
196
inline Avx512f_32_Float operator +() const noexcept { return *this ; }
160
197
@@ -231,6 +268,10 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
231
268
return Avx512f_32_Float (val) / other;
232
269
}
233
270
271
+ friend inline __mmask16 operator <=(const Avx512f_32_Float& left, const Avx512f_32_Float& right) noexcept {
272
+ return _mm512_cmp_ps_mask (left.m_data , right.m_data , _CMP_LE_OQ);
273
+ }
274
+
234
275
inline static Avx512f_32_Float Load (const T* const a) noexcept { return Avx512f_32_Float (_mm512_load_ps (a)); }
235
276
236
277
inline void Store (T* const a) const noexcept { _mm512_store_ps (a, m_data); }
@@ -545,6 +586,11 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
545
586
return Avx512f_32_Float (_mm512_mask_blend_ps (mask, falseVal.m_data , trueVal.m_data ));
546
587
}
547
588
589
+ friend inline Avx512f_32_Float IfThenElse (
590
+ const __mmask16& cmp, const Avx512f_32_Float& trueVal, const Avx512f_32_Float& falseVal) noexcept {
591
+ return Avx512f_32_Float (_mm512_mask_blend_ps (cmp, falseVal.m_data , trueVal.m_data ));
592
+ }
593
+
548
594
friend inline Avx512f_32_Float IfEqual (const Avx512f_32_Float& cmp1,
549
595
const Avx512f_32_Float& cmp2,
550
596
const Avx512f_32_Float& trueVal,
@@ -572,6 +618,20 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
572
618
return Avx512f_32_Float (_mm512_mask_blend_ps (mask, falseVal.m_data , trueVal.m_data ));
573
619
}
574
620
621
+ static inline __mmask16 ReinterpretInt (const __mmask16& val) noexcept { return val; }
622
+
623
+ static inline Avx512f_32_Int ReinterpretInt (const Avx512f_32_Float& val) noexcept {
624
+ return Avx512f_32_Int (_mm512_castps_si512 (val.m_data ));
625
+ }
626
+
627
+ static inline Avx512f_32_Float ReinterpretFloat (const Avx512f_32_Int& val) noexcept {
628
+ return Avx512f_32_Float (_mm512_castsi512_ps (val.m_data ));
629
+ }
630
+
631
+ friend inline Avx512f_32_Float Round (const Avx512f_32_Float& val) noexcept {
632
+ return Avx512f_32_Float (_mm512_roundscale_ps (val.m_data , _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
633
+ }
634
+
575
635
friend inline Avx512f_32_Float Abs (const Avx512f_32_Float& val) noexcept {
576
636
return Avx512f_32_Float (
577
637
_mm512_castsi512_ps (_mm512_and_si512 (_mm512_castps_si512 (val.m_data ), _mm512_set1_epi32 (0x7FFFFFFF ))));
@@ -609,14 +669,6 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
609
669
return Avx512f_32_Float (_mm512_sqrt_ps (val.m_data ));
610
670
}
611
671
612
- friend inline Avx512f_32_Float Exp (const Avx512f_32_Float& val) noexcept {
613
- return ApplyFunc ([](T x) { return std::exp (x); }, val);
614
- }
615
-
616
- friend inline Avx512f_32_Float Log (const Avx512f_32_Float& val) noexcept {
617
- return ApplyFunc ([](T x) { return std::log (x); }, val);
618
- }
619
-
620
672
template <bool bDisableApprox,
621
673
bool bNegateInput = false ,
622
674
bool bNaNPossible = true ,
@@ -627,7 +679,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
627
679
static inline Avx512f_32_Float ApproxExp (const Avx512f_32_Float& val,
628
680
const int32_t addExpSchraudolphTerm = k_expTermZeroMeanErrorForSoftmaxWithZeroedLogit) noexcept {
629
681
UNUSED (addExpSchraudolphTerm);
630
- return Exp ( bNegateInput ? -val : val);
682
+ return Exp< bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>( val);
631
683
}
632
684
633
685
template <bool bDisableApprox,
@@ -687,8 +739,7 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
687
739
static inline Avx512f_32_Float ApproxLog (
688
740
const Avx512f_32_Float& val, const float addLogSchraudolphTerm = k_logTermLowerBoundInputCloseToOne) noexcept {
689
741
UNUSED (addLogSchraudolphTerm);
690
- Avx512f_32_Float ret = Log (val);
691
- return bNegateOutput ? -ret : ret;
742
+ return Log<bNegateOutput, bNaNPossible, bNegativePossible, bZeroPossible, bPositiveInfinityPossible>(val);
692
743
}
693
744
694
745
template <bool bDisableApprox,
@@ -772,6 +823,25 @@ struct alignas(k_cAlignment) Avx512f_32_Float final {
772
823
static_assert (std::is_standard_layout<Avx512f_32_Float>::value && std::is_trivially_copyable<Avx512f_32_Float>::value,
773
824
" This allows offsetof, memcpy, memset, inter-language, GPU and cross-machine use where needed" );
774
825
826
+ template <bool bNegateInput, bool bNaNPossible, bool bUnderflowPossible, bool bOverflowPossible>
827
+ inline Avx512f_32_Float Exp (const Avx512f_32_Float& val) noexcept {
828
+ return Exp32<Avx512f_32_Float, bNegateInput, bNaNPossible, bUnderflowPossible, bOverflowPossible>(val);
829
+ }
830
+
831
+ template <bool bNegateOutput,
832
+ bool bNaNPossible,
833
+ bool bNegativePossible,
834
+ bool bZeroPossible,
835
+ bool bPositiveInfinityPossible>
836
+ inline Avx512f_32_Float Log (const Avx512f_32_Float& val) noexcept {
837
+ return Log32<Avx512f_32_Float,
838
+ bNegateOutput,
839
+ bNaNPossible,
840
+ bNegativePossible,
841
+ bZeroPossible,
842
+ bPositiveInfinityPossible>(val);
843
+ }
844
+
775
845
INTERNAL_IMPORT_EXPORT_BODY ErrorEbm ApplyUpdate_Avx512f_32 (
776
846
const ObjectiveWrapper* const pObjectiveWrapper, ApplyUpdateBridge* const pData) {
777
847
const Objective* const pObjective = static_cast <const Objective*>(pObjectiveWrapper->m_pObjective );
0 commit comments