Skip to content

Commit 25fdb1c

Browse files
committed
make rsa types generic over u8/U8 and reorganize trait implementation
1 parent ce87a1f commit 25fdb1c

File tree

2 files changed

+119
-171
lines changed

2 files changed

+119
-171
lines changed

rsa/src/impl_hacl.rs

Lines changed: 106 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ pub struct PublicKey<const LEN: usize> {
77
}
88

99
/// An RSA Private Key that is `LEN` bytes long.
10-
pub struct PrivateKey<const LEN: usize> {
10+
pub struct PrivateKey<const LEN: usize, PrivateKeyByte> {
1111
pub(crate) pk: PublicKey<LEN>,
12-
pub(crate) d: [u8; LEN],
12+
pub(crate) d: [PrivateKeyByte; LEN],
1313
}
1414

15-
impl<const LEN: usize> alloc::fmt::Debug for PrivateKey<LEN> {
15+
impl<const LEN: usize, PrivateKeyByte> alloc::fmt::Debug for PrivateKey<LEN, PrivateKeyByte> {
1616
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1717
f.debug_struct("PrivateKey")
1818
.field("pk", &self.pk)
@@ -49,7 +49,7 @@ impl VarLenPublicKey<'_> {
4949
self.n
5050
}
5151
}
52-
impl alloc::fmt::Debug for VarLenPrivateKey<'_> {
52+
impl<PrivateKeyByte> alloc::fmt::Debug for VarLenPrivateKey<'_, PrivateKeyByte> {
5353
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
5454
f.debug_struct("PrivateKey")
5555
.field("pk", &self.pk)
@@ -59,40 +59,48 @@ impl alloc::fmt::Debug for VarLenPrivateKey<'_> {
5959
}
6060

6161
/// An RSA Private Key backed by slices. Use if the length is not known at compile time.
62-
pub struct VarLenPrivateKey<'a> {
62+
pub struct VarLenPrivateKey<'a, PrivateKeyByte> {
6363
pub(crate) pk: VarLenPublicKey<'a>,
64-
pub(crate) d: &'a [u8],
64+
pub(crate) d: &'a [PrivateKeyByte],
6565
}
6666

67-
impl<'a> VarLenPrivateKey<'a> {
68-
/// Constructor for the private key based on `n` and `d`.
69-
pub fn from_components(n: &'a [u8], d: &'a [u8]) -> Result<Self, Error> {
70-
if n.len() != d.len() {
71-
return Err(Error::KeyLengthMismatch);
72-
}
67+
macro_rules! impl_var_len_private_key {
68+
($sk_byte:ty) => {
69+
impl<'a> VarLenPrivateKey<'a, $sk_byte> {
70+
/// Constructor for the private key based on `n` and `d`.
71+
pub fn from_components(n: &'a [u8], d: &'a [$sk_byte]) -> Result<Self, Error> {
72+
if n.len() != d.len() {
73+
return Err(Error::KeyLengthMismatch);
74+
}
7375

74-
Ok(Self {
75-
pk: n.try_into()?,
76-
d,
77-
})
78-
}
76+
Ok(Self {
77+
pk: n.try_into()?,
78+
d,
79+
})
80+
}
7981

80-
/// Returns the public key of the private key.
81-
pub fn pk(&self) -> &VarLenPublicKey<'_> {
82-
&self.pk
83-
}
82+
/// Returns the public key of the private key.
83+
pub fn pk(&self) -> &VarLenPublicKey<'_> {
84+
&self.pk
85+
}
8486

85-
/// Returns the length of the keys
86-
pub fn key_len(&self) -> usize {
87-
self.d.len()
88-
}
87+
/// Returns the length of the keys
88+
pub fn key_len(&self) -> usize {
89+
self.d.len()
90+
}
8991

90-
/// Returns the private exponent of the keys
91-
pub fn d(&self) -> &[u8] {
92-
self.d
93-
}
92+
/// Returns the private exponent of the keys
93+
pub fn d(&self) -> &[$sk_byte] {
94+
self.d
95+
}
96+
}
97+
};
9498
}
9599

100+
impl_var_len_private_key!(u8);
101+
#[cfg(feature = "check-secret-independence")]
102+
impl_var_len_private_key!(libcrux_secrets::U8);
103+
96104
const E_BITS: u32 = 17;
97105
const E: [u8; 3] = [1, 0, 1];
98106

@@ -104,9 +112,19 @@ fn hacl_hash_alg(alg: crate::DigestAlgorithm) -> libcrux_hacl_rs::streaming_type
104112
}
105113
}
106114

115+
#[cfg(feature = "check-secret-independence")]
116+
impl<'a, const LEN: usize> libcrux_secrets::DeclassifyRef
117+
for &'a PrivateKey<LEN, libcrux_secrets::U8>
118+
{
119+
type DeclassifiedRef = &'a PrivateKey<LEN, u8>;
120+
fn declassify_ref(self) -> Self::DeclassifiedRef {
121+
unsafe { core::mem::transmute(self) }
122+
}
123+
}
124+
107125
// next up: generate these in macros
108126

109-
macro_rules! impl_rsapss {
127+
macro_rules! impl_rsapss_base {
110128
($sign_fn:ident, $verify_fn:ident, $bits:literal, $bytes:literal) => {
111129
impl From<[u8; $bytes]> for PublicKey<$bytes> {
112130
fn from(n: [u8; $bytes]) -> Self {
@@ -126,43 +144,12 @@ macro_rules! impl_rsapss {
126144
}
127145
}
128146

129-
impl PrivateKey<$bytes> {
130-
/// Constructor for the private key based on `n` and `d`.
131-
pub fn from_components(n: [u8; $bytes], d: [u8; $bytes]) -> Self {
132-
Self { pk: n.into(), d }
133-
}
134-
135-
/// Returns the public key of the private key.
136-
pub fn pk(&self) -> &PublicKey<$bytes> {
137-
&self.pk
138-
}
139-
140-
/// Returns the slice-based private key
141-
pub fn as_var_len(&self) -> VarLenPrivateKey<'_> {
142-
VarLenPrivateKey {
143-
pk: self.pk.as_var_len(),
144-
d: &self.d,
145-
}
146-
}
147-
148-
/// Returns the private exponent as bytes.
149-
pub fn d(&self) -> &[u8; $bytes] {
150-
&self.d
151-
}
152-
}
153-
154147
impl<'a> From<&'a PublicKey<$bytes>> for VarLenPublicKey<'a> {
155148
fn from(value: &'a PublicKey<$bytes>) -> Self {
156149
value.as_var_len()
157150
}
158151
}
159152

160-
impl<'a> From<&'a PrivateKey<$bytes>> for VarLenPrivateKey<'a> {
161-
fn from(value: &'a PrivateKey<$bytes>) -> Self {
162-
value.as_var_len()
163-
}
164-
}
165-
166153
/// Computes a signature over `msg` using `sk` and writes it to `sig`.
167154
/// Returns `Ok(())` on success.
168155
///
@@ -172,7 +159,7 @@ macro_rules! impl_rsapss {
172159
/// - `salt_len` exceeds `u32::MAX - alg.hash_len() - 8`
173160
pub fn $sign_fn(
174161
alg: crate::DigestAlgorithm,
175-
sk: &PrivateKey<$bytes>,
162+
sk: &PrivateKey<$bytes, u8>,
176163
msg: &[u8],
177164
salt: &[u8],
178165
sig: &mut [u8; $bytes],
@@ -199,11 +186,61 @@ macro_rules! impl_rsapss {
199186
};
200187
}
201188

202-
impl_rsapss!(sign_2048, verify_2048, 2048, 256);
203-
impl_rsapss!(sign_3072, verify_3072, 3072, 384);
204-
impl_rsapss!(sign_4096, verify_4096, 4096, 512);
205-
impl_rsapss!(sign_6144, verify_6144, 6144, 768);
206-
impl_rsapss!(sign_8192, verify_8192, 8192, 1024);
189+
macro_rules! impl_rsapss_private {
190+
($sign_fn:ident, $verify_fn:ident, $bits:literal, $bytes:literal, $sk_byte:ty) => {
191+
impl PrivateKey<$bytes, $sk_byte> {
192+
/// Constructor for the private key based on `n` and `d`.
193+
pub fn from_components(n: [u8; $bytes], d: [$sk_byte; $bytes]) -> Self {
194+
Self { pk: n.into(), d }
195+
}
196+
197+
/// Returns the public key of the private key.
198+
pub fn pk(&self) -> &PublicKey<$bytes> {
199+
&self.pk
200+
}
201+
202+
/// Returns the slice-based private key
203+
pub fn as_var_len(&self) -> VarLenPrivateKey<'_, $sk_byte> {
204+
VarLenPrivateKey {
205+
pk: self.pk.as_var_len(),
206+
d: &self.d,
207+
}
208+
}
209+
210+
/// Returns the private exponent as bytes.
211+
pub fn d(&self) -> &[$sk_byte; $bytes] {
212+
&self.d
213+
}
214+
}
215+
216+
impl<'a> From<&'a PrivateKey<$bytes, $sk_byte>> for VarLenPrivateKey<'a, $sk_byte> {
217+
fn from(value: &'a PrivateKey<$bytes, $sk_byte>) -> Self {
218+
value.as_var_len()
219+
}
220+
}
221+
};
222+
}
223+
224+
impl_rsapss_base!(sign_2048, verify_2048, 2048, 256);
225+
impl_rsapss_base!(sign_3072, verify_3072, 3072, 384);
226+
impl_rsapss_base!(sign_4096, verify_4096, 4096, 512);
227+
impl_rsapss_base!(sign_6144, verify_6144, 6144, 768);
228+
impl_rsapss_base!(sign_8192, verify_8192, 8192, 1024);
229+
impl_rsapss_private!(sign_2048, verify_2048, 2048, 256, u8);
230+
impl_rsapss_private!(sign_3072, verify_3072, 3072, 384, u8);
231+
impl_rsapss_private!(sign_4096, verify_4096, 4096, 512, u8);
232+
impl_rsapss_private!(sign_6144, verify_6144, 6144, 768, u8);
233+
impl_rsapss_private!(sign_8192, verify_8192, 8192, 1024, u8);
234+
235+
#[cfg(feature = "check-secret-independence")]
236+
mod secret_integer_impl {
237+
use super::*;
238+
impl_rsapss_private!(sign_2048, verify_2048, 2048, 256, libcrux_secrets::U8);
239+
impl_rsapss_private!(sign_3072, verify_3072, 3072, 384, libcrux_secrets::U8);
240+
impl_rsapss_private!(sign_4096, verify_4096, 4096, 512, libcrux_secrets::U8);
241+
impl_rsapss_private!(sign_6144, verify_6144, 6144, 768, libcrux_secrets::U8);
242+
impl_rsapss_private!(sign_8192, verify_8192, 8192, 1024, libcrux_secrets::U8);
243+
}
207244

208245
/// Computes a signature over `msg` using `sk` and writes it to `sig`.
209246
/// Returns `Ok(())` on success.
@@ -215,7 +252,7 @@ impl_rsapss!(sign_8192, verify_8192, 8192, 1024);
215252
/// - the length of `sig` does not match the length of `sk`
216253
pub fn sign(
217254
alg: crate::DigestAlgorithm,
218-
sk: &VarLenPrivateKey<'_>,
255+
sk: &VarLenPrivateKey<'_, u8>,
219256
msg: &[u8],
220257
salt: &[u8],
221258
sig: &mut [u8],
@@ -270,7 +307,7 @@ pub fn verify(
270307
/// - follows from the check that messages are shorter than `u32::MAX`.
271308
pub fn sign_varlen(
272309
alg: crate::DigestAlgorithm,
273-
sk: &VarLenPrivateKey<'_>,
310+
sk: &VarLenPrivateKey<'_, u8>,
274311
msg: &[u8],
275312
salt: &[u8],
276313
sig: &mut [u8],

0 commit comments

Comments
 (0)