Skip to content

Commit d83012d

Browse files
committed
Switch mlk_polyvec and mlk_polymat to struct wrappers
- Change mlk_polyvec back to struct { mlk_poly vec[MLKEM_K]; } - Change mlk_polymat to struct { mlk_polyvec vec[MLKEM_K]; } - Update all function signatures to use pointer style - Fix all implementations to use struct member access - Update tests, benchmarks, and CBMC harnesses - Add consistent const annotations Signed-off-by: Hanno Becker <[email protected]>
1 parent 1ab3bad commit d83012d

File tree

21 files changed

+184
-158
lines changed

21 files changed

+184
-158
lines changed

mlkem/src/indcpa.c

Lines changed: 75 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@
5959
* Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L19]
6060
*
6161
**************************************************/
62-
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec pk,
62+
static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES],
63+
const mlk_polyvec *pk,
6364
const uint8_t seed[MLKEM_SYMBYTES])
6465
{
6566
mlk_assert_bound_2d(pk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
@@ -83,7 +84,7 @@ static void mlk_pack_pk(uint8_t r[MLKEM_INDCPA_PUBLICKEYBYTES], mlk_polyvec pk,
8384
* Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L2-3]
8485
*
8586
**************************************************/
86-
static void mlk_unpack_pk(mlk_polyvec pk, uint8_t seed[MLKEM_SYMBYTES],
87+
static void mlk_unpack_pk(mlk_polyvec *pk, uint8_t seed[MLKEM_SYMBYTES],
8788
const uint8_t packedpk[MLKEM_INDCPA_PUBLICKEYBYTES])
8889
{
8990
mlk_polyvec_frombytes(pk, packedpk);
@@ -108,7 +109,8 @@ static void mlk_unpack_pk(mlk_polyvec pk, uint8_t seed[MLKEM_SYMBYTES],
108109
* Implements @[FIPS203, Algorithm 13 (K-PKE.KeyGen), L20]
109110
*
110111
**************************************************/
111-
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec sk)
112+
static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES],
113+
const mlk_polyvec *sk)
112114
{
113115
mlk_assert_bound_2d(sk, MLKEM_K, MLKEM_N, 0, MLKEM_Q);
114116
mlk_polyvec_tobytes(r, sk);
@@ -128,7 +130,7 @@ static void mlk_pack_sk(uint8_t r[MLKEM_INDCPA_SECRETKEYBYTES], mlk_polyvec sk)
128130
* Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L5]
129131
*
130132
**************************************************/
131-
static void mlk_unpack_sk(mlk_polyvec sk,
133+
static void mlk_unpack_sk(mlk_polyvec *sk,
132134
const uint8_t packedsk[MLKEM_INDCPA_SECRETKEYBYTES])
133135
{
134136
mlk_polyvec_frombytes(sk, packedsk);
@@ -149,8 +151,8 @@ static void mlk_unpack_sk(mlk_polyvec sk,
149151
* Implements @[FIPS203, Algorithm 14 (K-PKE.Encrypt), L22-23]
150152
*
151153
**************************************************/
152-
static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], mlk_polyvec b,
153-
mlk_poly *v)
154+
static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES],
155+
const mlk_polyvec *b, mlk_poly *v)
154156
{
155157
mlk_polyvec_compress_du(r, b);
156158
mlk_poly_compress_dv(r + MLKEM_POLYVECCOMPRESSEDBYTES_DU, v);
@@ -170,7 +172,7 @@ static void mlk_pack_ciphertext(uint8_t r[MLKEM_INDCPA_BYTES], mlk_polyvec b,
170172
* Implements @[FIPS203, Algorithm 15 (K-PKE.Decrypt), L1-4]
171173
*
172174
**************************************************/
173-
static void mlk_unpack_ciphertext(mlk_polyvec b, mlk_poly *v,
175+
static void mlk_unpack_ciphertext(mlk_polyvec *b, mlk_poly *v,
174176
const uint8_t c[MLKEM_INDCPA_BYTES])
175177
{
176178
mlk_polyvec_decompress_du(b, c);
@@ -201,7 +203,7 @@ __contract__(
201203
*
202204
* Not static for benchmarking */
203205
MLK_INTERNAL_API
204-
void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
206+
void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
205207
int transposed)
206208
{
207209
unsigned i, j;
@@ -238,7 +240,11 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
238240
}
239241
}
240242

241-
mlk_poly_rej_uniform_x4(&a[i], &a[i + 1], &a[i + 2], &a[i + 3], seed_ext);
243+
mlk_poly_rej_uniform_x4(&a->vec[i / MLKEM_K].vec[i % MLKEM_K],
244+
&a->vec[(i + 1) / MLKEM_K].vec[(i + 1) % MLKEM_K],
245+
&a->vec[(i + 2) / MLKEM_K].vec[(i + 2) % MLKEM_K],
246+
&a->vec[(i + 3) / MLKEM_K].vec[(i + 3) % MLKEM_K],
247+
seed_ext);
242248
}
243249

244250
/* For MLKEM_K == 3, sample the last entry individually. */
@@ -259,7 +265,7 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
259265
seed_ext[0][MLKEM_SYMBYTES + 1] = x;
260266
}
261267

262-
mlk_poly_rej_uniform(&a[i], seed_ext[0]);
268+
mlk_poly_rej_uniform(&a->vec[i / MLKEM_K].vec[i % MLKEM_K], seed_ext[0]);
263269
i++;
264270
}
265271

@@ -271,7 +277,8 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
271277
*/
272278
for (i = 0; i < MLKEM_K * MLKEM_K; i++)
273279
{
274-
mlk_poly_permute_bitrev_to_custom(a[i].coeffs);
280+
mlk_poly_permute_bitrev_to_custom(
281+
a->vec[i / MLKEM_K].vec[i % MLKEM_K].coeffs);
275282
}
276283

277284
/* Specification: Partially implements
@@ -296,15 +303,16 @@ void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
296303
* Specification: Implements @[FIPS203, Section 2.4.7, Eq (2.12), (2.13)]
297304
*
298305
**************************************************/
299-
static void mlk_matvec_mul(mlk_polyvec out, const mlk_polymat a,
300-
const mlk_polyvec v, const mlk_polyvec_mulcache vc)
306+
static void mlk_matvec_mul(mlk_polyvec *out, const mlk_polymat *a,
307+
const mlk_polyvec *v, const mlk_polyvec_mulcache *vc)
301308
__contract__(
302309
requires(memory_no_alias(out, sizeof(mlk_polyvec)))
303310
requires(memory_no_alias(a, sizeof(mlk_polymat)))
304311
requires(memory_no_alias(v, sizeof(mlk_polyvec)))
305312
requires(memory_no_alias(vc, sizeof(mlk_polyvec_mulcache)))
306-
requires(forall(k0, 0, MLKEM_K * MLKEM_K,
307-
array_bound(a[k0].coeffs, 0, MLKEM_N, 0, MLKEM_UINT12_LIMIT)))
313+
requires(forall(k0, 0, MLKEM_K,
314+
forall(k1, 0, MLKEM_K,
315+
array_bound(a->vec[k0].vec[k1].coeffs, 0, MLKEM_N, 0, MLKEM_UINT12_LIMIT))))
308316
assigns(object_whole(out)))
309317
{
310318
unsigned i;
@@ -313,7 +321,7 @@ __contract__(
313321
assigns(i, object_whole(out))
314322
invariant(i <= MLKEM_K))
315323
{
316-
mlk_polyvec_basemul_acc_montgomery_cached(&out[i], &a[MLKEM_K * i], v, vc);
324+
mlk_polyvec_basemul_acc_montgomery_cached(&out->vec[i], &a->vec[i], v, vc);
317325
}
318326
}
319327

@@ -352,47 +360,49 @@ void mlk_indcpa_keypair_derand(uint8_t pk[MLKEM_INDCPA_PUBLICKEYBYTES],
352360
*/
353361
MLK_CT_TESTING_DECLASSIFY(publicseed, MLKEM_SYMBYTES);
354362

355-
mlk_gen_matrix(a, publicseed, 0 /* no transpose */);
363+
mlk_gen_matrix(&a, publicseed, 0 /* no transpose */);
356364

357365
#if MLKEM_K == 2
358-
mlk_poly_getnoise_eta1_4x(&skpv[0], &skpv[1], &e[0], &e[1], noiseseed, 0, 1,
359-
2, 3);
366+
mlk_poly_getnoise_eta1_4x(&skpv.vec[0], &skpv.vec[1], &e.vec[0], &e.vec[1],
367+
noiseseed, 0, 1, 2, 3);
360368
#elif MLKEM_K == 3
361369
/*
362370
* Only the first three output buffers are needed.
363371
* The laster parameter is a dummy that's overwritten later.
364372
*/
365-
mlk_poly_getnoise_eta1_4x(&skpv[0], &skpv[1], &skpv[2],
366-
&pkpv[0] /* irrelevant */, noiseseed, 0, 1, 2,
373+
mlk_poly_getnoise_eta1_4x(&skpv.vec[0], &skpv.vec[1], &skpv.vec[2],
374+
&pkpv.vec[0] /* irrelevant */, noiseseed, 0, 1, 2,
367375
0xFF /* irrelevant */);
368376
/* Same here */
369-
mlk_poly_getnoise_eta1_4x(&e[0], &e[1], &e[2], &pkpv[0] /* irrelevant */,
370-
noiseseed, 3, 4, 5, 0xFF /* irrelevant */);
377+
mlk_poly_getnoise_eta1_4x(&e.vec[0], &e.vec[1], &e.vec[2],
378+
&pkpv.vec[0] /* irrelevant */, noiseseed, 3, 4, 5,
379+
0xFF /* irrelevant */);
371380
#elif MLKEM_K == 4
372-
mlk_poly_getnoise_eta1_4x(&skpv[0], &skpv[1], &skpv[2], &skpv[3], noiseseed,
373-
0, 1, 2, 3);
374-
mlk_poly_getnoise_eta1_4x(&e[0], &e[1], &e[2], &e[3], noiseseed, 4, 5, 6, 7);
375-
#endif
381+
mlk_poly_getnoise_eta1_4x(&skpv.vec[0], &skpv.vec[1], &skpv.vec[2],
382+
&skpv.vec[3], noiseseed, 0, 1, 2, 3);
383+
mlk_poly_getnoise_eta1_4x(&e.vec[0], &e.vec[1], &e.vec[2], &e.vec[3],
384+
noiseseed, 4, 5, 6, 7);
385+
#endif /* MLKEM_K == 4 */
376386

377-
mlk_polyvec_ntt(skpv);
378-
mlk_polyvec_ntt(e);
387+
mlk_polyvec_ntt(&skpv);
388+
mlk_polyvec_ntt(&e);
379389

380-
mlk_polyvec_mulcache_compute(skpv_cache, skpv);
381-
mlk_matvec_mul(pkpv, a, skpv, skpv_cache);
382-
mlk_polyvec_tomont(pkpv);
390+
mlk_polyvec_mulcache_compute(&skpv_cache, &skpv);
391+
mlk_matvec_mul(&pkpv, &a, &skpv, &skpv_cache);
392+
mlk_polyvec_tomont(&pkpv);
383393

384-
mlk_polyvec_add(pkpv, e);
385-
mlk_polyvec_reduce(pkpv);
386-
mlk_polyvec_reduce(skpv);
394+
mlk_polyvec_add(&pkpv, &e);
395+
mlk_polyvec_reduce(&pkpv);
396+
mlk_polyvec_reduce(&skpv);
387397

388-
mlk_pack_sk(sk, skpv);
389-
mlk_pack_pk(pk, pkpv, publicseed);
398+
mlk_pack_sk(sk, &skpv);
399+
mlk_pack_pk(pk, &pkpv, publicseed);
390400

391401
/* Specification: Partially implements
392402
* @[FIPS203, Section 3.3, Destruction of intermediate values] */
393403
mlk_zeroize(buf, sizeof(buf));
394404
mlk_zeroize(coins_with_domain_separator, sizeof(coins_with_domain_separator));
395-
mlk_zeroize(a, sizeof(a));
405+
mlk_zeroize(&a, sizeof(a));
396406
mlk_zeroize(&e, sizeof(e));
397407
mlk_zeroize(&skpv, sizeof(skpv));
398408
mlk_zeroize(&skpv_cache, sizeof(skpv_cache));
@@ -418,7 +428,7 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
418428
mlk_poly v, k, epp;
419429
mlk_polyvec_mulcache sp_cache;
420430

421-
mlk_unpack_pk(pkpv, seed, pk);
431+
mlk_unpack_pk(&pkpv, seed, pk);
422432
mlk_poly_frommsg(&k, m);
423433

424434
/*
@@ -429,44 +439,47 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
429439
*/
430440
MLK_CT_TESTING_DECLASSIFY(seed, MLKEM_SYMBYTES);
431441

432-
mlk_gen_matrix(at, seed, 1 /* transpose */);
442+
mlk_gen_matrix(&at, seed, 1 /* transpose */);
433443

434444
#if MLKEM_K == 2
435-
mlk_poly_getnoise_eta1122_4x(&sp[0], &sp[1], &ep[0], &ep[1], coins, 0, 1, 2,
436-
3);
445+
mlk_poly_getnoise_eta1122_4x(&sp.vec[0], &sp.vec[1], &ep.vec[0], &ep.vec[1],
446+
coins, 0, 1, 2, 3);
437447
mlk_poly_getnoise_eta2(&epp, coins, 4);
438448
#elif MLKEM_K == 3
439449
/*
440450
* In this call, only the first three output buffers are needed.
441451
* The last parameter is a dummy that's overwritten later.
442452
*/
443-
mlk_poly_getnoise_eta1_4x(&sp[0], &sp[1], &sp[2], &b[0], coins, 0, 1, 2,
444-
0xFF);
453+
mlk_poly_getnoise_eta1_4x(&sp.vec[0], &sp.vec[1], &sp.vec[2], &b.vec[0],
454+
coins, 0, 1, 2, 0xFF);
445455
/* The fourth output buffer in this call _is_ used. */
446-
mlk_poly_getnoise_eta2_4x(&ep[0], &ep[1], &ep[2], &epp, coins, 3, 4, 5, 6);
456+
mlk_poly_getnoise_eta2_4x(&ep.vec[0], &ep.vec[1], &ep.vec[2], &epp, coins, 3,
457+
4, 5, 6);
447458
#elif MLKEM_K == 4
448-
mlk_poly_getnoise_eta1_4x(&sp[0], &sp[1], &sp[2], &sp[3], coins, 0, 1, 2, 3);
449-
mlk_poly_getnoise_eta2_4x(&ep[0], &ep[1], &ep[2], &ep[3], coins, 4, 5, 6, 7);
459+
mlk_poly_getnoise_eta1_4x(&sp.vec[0], &sp.vec[1], &sp.vec[2], &sp.vec[3],
460+
coins, 0, 1, 2, 3);
461+
mlk_poly_getnoise_eta2_4x(&ep.vec[0], &ep.vec[1], &ep.vec[2], &ep.vec[3],
462+
coins, 4, 5, 6, 7);
450463
mlk_poly_getnoise_eta2(&epp, coins, 8);
451-
#endif
464+
#endif /* MLKEM_K == 4 */
452465

453-
mlk_polyvec_ntt(sp);
466+
mlk_polyvec_ntt(&sp);
454467

455-
mlk_polyvec_mulcache_compute(sp_cache, sp);
456-
mlk_matvec_mul(b, at, sp, sp_cache);
457-
mlk_polyvec_basemul_acc_montgomery_cached(&v, pkpv, sp, sp_cache);
468+
mlk_polyvec_mulcache_compute(&sp_cache, &sp);
469+
mlk_matvec_mul(&b, &at, &sp, &sp_cache);
470+
mlk_polyvec_basemul_acc_montgomery_cached(&v, &pkpv, &sp, &sp_cache);
458471

459-
mlk_polyvec_invntt_tomont(b);
472+
mlk_polyvec_invntt_tomont(&b);
460473
mlk_poly_invntt_tomont(&v);
461474

462-
mlk_polyvec_add(b, ep);
475+
mlk_polyvec_add(&b, &ep);
463476
mlk_poly_add(&v, &epp);
464477
mlk_poly_add(&v, &k);
465478

466-
mlk_polyvec_reduce(b);
479+
mlk_polyvec_reduce(&b);
467480
mlk_poly_reduce(&v);
468481

469-
mlk_pack_ciphertext(c, b, &v);
482+
mlk_pack_ciphertext(c, &b, &v);
470483

471484
/* Specification: Partially implements
472485
* @[FIPS203, Section 3.3, Destruction of intermediate values] */
@@ -475,7 +488,7 @@ void mlk_indcpa_enc(uint8_t c[MLKEM_INDCPA_BYTES],
475488
mlk_zeroize(&sp_cache, sizeof(sp_cache));
476489
mlk_zeroize(&b, sizeof(b));
477490
mlk_zeroize(&v, sizeof(v));
478-
mlk_zeroize(at, sizeof(at));
491+
mlk_zeroize(&at, sizeof(at));
479492
mlk_zeroize(&k, sizeof(k));
480493
mlk_zeroize(&ep, sizeof(ep));
481494
mlk_zeroize(&epp, sizeof(epp));
@@ -493,12 +506,12 @@ void mlk_indcpa_dec(uint8_t m[MLKEM_INDCPA_MSGBYTES],
493506
mlk_poly v, sb;
494507
mlk_polyvec_mulcache b_cache;
495508

496-
mlk_unpack_ciphertext(b, &v, c);
497-
mlk_unpack_sk(skpv, sk);
509+
mlk_unpack_ciphertext(&b, &v, c);
510+
mlk_unpack_sk(&skpv, sk);
498511

499-
mlk_polyvec_ntt(b);
500-
mlk_polyvec_mulcache_compute(b_cache, b);
501-
mlk_polyvec_basemul_acc_montgomery_cached(&sb, skpv, b, b_cache);
512+
mlk_polyvec_ntt(&b);
513+
mlk_polyvec_mulcache_compute(&b_cache, &b);
514+
mlk_polyvec_basemul_acc_montgomery_cached(&sb, &skpv, &b, &b_cache);
502515
mlk_poly_invntt_tomont(&sb);
503516

504517
mlk_poly_sub(&v, &sb);

mlkem/src/indcpa.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,15 @@
3939
*
4040
**************************************************/
4141
MLK_INTERNAL_API
42-
void mlk_gen_matrix(mlk_polymat a, const uint8_t seed[MLKEM_SYMBYTES],
42+
void mlk_gen_matrix(mlk_polymat *a, const uint8_t seed[MLKEM_SYMBYTES],
4343
int transposed)
4444
__contract__(
4545
requires(memory_no_alias(a, sizeof(mlk_polymat)))
4646
requires(memory_no_alias(seed, MLKEM_SYMBYTES))
4747
requires(transposed == 0 || transposed == 1)
4848
assigns(object_whole(a))
49-
ensures(forall(x, 0, MLKEM_K * MLKEM_K,
50-
array_bound(a[x].coeffs, 0, MLKEM_N, 0, MLKEM_Q)))
49+
ensures(forall(x, 0, MLKEM_K, forall(y, 0, MLKEM_K,
50+
array_bound(a->vec[x].vec[y].coeffs, 0, MLKEM_N, 0, MLKEM_Q))))
5151
);
5252

5353
#define mlk_indcpa_keypair_derand MLK_NAMESPACE_K(indcpa_keypair_derand)

mlkem/src/kem.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ int crypto_kem_check_pk(const uint8_t pk[MLKEM_INDCCA_PUBLICKEYBYTES])
5858
mlk_polyvec p;
5959
uint8_t p_reencoded[MLKEM_POLYVECBYTES];
6060

61-
mlk_polyvec_frombytes(p, pk);
62-
mlk_polyvec_reduce(p);
63-
mlk_polyvec_tobytes(p_reencoded, p);
61+
mlk_polyvec_frombytes(&p, pk);
62+
mlk_polyvec_reduce(&p);
63+
mlk_polyvec_tobytes(p_reencoded, &p);
6464

6565
/* We use a constant-time memcmp here to avoid having to
6666
* declassify the PK before the PCT has succeeded. */

0 commit comments

Comments
 (0)