Skip to content

Commit a581615

Browse files
committed
Fix endianness correction
1 parent e75a8d8 commit a581615

File tree

3 files changed

+58
-26
lines changed

3 files changed

+58
-26
lines changed

crates/core_simd/src/masks.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,23 @@ where
296296
}
297297

298298
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
299-
let bitmask = if N <= 8 {
299+
let (bitmask, n_used) = if N <= 8 {
300300
// Safety: bitmask matches length
301-
unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
301+
unsafe { (to_bitmask_impl::<T, u8, 8, N>(self) as u64, 8) }
302302
} else if N <= 16 {
303303
// Safety: bitmask matches length
304-
unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
304+
unsafe { (to_bitmask_impl::<T, u16, 16, N>(self) as u64, 16) }
305305
} else if N <= 32 {
306306
// Safety: bitmask matches length
307-
unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
307+
unsafe { (to_bitmask_impl::<T, u32, 32, N>(self) as u64, 32) }
308308
} else {
309309
// Safety: bitmask matches length
310-
unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
310+
unsafe { (to_bitmask_impl::<T, u64, 64, N>(self), 64) }
311311
};
312312

313313
// LLVM assumes bit order should match endianness
314314
if cfg!(target_endian = "big") {
315-
bitmask.reverse_bits() >> (64 - N.min(64))
315+
bitmask.reverse_bits() >> (n_used - N.min(64))
316316
} else {
317317
bitmask
318318
}

crates/core_simd/src/select.rs

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,22 +82,15 @@ where
8282
assert!(N <= 64, "number of elements can't be greater than 64");
8383
}
8484

85-
// LLVM assumes bit order should match endianness
86-
let bitmask = if cfg!(target_endian = "big") {
87-
let rev = self.reverse_bits();
88-
if N < 64 {
89-
// Shift things back to the right
90-
rev >> (64 - N)
91-
} else {
92-
rev
93-
}
94-
} else {
95-
self
96-
};
97-
9885
#[inline]
99-
unsafe fn select_impl<T, U, const M: usize, const N: usize>(
86+
unsafe fn select_impl<
87+
T,
88+
U: core::ops::Shr<usize, Output = U>,
89+
const M: usize,
90+
const N: usize,
91+
>(
10092
bitmask: U,
93+
bitmask_reversed: U,
10194
true_values: Simd<T, N>,
10295
false_values: Simd<T, N>,
10396
) -> Simd<T, N>
@@ -110,6 +103,13 @@ where
110103
let true_values = true_values.resize::<M>(default);
111104
let false_values = false_values.resize::<M>(default);
112105

106+
// LLVM assumes bit order should match endianness
107+
let bitmask = if cfg!(target_endian = "big") {
108+
bitmask_reversed >> (M - N)
109+
} else {
110+
bitmask
111+
};
112+
113113
// Safety: the caller guarantees that the size of U matches M
114114
let selected = unsafe {
115115
core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values)
@@ -120,17 +120,49 @@ where
120120

121121
// TODO modify simd_bitmask_select to truncate input, making this unnecessary
122122
if N <= 8 {
123+
let bitmask = self as u8;
123124
// Safety: bitmask matches length
124-
unsafe { select_impl::<T, u8, 8, N>(bitmask as u8, true_values, false_values) }
125+
unsafe {
126+
select_impl::<T, u8, 8, N>(
127+
bitmask,
128+
bitmask.reverse_bits(),
129+
true_values,
130+
false_values,
131+
)
132+
}
125133
} else if N <= 16 {
134+
let bitmask = self as u16;
126135
// Safety: bitmask matches length
127-
unsafe { select_impl::<T, u16, 16, N>(bitmask as u16, true_values, false_values) }
136+
unsafe {
137+
select_impl::<T, u16, 16, N>(
138+
bitmask,
139+
bitmask.reverse_bits(),
140+
true_values,
141+
false_values,
142+
)
143+
}
128144
} else if N <= 32 {
145+
let bitmask = self as u32;
129146
// Safety: bitmask matches length
130-
unsafe { select_impl::<T, u32, 32, N>(bitmask as u32, true_values, false_values) }
147+
unsafe {
148+
select_impl::<T, u32, 32, N>(
149+
bitmask,
150+
bitmask.reverse_bits(),
151+
true_values,
152+
false_values,
153+
)
154+
}
131155
} else {
156+
let bitmask = self as u64;
132157
// Safety: bitmask matches length
133-
unsafe { select_impl::<T, u64, 64, N>(bitmask, true_values, false_values) }
158+
unsafe {
159+
select_impl::<T, u64, 64, N>(
160+
bitmask,
161+
bitmask.reverse_bits(),
162+
true_values,
163+
false_values,
164+
)
165+
}
134166
}
135167
}
136168
}

crates/core_simd/src/swizzle_dyn.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd<u8, 16>, idxs: Simd<u8, 16>) -> S
139139
#[inline]
140140
#[allow(clippy::let_and_return)]
141141
unsafe fn avx2_pshufb(bytes: Simd<u8, 32>, idxs: Simd<u8, 32>) -> Simd<u8, 32> {
142-
use crate::simd::{cmp::SimdPartialOrd, Select};
142+
use crate::simd::{Select, cmp::SimdPartialOrd};
143143
#[cfg(target_arch = "x86")]
144144
use core::arch::x86;
145145
#[cfg(target_arch = "x86_64")]
@@ -200,7 +200,7 @@ fn zeroing_idxs<const N: usize>(idxs: Simd<u8, N>) -> Simd<u8, N>
200200
where
201201
LaneCount<N>: SupportedLaneCount,
202202
{
203-
use crate::simd::{cmp::SimdPartialOrd, Select};
203+
use crate::simd::{Select, cmp::SimdPartialOrd};
204204
idxs.simd_lt(Simd::splat(N as u8))
205205
.select(idxs, Simd::splat(u8::MAX))
206206
}

0 commit comments

Comments
 (0)