Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 81 additions & 37 deletions crates/core_simd/src/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,33 @@
//! Types representing
#![allow(non_camel_case_types)]

#[cfg_attr(
not(all(target_arch = "x86_64", target_feature = "avx512f")),
path = "masks/full_masks.rs"
)]
#[cfg_attr(
all(target_arch = "x86_64", target_feature = "avx512f"),
path = "masks/bitmask.rs"
)]
mod mask_impl;

use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount};
use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount};
use core::cmp::Ordering;
use core::{fmt, mem};

pub(crate) trait FixEndianness {
fn fix_endianness(self) -> Self;
}

macro_rules! impl_fix_endianness {
{ $($int:ty),* } => {
$(
impl FixEndianness for $int {
#[inline(always)]
fn fix_endianness(self) -> Self {
if cfg!(target_endian = "big") {
<$int>::reverse_bits(self)
} else {
self
}
}
}
)*
}
}

impl_fix_endianness! { u8, u16, u32, u64 }

mod sealed {
use super::*;

Expand Down Expand Up @@ -109,7 +122,7 @@ impl_element! { isize, usize }
/// and/or Rust versions, and code should not assume that it is equivalent to
/// `[T; N]`.
#[repr(transparent)]
pub struct Mask<T, const N: usize>(mask_impl::Mask<T, N>)
pub struct Mask<T, const N: usize>(Simd<T, N>)
where
T: MaskElement,
LaneCount<N>: SupportedLaneCount;
Expand Down Expand Up @@ -141,7 +154,7 @@ where
#[inline]
#[rustc_const_unstable(feature = "portable_simd", issue = "86656")]
pub const fn splat(value: bool) -> Self {
Self(mask_impl::Mask::splat(value))
Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
}

/// Converts an array of bools to a SIMD mask.
Expand Down Expand Up @@ -192,8 +205,8 @@ where
// Safety: the caller must confirm this invariant
unsafe {
core::intrinsics::assume(<T as Sealed>::valid(value));
Self(mask_impl::Mask::from_simd_unchecked(value))
}
Self(value)
}

/// Converts a vector of integers to a mask, where 0 represents `false` and -1
Expand All @@ -215,14 +228,15 @@ where
#[inline]
#[must_use = "method returns a new vector and does not mutate the original value"]
pub fn to_simd(self) -> Simd<T, N> {
self.0.to_simd()
self.0
}

/// Converts the mask to a mask of any other element size.
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn cast<U: MaskElement>(self) -> Mask<U, N> {
Mask(self.0.convert())
// Safety: mask elements are integers
unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) }
}

/// Tests the value of the specified element.
Expand All @@ -233,7 +247,7 @@ where
#[must_use = "method returns a new bool and does not mutate the original value"]
pub unsafe fn test_unchecked(&self, index: usize) -> bool {
// Safety: the caller must confirm this invariant
unsafe { self.0.test_unchecked(index) }
unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) }
}

/// Tests the value of the specified element.
Expand All @@ -244,9 +258,7 @@ where
#[must_use = "method returns a new bool and does not mutate the original value"]
#[track_caller]
pub fn test(&self, index: usize) -> bool {
assert!(index < N, "element index out of range");
// Safety: the element index has been checked
unsafe { self.test_unchecked(index) }
T::eq(self.0[index], T::TRUE)
}

/// Sets the value of the specified element.
Expand All @@ -257,7 +269,7 @@ where
pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) {
// Safety: the caller must confirm this invariant
unsafe {
self.0.set_unchecked(index, value);
*self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE }
}
}

Expand All @@ -268,25 +280,23 @@ where
#[inline]
#[track_caller]
pub fn set(&mut self, index: usize, value: bool) {
assert!(index < N, "element index out of range");
// Safety: the element index has been checked
unsafe {
self.set_unchecked(index, value);
}
self.0[index] = if value { T::TRUE } else { T::FALSE }
}

/// Returns true if any element is set, or false otherwise.
#[inline]
#[must_use = "method returns a new bool and does not mutate the original value"]
pub fn any(self) -> bool {
self.0.any()
// Safety: `self` is a mask vector
unsafe { core::intrinsics::simd::simd_reduce_any(self.0) }
}

/// Returns true if all elements are set, or false otherwise.
#[inline]
#[must_use = "method returns a new bool and does not mutate the original value"]
pub fn all(self) -> bool {
self.0.all()
// Safety: `self` is a mask vector
unsafe { core::intrinsics::simd::simd_reduce_all(self.0) }
}

/// Creates a bitmask from a mask.
Expand All @@ -296,7 +306,38 @@ where
#[inline]
#[must_use = "method returns a new integer and does not mutate the original value"]
pub fn to_bitmask(self) -> u64 {
self.0.to_bitmask_integer()
#[inline]
unsafe fn to_bitmask_impl<T, U: FixEndianness, const M: usize, const N: usize>(
mask: Mask<T, N>,
) -> U
where
T: MaskElement,
LaneCount<M>: SupportedLaneCount,
LaneCount<N>: SupportedLaneCount,
{
let resized = mask.resize::<M>(false);

// Safety: `resized` is an integer vector with length M, which must match T
let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) };

// LLVM assumes bit order should match endianness
bitmask.fix_endianness()
}

// TODO modify simd_bitmask to zero-extend output, making this unnecessary
if N <= 8 {
// Safety: bitmask matches length
unsafe { to_bitmask_impl::<T, u8, 8, N>(self) as u64 }
} else if N <= 16 {
// Safety: bitmask matches length
unsafe { to_bitmask_impl::<T, u16, 16, N>(self) as u64 }
} else if N <= 32 {
// Safety: bitmask matches length
unsafe { to_bitmask_impl::<T, u32, 32, N>(self) as u64 }
} else {
// Safety: bitmask matches length
unsafe { to_bitmask_impl::<T, u64, 64, N>(self) }
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add a todo!() for N > 64?

}

/// Creates a mask from a bitmask.
Expand All @@ -306,7 +347,7 @@ where
#[inline]
#[must_use = "method returns a new mask and does not mutate the original value"]
pub fn from_bitmask(bitmask: u64) -> Self {
Self(mask_impl::Mask::from_bitmask_integer(bitmask))
Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE)))
}

/// Finds the index of the first set element.
Expand Down Expand Up @@ -450,7 +491,8 @@ where
type Output = Self;
#[inline]
fn bitand(self, rhs: Self) -> Self {
Self(self.0 & rhs.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can just stay unchanged -- same for all the other bitwise ops.

// Safety: `self` is an integer vector
unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
}
}

Expand Down Expand Up @@ -486,7 +528,8 @@ where
type Output = Self;
#[inline]
fn bitor(self, rhs: Self) -> Self {
Self(self.0 | rhs.0)
// Safety: `self` is an integer vector
unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
}
}

Expand Down Expand Up @@ -522,7 +565,8 @@ where
type Output = Self;
#[inline]
fn bitxor(self, rhs: Self) -> Self::Output {
Self(self.0 ^ rhs.0)
// Safety: `self` is an integer vector
unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
}
}

Expand Down Expand Up @@ -558,7 +602,7 @@ where
type Output = Mask<T, N>;
#[inline]
fn not(self) -> Self::Output {
Self(!self.0)
Self::splat(true) ^ self
}
}

Expand All @@ -569,7 +613,7 @@ where
{
#[inline]
fn bitand_assign(&mut self, rhs: Self) {
self.0 = self.0 & rhs.0;
*self = *self & rhs;
}
}

Expand All @@ -591,7 +635,7 @@ where
{
#[inline]
fn bitor_assign(&mut self, rhs: Self) {
self.0 = self.0 | rhs.0;
*self = *self | rhs;
}
}

Expand All @@ -613,7 +657,7 @@ where
{
#[inline]
fn bitxor_assign(&mut self, rhs: Self) {
self.0 = self.0 ^ rhs.0;
*self = *self ^ rhs;
}
}

Expand Down
Loading
Loading