Skip to content

Commit b8df2c9

Browse files
committed
Remove mask backing implementations
1 parent 1146930 commit b8df2c9

File tree

3 files changed

+59
-557
lines changed

3 files changed

+59
-557
lines changed

crates/core_simd/src/masks.rs

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
//! Types representing
33
#![allow(non_camel_case_types)]
44

5-
#[path = "masks/full_masks.rs"]
6-
mod mask_impl;
7-
85
use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount};
96
use core::cmp::Ordering;
107
use core::{fmt, mem};
@@ -101,7 +98,7 @@ impl_element! { isize, usize }
10198
/// The layout of this type is equivalent to `Simd<T, N>`, but elements
10299
/// are guaranteed to be either 0 or -1.
103100
#[repr(transparent)]
104-
pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
101+
pub struct Mask<T, const N: usize>(Simd<T, N>)
105102
where
106103
T: MaskElement,
107104
LaneCount<N>: SupportedLaneCount;
@@ -133,7 +130,7 @@ where
133130
#[inline]
134131
#[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
135132
pub const fn splat(value: bool) -> Self {
136-
Self(mask_impl::Mask::splat(value))
133+
Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
137134
}
138135

139136
/// Converts an array of bools to a SIMD mask.
@@ -184,8 +181,8 @@ where
184181
// Safety: the caller must confirm this invariant
185182
unsafe {
186183
core::intrinsics::assume(<T as Sealed>::valid(value));
187-
Self(mask_impl::Mask::from_simd_unchecked(value))
188184
}
185+
Self(value)
189186
}
190187

191188
/// Converts a vector of integers to a mask, where 0 represents `false` and -1
@@ -207,14 +204,15 @@ where
207204
#[inline]
208205
#[must_use = "method returns a new vector and does not mutate the original value"]
209206
pub fn to_simd(self) -> Simd<T, N> {
210-
self.0.to_simd()
207+
self.0
211208
}
212209

213210
/// Converts the mask to a mask of any other element size.
214211
#[inline]
215212
#[must_use = "method returns a new mask and does not mutate the original value"]
216213
pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
217-
Mask(self.0.convert())
214+
// Safety: mask elements are integers
215+
unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
218216
}
219217

220218
/// Tests the value of the specified element.
@@ -225,7 +223,7 @@ where
225223
#[must_use = "method returns a new bool and does not mutate the original value"]
226224
pub unsafe fn test_unchecked(&self, index: usize) -> bool {
227225
// Safety: the caller must confirm this invariant
228-
unsafe { self.0.test_unchecked(index) }
226+
unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
229227
}
230228

231229
/// Tests the value of the specified element.
@@ -236,9 +234,7 @@ where
236234
#[must_use = "method returns a new bool and does not mutate the original value"]
237235
#[track_caller]
238236
pub fn test(&self, index: usize) -> bool {
239-
assert!(index < N, "element index out of range");
240-
// Safety: the element index has been checked
241-
unsafe { self.test_unchecked(index) }
237+
T::eq(self.0[index], T::TRUE)
242238
}
243239

244240
/// Sets the value of the specified element.
@@ -249,7 +245,7 @@ where
249245
pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
250246
// Safety: the caller must confirm this invariant
251247
unsafe {
252-
self.0.set_unchecked(index, value);
248+
*self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
253249
}
254250
}
255251

@@ -260,25 +256,23 @@ where
260256
#[inline]
261257
#[track_caller]
262258
pub fn set(&mut self, index: usize, value: bool) {
263-
assert!(index < N, "element index out of range");
264-
// Safety: the element index has been checked
265-
unsafe {
266-
self.set_unchecked(index, value);
267-
}
259+
self.0[index] = if value { T::TRUE } else { T::FALSE }
268260
}
269261

270262
/// Returns true if any element is set, or false otherwise.
271263
#[inline]
272264
#[must_use = "method returns a new bool and does not mutate the original value"]
273265
pub fn any(self) -> bool {
274-
self.0.any()
266+
// Safety: `self` is a mask vector
267+
unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
275268
}
276269

277270
/// Returns true if all elements are set, or false otherwise.
278271
#[inline]
279272
#[must_use = "method returns a new bool and does not mutate the original value"]
280273
pub fn all(self) -> bool {
281-
self.0.all()
274+
// Safety: `self` is a mask vector
275+
unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
282276
}
283277

284278
/// Creates a bitmask from a mask.
@@ -288,7 +282,40 @@ where
288282
#[inline]
289283
#[must_use = "method returns a new integer and does not mutate the original value"]
290284
pub fn to_bitmask(self) -> u64 {
291-
self.0.to_bitmask_integer()
285+
#[inline]
286+
unsafe fn to_bitmask_impl<T, U, const M: usize, const N: usize>(mask: Mask<T, N>) -> U
287+
where
288+
T: MaskElement,
289+
LaneCount<M>: SupportedLaneCount,
290+
LaneCount<N>: SupportedLaneCount,
291+
{
292+
let resized = mask.resize::<M>(false);
293+
294+
// Safety: `resized` is an integer vector with length M, which must match T
295+
unsafe { core::intrinsics::simd::simd_bitmask(resized.0) }
296+
}
297+
298+
// TODO modify simd_bitmask to zero-extend output, making this unnecessary
299+
let bitmask = if N <= 8 {
300+
// Safety: bitmask matches length
301+
unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
302+
} else if N <= 16 {
303+
// Safety: bitmask matches length
304+
unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
305+
} else if N <= 32 {
306+
// Safety: bitmask matches length
307+
unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
308+
} else {
309+
// Safety: bitmask matches length
310+
unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
311+
};
312+
313+
// LLVM assumes bit order should match endianness
314+
if cfg!(target_endian = "big") {
315+
bitmask.reverse_bits() >> (64 - N.min(64))
316+
} else {
317+
bitmask
318+
}
292319
}
293320

294321
/// Creates a mask from a bitmask.
@@ -298,7 +325,7 @@ where
298325
#[inline]
299326
#[must_use = "method returns a new mask and does not mutate the original value"]
300327
pub fn from_bitmask(bitmask: u64) -> Self {
301-
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
328+
Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
302329
}
303330

304331
/// Finds the index of the first set element.
@@ -442,7 +469,8 @@ where
442469
type Output = Self;
443470
#[inline]
444471
fn bitand(self, rhs: Self) -> Self {
445-
Self(self.0 & rhs.0)
472+
// Safety: `self` is an integer vector
473+
unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
446474
}
447475
}
448476

@@ -478,7 +506,8 @@ where
478506
type Output = Self;
479507
#[inline]
480508
fn bitor(self, rhs: Self) -> Self {
481-
Self(self.0 | rhs.0)
509+
// Safety: `self` is an integer vector
510+
unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
482511
}
483512
}
484513

@@ -514,7 +543,8 @@ where
514543
type Output = Self;
515544
#[inline]
516545
fn bitxor(self, rhs: Self) -> Self::Output {
517-
Self(self.0 ^ rhs.0)
546+
// Safety: `self` is an integer vector
547+
unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
518548
}
519549
}
520550

@@ -550,7 +580,7 @@ where
550580
type Output = Mask<T, N>;
551581
#[inline]
552582
fn not(self) -> Self::Output {
553-
Self(!self.0)
583+
Self::splat(true) ^ self
554584
}
555585
}
556586

@@ -561,7 +591,7 @@ where
561591
{
562592
#[inline]
563593
fn bitand_assign(&mut self, rhs: Self) {
564-
self.0 = self.0 & rhs.0;
594+
*self = *self & rhs;
565595
}
566596
}
567597

@@ -583,7 +613,7 @@ where
583613
{
584614
#[inline]
585615
fn bitor_assign(&mut self, rhs: Self) {
586-
self.0 = self.0 | rhs.0;
616+
*self = *self | rhs;
587617
}
588618
}
589619

@@ -605,7 +635,7 @@ where
605635
{
606636
#[inline]
607637
fn bitxor_assign(&mut self, rhs: Self) {
608-
self.0 = self.0 ^ rhs.0;
638+
*self = *self ^ rhs;
609639
}
610640
}
611641

0 commit comments

Comments
 (0)