Skip to content

Commit 3e7a817

Browse files
committed
Add runtime dispatch (mld_polyvecl_pointwise_acc_montgomery_l4/l5/l7_native)
Signed-off-by: willieyz <[email protected]>
1 parent c97c822 commit 3e7a817

File tree

6 files changed

+91
-46
lines changed

6 files changed

+91
-46
lines changed

dev/aarch64_clean/meta.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,28 +171,31 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
171171
return MLD_NATIVE_FUNC_SUCCESS;
172172
}
173173

174-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
174+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
175175
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
176176
const int32_t v[4][MLDSA_N])
177177
{
178178
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, (const int32_t *)u,
179179
(const int32_t *)v);
180+
return MLD_NATIVE_FUNC_SUCCESS;
180181
}
181182

182-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
183+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
183184
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
184185
const int32_t v[5][MLDSA_N])
185186
{
186187
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, (const int32_t *)u,
187188
(const int32_t *)v);
189+
return MLD_NATIVE_FUNC_SUCCESS;
188190
}
189191

190-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
192+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
191193
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
192194
const int32_t v[7][MLDSA_N])
193195
{
194196
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, (const int32_t *)u,
195197
(const int32_t *)v);
198+
return MLD_NATIVE_FUNC_SUCCESS;
196199
}
197200

198201
#endif /* !__ASSEMBLER__ */

dev/x86_64/meta.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,28 +228,43 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
228228
return MLD_NATIVE_FUNC_SUCCESS;
229229
}
230230

231-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
231+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
232232
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
233233
const int32_t v[4][MLDSA_N])
234234
{
235+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
236+
{
237+
return MLD_NATIVE_FUNC_FALLBACK;
238+
}
235239
mld_pointwise_acc_l4_avx2((__m256i *)w, (const __m256i *)u,
236240
(const __m256i *)v, mld_qdata.vec);
241+
return MLD_NATIVE_FUNC_SUCCESS;
237242
}
238243

239-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
244+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
240245
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
241246
const int32_t v[5][MLDSA_N])
242247
{
248+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
249+
{
250+
return MLD_NATIVE_FUNC_FALLBACK;
251+
}
243252
mld_pointwise_acc_l5_avx2((__m256i *)w, (const __m256i *)u,
244253
(const __m256i *)v, mld_qdata.vec);
254+
return MLD_NATIVE_FUNC_SUCCESS;
245255
}
246256

247-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
257+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
248258
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
249259
const int32_t v[7][MLDSA_N])
250260
{
261+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
262+
{
263+
return MLD_NATIVE_FUNC_FALLBACK;
264+
}
251265
mld_pointwise_acc_l7_avx2((__m256i *)w, (const __m256i *)u,
252266
(const __m256i *)v, mld_qdata.vec);
267+
return MLD_NATIVE_FUNC_SUCCESS;
253268
}
254269

255270
#endif /* !__ASSEMBLER__ */

mldsa/src/native/aarch64/meta.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,28 +171,31 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
171171
return MLD_NATIVE_FUNC_SUCCESS;
172172
}
173173

174-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
174+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
175175
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
176176
const int32_t v[4][MLDSA_N])
177177
{
178178
mld_polyvecl_pointwise_acc_montgomery_l4_asm(w, (const int32_t *)u,
179179
(const int32_t *)v);
180+
return MLD_NATIVE_FUNC_SUCCESS;
180181
}
181182

182-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
183+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
183184
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
184185
const int32_t v[5][MLDSA_N])
185186
{
186187
mld_polyvecl_pointwise_acc_montgomery_l5_asm(w, (const int32_t *)u,
187188
(const int32_t *)v);
189+
return MLD_NATIVE_FUNC_SUCCESS;
188190
}
189191

190-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
192+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
191193
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
192194
const int32_t v[7][MLDSA_N])
193195
{
194196
mld_polyvecl_pointwise_acc_montgomery_l7_asm(w, (const int32_t *)u,
195197
(const int32_t *)v);
198+
return MLD_NATIVE_FUNC_SUCCESS;
196199
}
197200

198201
#endif /* !__ASSEMBLER__ */

mldsa/src/native/api.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
339339
* - const int32_t u[MLDSA_L][MLDSA_N]: first input vector
340340
* - const int32_t v[MLDSA_L][MLDSA_N]: second input vector
341341
**************************************************/
342-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
342+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
343343
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
344344
const int32_t v[4][MLDSA_N]);
345345
#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 */
@@ -359,7 +359,7 @@ static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
359359
* - const int32_t u[MLDSA_L][MLDSA_N]: first input vector
360360
* - const int32_t v[MLDSA_L][MLDSA_N]: second input vector
361361
**************************************************/
362-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
362+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
363363
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
364364
const int32_t v[5][MLDSA_N]);
365365
#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 */
@@ -379,7 +379,7 @@ static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
379379
* - const int32_t u[MLDSA_L][MLDSA_N]: first input vector
380380
* - const int32_t v[MLDSA_L][MLDSA_N]: second input vector
381381
**************************************************/
382-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
382+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
383383
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
384384
const int32_t v[7][MLDSA_N]);
385385
#endif /* MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 */

mldsa/src/native/x86_64/meta.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,28 +228,43 @@ static MLD_INLINE int mld_poly_pointwise_montgomery_native(
228228
return MLD_NATIVE_FUNC_SUCCESS;
229229
}
230230

231-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l4_native(
231+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l4_native(
232232
int32_t w[MLDSA_N], const int32_t u[4][MLDSA_N],
233233
const int32_t v[4][MLDSA_N])
234234
{
235+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
236+
{
237+
return MLD_NATIVE_FUNC_FALLBACK;
238+
}
235239
mld_pointwise_acc_l4_avx2((__m256i *)w, (const __m256i *)u,
236240
(const __m256i *)v, mld_qdata.vec);
241+
return MLD_NATIVE_FUNC_SUCCESS;
237242
}
238243

239-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l5_native(
244+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l5_native(
240245
int32_t w[MLDSA_N], const int32_t u[5][MLDSA_N],
241246
const int32_t v[5][MLDSA_N])
242247
{
248+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
249+
{
250+
return MLD_NATIVE_FUNC_FALLBACK;
251+
}
243252
mld_pointwise_acc_l5_avx2((__m256i *)w, (const __m256i *)u,
244253
(const __m256i *)v, mld_qdata.vec);
254+
return MLD_NATIVE_FUNC_SUCCESS;
245255
}
246256

247-
static MLD_INLINE void mld_polyvecl_pointwise_acc_montgomery_l7_native(
257+
static MLD_INLINE int mld_polyvecl_pointwise_acc_montgomery_l7_native(
248258
int32_t w[MLDSA_N], const int32_t u[7][MLDSA_N],
249259
const int32_t v[7][MLDSA_N])
250260
{
261+
if (!mld_sys_check_capability(MLD_SYS_CAP_AVX2))
262+
{
263+
return MLD_NATIVE_FUNC_FALLBACK;
264+
}
251265
mld_pointwise_acc_l7_avx2((__m256i *)w, (const __m256i *)u,
252266
(const __m256i *)v, mld_qdata.vec);
267+
return MLD_NATIVE_FUNC_SUCCESS;
253268
}
254269

255270
#endif /* !__ASSEMBLER__ */

mldsa/src/polyvec.c

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -291,42 +291,57 @@ MLD_INTERNAL_API
291291
void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
292292
const mld_polyvecl *v)
293293
{
294-
#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) && \
295-
MLD_CONFIG_PARAMETER_SET == 44
296-
/* TODO: proof */
294+
unsigned int i, j;
297295
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
298296
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
299-
mld_polyvecl_pointwise_acc_montgomery_l4_native(
300-
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
301-
(const int32_t(*)[MLDSA_N])v->vec);
302-
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
297+
#if defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4) && \
298+
MLD_CONFIG_PARAMETER_SET == 44
299+
{
300+
/* TODO: proof */
301+
int ret;
302+
ret = mld_polyvecl_pointwise_acc_montgomery_l4_native(
303+
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
304+
(const int32_t(*)[MLDSA_N])v->vec);
305+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
306+
{
307+
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
308+
return;
309+
}
310+
}
303311
#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5) && \
304312
MLD_CONFIG_PARAMETER_SET == 65
305-
/* TODO: proof */
306-
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
307-
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
308-
mld_polyvecl_pointwise_acc_montgomery_l5_native(
309-
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
310-
(const int32_t(*)[MLDSA_N])v->vec);
311-
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
313+
{
314+
/* TODO: proof */
315+
int ret;
316+
ret = mld_polyvecl_pointwise_acc_montgomery_l5_native(
317+
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
318+
(const int32_t(*)[MLDSA_N])v->vec);
319+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
320+
{
321+
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
322+
return;
323+
}
324+
}
312325
#elif defined(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7) && \
313326
MLD_CONFIG_PARAMETER_SET == 87
314-
/* TODO: proof */
315-
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
316-
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
317-
mld_polyvecl_pointwise_acc_montgomery_l7_native(
318-
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
319-
(const int32_t(*)[MLDSA_N])v->vec);
320-
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
321-
#else /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
327+
{
328+
/* TODO: proof */
329+
int ret;
330+
ret = mld_polyvecl_pointwise_acc_montgomery_l7_native(
331+
w->coeffs, (const int32_t(*)[MLDSA_N])u->vec,
332+
(const int32_t(*)[MLDSA_N])v->vec);
333+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
334+
{
335+
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
336+
return;
337+
}
338+
}
339+
#endif /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
322340
MLD_CONFIG_PARAMETER_SET == 44) && \
323341
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \
324342
MLD_CONFIG_PARAMETER_SET == 65) && \
325343
MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \
326344
MLD_CONFIG_PARAMETER_SET == 87 */
327-
unsigned int i, j;
328-
mld_assert_bound_2d(u->vec, MLDSA_L, MLDSA_N, 0, MLDSA_Q);
329-
mld_assert_abs_bound_2d(v->vec, MLDSA_L, MLDSA_N, MLD_NTT_BOUND);
330345
/* The first input is bounded by [0, Q-1] inclusive
331346
* The second input is bounded by [-9Q+1, 9Q-1] inclusive . Hence, we can
332347
* safely accumulate in 64-bits without intermediate reductions as
@@ -361,12 +376,6 @@ void mld_polyvecl_pointwise_acc_montgomery(mld_poly *w, const mld_polyvecl *u,
361376
}
362377

363378
mld_assert_abs_bound(w->coeffs, MLDSA_N, MLDSA_Q);
364-
#endif /* !(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L4 && \
365-
MLD_CONFIG_PARAMETER_SET == 44) && \
366-
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L5 && \
367-
MLD_CONFIG_PARAMETER_SET == 65) && \
368-
!(MLD_USE_NATIVE_POLYVECL_POINTWISE_ACC_MONTGOMERY_L7 && \
369-
MLD_CONFIG_PARAMETER_SET == 87) */
370379
}
371380

372381
MLD_INTERNAL_API

0 commit comments

Comments
 (0)