Skip to content

Commit a6c3cc0

Browse files
committed
Add runtime dispatch (mld_polyz_unpack_17/19_native)
Signed-off-by: willieyz <[email protected]>
1 parent cfb84c5 commit a6c3cc0

File tree

6 files changed

+32
-28
lines changed

6 files changed

+32
-28
lines changed

dev/aarch64_clean/meta.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,12 @@ static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B)
151151
return mld_poly_chknorm_asm(a, B) == 0 ? 0 : 1;
152152
}
153153

154-
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r,
155-
const uint8_t *buf)
154+
static MLD_INLINE int mld_polyz_unpack_17_native(int32_t *r, const uint8_t *buf)
156155
{
157156
mld_polyz_unpack_17_asm(r, buf, mld_polyz_unpack_17_indices);
158157
}
159158

160-
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r,
161-
const uint8_t *buf)
159+
static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *buf)
162160
{
163161
mld_polyz_unpack_19_asm(r, buf, mld_polyz_unpack_19_indices);
164162
}

dev/x86_64/meta.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,12 @@ static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B)
196196
return mld_poly_chknorm_avx2((const __m256i *)a, B) == 0 ? 0 : 1;
197197
}
198198

199-
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a)
199+
static MLD_INLINE int mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a)
200200
{
201201
mld_polyz_unpack_17_avx2((__m256i *)r, a);
202202
}
203203

204-
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
204+
static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
205205
{
206206
mld_polyz_unpack_19_avx2((__m256i *)r, a);
207207
}

mldsa/src/native/aarch64/meta.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,12 @@ static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B)
151151
return mld_poly_chknorm_asm(a, B) == 0 ? 0 : 1;
152152
}
153153

154-
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r,
155-
const uint8_t *buf)
154+
static MLD_INLINE int mld_polyz_unpack_17_native(int32_t *r, const uint8_t *buf)
156155
{
157156
mld_polyz_unpack_17_asm(r, buf, mld_polyz_unpack_17_indices);
158157
}
159158

160-
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r,
161-
const uint8_t *buf)
159+
static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *buf)
162160
{
163161
mld_polyz_unpack_19_asm(r, buf, mld_polyz_unpack_19_indices);
164162
}

mldsa/src/native/api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B);
289289
* Arguments: - int32_t *r: pointer to output polynomial
290290
* - const uint8_t *a: byte array with bit-packed polynomial
291291
**************************************************/
292-
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a);
292+
static MLD_INLINE int mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a);
293293
#endif /* MLD_USE_NATIVE_POLYZ_UNPACK_17 */
294294

295295
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_19)
@@ -303,7 +303,7 @@ static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a);
303303
* Arguments: - int32_t *r: pointer to output polynomial
304304
* - const uint8_t *a: byte array with bit-packed polynomial
305305
**************************************************/
306-
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a);
306+
static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a);
307307
#endif /* MLD_USE_NATIVE_POLYZ_UNPACK_19 */
308308

309309
#if defined(MLD_USE_NATIVE_POINTWISE_MONTGOMERY)

mldsa/src/native/x86_64/meta.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,12 +196,12 @@ static MLD_INLINE int mld_poly_chknorm_native(const int32_t *a, int32_t B)
196196
return mld_poly_chknorm_avx2((const __m256i *)a, B) == 0 ? 0 : 1;
197197
}
198198

199-
static MLD_INLINE void mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a)
199+
static MLD_INLINE int mld_polyz_unpack_17_native(int32_t *r, const uint8_t *a)
200200
{
201201
mld_polyz_unpack_17_avx2((__m256i *)r, a);
202202
}
203203

204-
static MLD_INLINE void mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
204+
static MLD_INLINE int mld_polyz_unpack_19_native(int32_t *r, const uint8_t *a)
205205
{
206206
mld_polyz_unpack_19_avx2((__m256i *)r, a);
207207
}

mldsa/src/poly_kl.c

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -745,15 +745,31 @@ void mld_polyz_pack(uint8_t *r, const mld_poly *a)
745745
MLD_INTERNAL_API
746746
void mld_polyz_unpack(mld_poly *r, const uint8_t *a)
747747
{
748+
unsigned int i;
748749
#if defined(MLD_USE_NATIVE_POLYZ_UNPACK_17) && MLD_CONFIG_PARAMETER_SET == 44
749750
/* TODO: proof */
750-
mld_polyz_unpack_17_native(r->coeffs, a);
751+
int ret;
752+
ret = mld_polyz_unpack_17_native(r->coeffs, a);
753+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
754+
{
755+
mld_assert_bound(r->coeffs, MLDSA_N, -(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1 + 1);
756+
return;
757+
}
751758
#elif defined(MLD_USE_NATIVE_POLYZ_UNPACK_19) && \
752759
(MLD_CONFIG_PARAMETER_SET == 65 || MLD_CONFIG_PARAMETER_SET == 87)
753760
/* TODO: proof */
754-
mld_polyz_unpack_19_native(r->coeffs, a);
755-
#elif MLD_CONFIG_PARAMETER_SET == 44
756-
unsigned int i;
761+
int ret;
762+
ret = mld_polyz_unpack_19_native(r->coeffs, a);
763+
if (ret == MLD_NATIVE_FUNC_SUCCESS)
764+
{
765+
mld_assert_bound(r->coeffs, MLDSA_N, -(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1 + 1);
766+
return;
767+
}
768+
#endif /* !(MLD_USE_NATIVE_POLYZ_UNPACK_17 && MLD_CONFIG_PARAMETER_SET == 44) \
769+
&& MLD_USE_NATIVE_POLYZ_UNPACK_19 && (MLD_CONFIG_PARAMETER_SET == 65 \
770+
|| MLD_CONFIG_PARAMETER_SET == 87) */
771+
772+
#if MLD_CONFIG_PARAMETER_SET == 44
757773
for (i = 0; i < MLDSA_N / 4; ++i)
758774
__loop__(
759775
invariant(i <= MLDSA_N/4)
@@ -784,11 +800,7 @@ void mld_polyz_unpack(mld_poly *r, const uint8_t *a)
784800
r->coeffs[4 * i + 2] = MLDSA_GAMMA1 - r->coeffs[4 * i + 2];
785801
r->coeffs[4 * i + 3] = MLDSA_GAMMA1 - r->coeffs[4 * i + 3];
786802
}
787-
#else /* !(MLD_USE_NATIVE_POLYZ_UNPACK_17 && MLD_CONFIG_PARAMETER_SET == 44) \
788-
&& !(MLD_USE_NATIVE_POLYZ_UNPACK_19 && (MLD_CONFIG_PARAMETER_SET == \
789-
65 || MLD_CONFIG_PARAMETER_SET == 87)) && MLD_CONFIG_PARAMETER_SET == \
790-
44 */
791-
unsigned int i;
803+
#else /* MLD_CONFIG_PARAMETER_SET == 44 */
792804
for (i = 0; i < MLDSA_N / 2; ++i)
793805
__loop__(
794806
invariant(i <= MLDSA_N/2)
@@ -808,11 +820,7 @@ void mld_polyz_unpack(mld_poly *r, const uint8_t *a)
808820
r->coeffs[2 * i + 0] = MLDSA_GAMMA1 - r->coeffs[2 * i + 0];
809821
r->coeffs[2 * i + 1] = MLDSA_GAMMA1 - r->coeffs[2 * i + 1];
810822
}
811-
#endif /* !(MLD_USE_NATIVE_POLYZ_UNPACK_17 && MLD_CONFIG_PARAMETER_SET == 44) \
812-
&& !(MLD_USE_NATIVE_POLYZ_UNPACK_19 && (MLD_CONFIG_PARAMETER_SET == \
813-
65 || MLD_CONFIG_PARAMETER_SET == 87)) && MLD_CONFIG_PARAMETER_SET \
814-
!= 44 */
815-
823+
#endif /* MLD_CONFIG_PARAMETER_SET != 44 */
816824
mld_assert_bound(r->coeffs, MLDSA_N, -(MLDSA_GAMMA1 - 1), MLDSA_GAMMA1 + 1);
817825
}
818826

0 commit comments

Comments
 (0)