Skip to content

Commit ee4df4a

Browse files
authored
crypto - threaded K12: separate context computation from thread spawning (#25793)
* threaded K12: separate context computation from thread spawning Compute all contexts and store them in a pre-allocated array, then spawn threads using the pre-computed contexts. This ensures each context is fully materialized in memory with the correct values before any thread tries to access it. * kt128: unroll the permutation rounds only twice This appears to deliver the best performance thanks to improved cache utilization, and it’s consistent with what we already do for SHA3.
1 parent afdd043 commit ee4df4a

File tree

1 file changed

+100
-89
lines changed

1 file changed

+100
-89
lines changed

lib/std/crypto/kangarootwelve.zig

Lines changed: 100 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -230,58 +230,61 @@ fn keccakP1600timesN(comptime N: usize, states: *[5][5]@Vector(N, u64)) void {
230230
break :blk offsets;
231231
};
232232

233-
inline for (RC) |rc| {
234-
// θ (theta)
235-
var C: [5]@Vector(N, u64) = undefined;
236-
inline for (0..5) |x| {
237-
C[x] = states[x][0] ^ states[x][1] ^ states[x][2] ^ states[x][3] ^ states[x][4];
238-
}
233+
var round: usize = 0;
234+
while (round < 12) : (round += 2) {
235+
inline for (0..2) |i| {
236+
// θ (theta)
237+
var C: [5]@Vector(N, u64) = undefined;
238+
inline for (0..5) |x| {
239+
C[x] = states[x][0] ^ states[x][1] ^ states[x][2] ^ states[x][3] ^ states[x][4];
240+
}
239241

240-
var D: [5]@Vector(N, u64) = undefined;
241-
inline for (0..5) |x| {
242-
D[x] = C[(x + 4) % 5] ^ rol64Vec(N, C[(x + 1) % 5], 1);
243-
}
242+
var D: [5]@Vector(N, u64) = undefined;
243+
inline for (0..5) |x| {
244+
D[x] = C[(x + 4) % 5] ^ rol64Vec(N, C[(x + 1) % 5], 1);
245+
}
244246

245-
// Apply D to all lanes
246-
inline for (0..5) |x| {
247-
states[x][0] ^= D[x];
248-
states[x][1] ^= D[x];
249-
states[x][2] ^= D[x];
250-
states[x][3] ^= D[x];
251-
states[x][4] ^= D[x];
252-
}
247+
// Apply D to all lanes
248+
inline for (0..5) |x| {
249+
states[x][0] ^= D[x];
250+
states[x][1] ^= D[x];
251+
states[x][2] ^= D[x];
252+
states[x][3] ^= D[x];
253+
states[x][4] ^= D[x];
254+
}
253255

254-
// ρ (rho) and π (pi) - optimized with pre-computed offsets
255-
var current = states[1][0];
256-
var px: usize = 1;
257-
var py: usize = 0;
258-
inline for (rho_offsets) |rot| {
259-
const next_y = (2 * px + 3 * py) % 5;
260-
const next = states[py][next_y];
261-
states[py][next_y] = rol64Vec(N, current, rot);
262-
current = next;
263-
px = py;
264-
py = next_y;
265-
}
256+
// ρ (rho) and π (pi) - optimized with pre-computed offsets
257+
var current = states[1][0];
258+
var px: usize = 1;
259+
var py: usize = 0;
260+
inline for (rho_offsets) |rot| {
261+
const next_y = (2 * px + 3 * py) % 5;
262+
const next = states[py][next_y];
263+
states[py][next_y] = rol64Vec(N, current, rot);
264+
current = next;
265+
px = py;
266+
py = next_y;
267+
}
266268

267-
// χ (chi) - optimized with better register usage
268-
inline for (0..5) |y| {
269-
const t0 = states[0][y];
270-
const t1 = states[1][y];
271-
const t2 = states[2][y];
272-
const t3 = states[3][y];
273-
const t4 = states[4][y];
274-
275-
states[0][y] = t0 ^ (~t1 & t2);
276-
states[1][y] = t1 ^ (~t2 & t3);
277-
states[2][y] = t2 ^ (~t3 & t4);
278-
states[3][y] = t3 ^ (~t4 & t0);
279-
states[4][y] = t4 ^ (~t0 & t1);
280-
}
269+
// χ (chi) - optimized with better register usage
270+
inline for (0..5) |y| {
271+
const t0 = states[0][y];
272+
const t1 = states[1][y];
273+
const t2 = states[2][y];
274+
const t3 = states[3][y];
275+
const t4 = states[4][y];
276+
277+
states[0][y] = t0 ^ (~t1 & t2);
278+
states[1][y] = t1 ^ (~t2 & t3);
279+
states[2][y] = t2 ^ (~t3 & t4);
280+
states[3][y] = t3 ^ (~t4 & t0);
281+
states[4][y] = t4 ^ (~t0 & t1);
282+
}
281283

282-
// ι (iota)
283-
const rc_splat: @Vector(N, u64) = @splat(rc);
284-
states[0][0] ^= rc_splat;
284+
// ι (iota)
285+
const rc_splat: @Vector(N, u64) = @splat(RC[round + i]);
286+
states[0][0] ^= rc_splat;
287+
}
285288
}
286289
}
287290

@@ -323,46 +326,49 @@ fn keccakP(state: *[200]u8) void {
323326
}
324327

325328
// Apply 12 rounds
326-
inline for (RC) |rc| {
327-
// θ
328-
var C: [5]u64 = undefined;
329-
inline for (0..5) |x| {
330-
C[x] = lanes[x][0] ^ lanes[x][1] ^ lanes[x][2] ^ lanes[x][3] ^ lanes[x][4];
331-
}
332-
var D: [5]u64 = undefined;
333-
inline for (0..5) |x| {
334-
D[x] = C[(x + 4) % 5] ^ std.math.rotl(u64, C[(x + 1) % 5], 1);
335-
}
336-
inline for (0..5) |x| {
337-
inline for (0..5) |y| {
338-
lanes[x][y] ^= D[x];
329+
var round: usize = 0;
330+
while (round < 12) : (round += 2) {
331+
inline for (0..2) |i| {
332+
// θ
333+
var C: [5]u64 = undefined;
334+
inline for (0..5) |x| {
335+
C[x] = lanes[x][0] ^ lanes[x][1] ^ lanes[x][2] ^ lanes[x][3] ^ lanes[x][4];
336+
}
337+
var D: [5]u64 = undefined;
338+
inline for (0..5) |x| {
339+
D[x] = C[(x + 4) % 5] ^ std.math.rotl(u64, C[(x + 1) % 5], 1);
340+
}
341+
inline for (0..5) |x| {
342+
inline for (0..5) |y| {
343+
lanes[x][y] ^= D[x];
344+
}
339345
}
340-
}
341346

342-
// ρ and π
343-
var current = lanes[1][0];
344-
var px: usize = 1;
345-
var py: usize = 0;
346-
inline for (0..24) |t| {
347-
const temp = lanes[py][(2 * px + 3 * py) % 5];
348-
const rot_amount = ((t + 1) * (t + 2) / 2) % 64;
349-
lanes[py][(2 * px + 3 * py) % 5] = std.math.rotl(u64, current, @as(u6, @intCast(rot_amount)));
350-
current = temp;
351-
const temp_x = py;
352-
py = (2 * px + 3 * py) % 5;
353-
px = temp_x;
354-
}
347+
// ρ and π
348+
var current = lanes[1][0];
349+
var px: usize = 1;
350+
var py: usize = 0;
351+
inline for (0..24) |t| {
352+
const temp = lanes[py][(2 * px + 3 * py) % 5];
353+
const rot_amount = ((t + 1) * (t + 2) / 2) % 64;
354+
lanes[py][(2 * px + 3 * py) % 5] = std.math.rotl(u64, current, @as(u6, @intCast(rot_amount)));
355+
current = temp;
356+
const temp_x = py;
357+
py = (2 * px + 3 * py) % 5;
358+
px = temp_x;
359+
}
355360

356-
// χ
357-
inline for (0..5) |y| {
358-
const T = [5]u64{ lanes[0][y], lanes[1][y], lanes[2][y], lanes[3][y], lanes[4][y] };
359-
inline for (0..5) |x| {
360-
lanes[x][y] = T[x] ^ (~T[(x + 1) % 5] & T[(x + 2) % 5]);
361+
// χ
362+
inline for (0..5) |y| {
363+
const T = [5]u64{ lanes[0][y], lanes[1][y], lanes[2][y], lanes[3][y], lanes[4][y] };
364+
inline for (0..5) |x| {
365+
lanes[x][y] = T[x] ^ (~T[(x + 1) % 5] & T[(x + 2) % 5]);
366+
}
361367
}
362-
}
363368

364-
// ι
365-
lanes[0][0] ^= rc;
369+
// ι
370+
lanes[0][0] ^= RC[round + i];
371+
}
366372
}
367373

368374
// Store lanes back to state
@@ -759,32 +765,37 @@ fn ktMultiThreaded(
759765
const all_scratch = try allocator.alloc(u8, thread_count * scratch_size);
760766
defer allocator.free(all_scratch);
761767

762-
var group: Io.Group = .init;
768+
const contexts = try allocator.alloc(LeafBatchContext, thread_count);
769+
defer allocator.free(contexts);
770+
763771
var leaves_assigned: usize = 0;
764-
var thread_idx: usize = 0;
772+
var context_count: usize = 0;
765773

766774
while (leaves_assigned < total_leaves) {
767775
const batch_count = @min(leaves_per_thread, total_leaves - leaves_assigned);
768776
const batch_start = chunk_size + leaves_assigned * chunk_size;
769777
const cvs_offset = leaves_assigned * cv_size;
770778

771-
const ctx = LeafBatchContext{
779+
contexts[context_count] = LeafBatchContext{
772780
.output_cvs = cvs[cvs_offset .. cvs_offset + batch_count * cv_size],
773781
.batch_start = batch_start,
774782
.batch_count = batch_count,
775783
.view = view,
776-
.scratch_buffer = all_scratch[thread_idx * scratch_size .. (thread_idx + 1) * scratch_size],
784+
.scratch_buffer = all_scratch[context_count * scratch_size .. (context_count + 1) * scratch_size],
777785
.total_len = total_len,
778786
};
779787

788+
leaves_assigned += batch_count;
789+
context_count += 1;
790+
}
791+
792+
var group: Io.Group = .init;
793+
for (contexts[0..context_count]) |ctx| {
780794
group.async(io, struct {
781795
fn process(c: LeafBatchContext) void {
782796
processLeafBatch(Variant, c);
783797
}
784798
}.process, .{ctx});
785-
786-
leaves_assigned += batch_count;
787-
thread_idx += 1;
788799
}
789800

790801
// Wait for all threads to complete

0 commit comments

Comments
 (0)