diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index ca1e3db8b46..7baa9647591 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -2,20 +2,33 @@ //! Types representing #![allow(non_camel_case_types)] -#[cfg_attr( - not(all(target_arch = "x86_64", target_feature = "avx512f")), - path = "masks/full_masks.rs" -)] -#[cfg_attr( - all(target_arch = "x86_64", target_feature = "avx512f"), - path = "masks/bitmask.rs" -)] -mod mask_impl; - -use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount}; +use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount}; use core::cmp::Ordering; use core::{fmt, mem}; +pub(crate) trait FixEndianness { + fn fix_endianness(self) -> Self; +} + +macro_rules! impl_fix_endianness { + { $($int:ty),* } => { + $( + impl FixEndianness for $int { + #[inline(always)] + fn fix_endianness(self) -> Self { + if cfg!(target_endian = "big") { + <$int>::reverse_bits(self) + } else { + self + } + } + } + )* + } +} + +impl_fix_endianness! { u8, u16, u32, u64 } + mod sealed { use super::*; @@ -109,7 +122,7 @@ impl_element! { isize, usize } /// and/or Rust versions, and code should not assume that it is equivalent to /// `[T; N]`. #[repr(transparent)] -pub struct Mask(mask_impl::Mask) +pub struct Mask(Simd) where T: MaskElement, LaneCount: SupportedLaneCount; @@ -141,7 +154,7 @@ where #[inline] #[rustc_const_unstable(feature = "portable_simd", issue = "86656")] pub const fn splat(value: bool) -> Self { - Self(mask_impl::Mask::splat(value)) + Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) } /// Converts an array of bools to a SIMD mask. @@ -192,8 +205,8 @@ where // Safety: the caller must confirm this invariant unsafe { core::intrinsics::assume(::valid(value)); - Self(mask_impl::Mask::from_simd_unchecked(value)) } + Self(value) } /// Converts a vector of integers to a mask, where 0 represents `false` and -1 @@ -215,14 +228,15 @@ where #[inline] #[must_use = "method returns a new vector and does not mutate the original value"] pub fn to_simd(self) -> Simd { - self.0.to_simd() + self.0 } /// Converts the mask to a mask of any other element size. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn cast(self) -> Mask { - Mask(self.0.convert()) + // Safety: mask elements are integers + unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) } } /// Tests the value of the specified element. @@ -233,7 +247,7 @@ where #[must_use = "method returns a new bool and does not mutate the original value"] pub unsafe fn test_unchecked(&self, index: usize) -> bool { // Safety: the caller must confirm this invariant - unsafe { self.0.test_unchecked(index) } + unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) } } /// Tests the value of the specified element. @@ -244,9 +258,7 @@ where #[must_use = "method returns a new bool and does not mutate the original value"] #[track_caller] pub fn test(&self, index: usize) -> bool { - assert!(index < N, "element index out of range"); - // Safety: the element index has been checked - unsafe { self.test_unchecked(index) } + T::eq(self.0[index], T::TRUE) } /// Sets the value of the specified element. @@ -257,7 +269,7 @@ where pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { // Safety: the caller must confirm this invariant unsafe { - self.0.set_unchecked(index, value); + *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE } } } @@ -268,35 +280,67 @@ where #[inline] #[track_caller] pub fn set(&mut self, index: usize, value: bool) { - assert!(index < N, "element index out of range"); - // Safety: the element index has been checked - unsafe { - self.set_unchecked(index, value); - } + self.0[index] = if value { T::TRUE } else { T::FALSE } } /// Returns true if any element is set, or false otherwise. #[inline] #[must_use = "method returns a new bool and does not mutate the original value"] pub fn any(self) -> bool { - self.0.any() + // Safety: `self` is a mask vector + unsafe { core::intrinsics::simd::simd_reduce_any(self.0) } } /// Returns true if all elements are set, or false otherwise. #[inline] #[must_use = "method returns a new bool and does not mutate the original value"] pub fn all(self) -> bool { - self.0.all() + // Safety: `self` is a mask vector + unsafe { core::intrinsics::simd::simd_reduce_all(self.0) } } /// Creates a bitmask from a mask. /// /// Each bit is set if the corresponding element in the mask is `true`. - /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64. #[inline] #[must_use = "method returns a new integer and does not mutate the original value"] pub fn to_bitmask(self) -> u64 { - self.0.to_bitmask_integer() + const { + assert!(N <= 64, "number of elements can't be greater than 64"); + } + + #[inline] + unsafe fn to_bitmask_impl( + mask: Mask, + ) -> U + where + T: MaskElement, + LaneCount: SupportedLaneCount, + LaneCount: SupportedLaneCount, + { + let resized = mask.resize::(false); + + // Safety: `resized` is an integer vector with length M, which must match T + let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) }; + + // LLVM assumes bit order should match endianness + bitmask.fix_endianness() + } + + // TODO modify simd_bitmask to zero-extend output, making this unnecessary + if N <= 8 { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) as u64 } + } else if N <= 16 { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) as u64 } + } else if N <= 32 { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) as u64 } + } else { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) } + } } /// Creates a mask from a bitmask. @@ -306,7 +350,7 @@ where #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn from_bitmask(bitmask: u64) -> Self { - Self(mask_impl::Mask::from_bitmask_integer(bitmask)) + Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE))) } /// Finds the index of the first set element. @@ -450,7 +494,8 @@ where type Output = Self; #[inline] fn bitand(self, rhs: Self) -> Self { - Self(self.0 & rhs.0) + // Safety: `self` is an integer vector + unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) } } } @@ -486,7 +531,8 @@ where type Output = Self; #[inline] fn bitor(self, rhs: Self) -> Self { - Self(self.0 | rhs.0) + // Safety: `self` is an integer vector + unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) } } } @@ -522,7 +568,8 @@ where type Output = Self; #[inline] fn bitxor(self, rhs: Self) -> Self::Output { - Self(self.0 ^ rhs.0) + // Safety: `self` is an integer vector + unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) } } } @@ -558,7 +605,7 @@ where type Output = Mask; #[inline] fn not(self) -> Self::Output { - Self(!self.0) + Self::splat(true) ^ self } } @@ -569,7 +616,7 @@ where { #[inline] fn bitand_assign(&mut self, rhs: Self) { - self.0 = self.0 & rhs.0; + *self = *self & rhs; } } @@ -591,7 +638,7 @@ where { #[inline] fn bitor_assign(&mut self, rhs: Self) { - self.0 = self.0 | rhs.0; + *self = *self | rhs; } } @@ -613,7 +660,7 @@ where { #[inline] fn bitxor_assign(&mut self, rhs: Self) { - self.0 = self.0 ^ rhs.0; + *self = *self ^ rhs; } } diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs deleted file mode 100644 index 83ee88c372a..00000000000 --- a/crates/core_simd/src/masks/bitmask.rs +++ /dev/null @@ -1,231 +0,0 @@ -#![allow(unused_imports)] -use super::MaskElement; -use crate::simd::{LaneCount, Simd, SupportedLaneCount}; -use core::marker::PhantomData; - -/// A mask where each lane is represented by a single bit. -#[repr(transparent)] -pub(crate) struct Mask( - as SupportedLaneCount>::BitMask, - PhantomData, -) -where - T: MaskElement, - LaneCount: SupportedLaneCount; - -impl Copy for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ -} - -impl Clone for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn clone(&self) -> Self { - *self - } -} - -impl PartialEq for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn eq(&self, other: &Self) -> bool { - self.0.as_ref() == other.0.as_ref() - } -} - -impl PartialOrd for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn partial_cmp(&self, other: &Self) -> Option { - self.0.as_ref().partial_cmp(other.0.as_ref()) - } -} - -impl Eq for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ -} - -impl Ord for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.0.as_ref().cmp(other.0.as_ref()) - } -} - -impl Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - #[rustc_const_unstable(feature = "portable_simd", issue = "86656")] - pub(crate) const fn splat(value: bool) -> Self { - Self( - if value { - as SupportedLaneCount>::FULL_BIT_MASK - } else { - as SupportedLaneCount>::EMPTY_BIT_MASK - }, - PhantomData, - ) - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) unsafe fn test_unchecked(&self, lane: usize) -> bool { - (self.0.as_ref()[lane / 8] >> (lane % 8)) & 0x1 > 0 - } - - #[inline] - pub(crate) unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { - unsafe { - self.0.as_mut()[lane / 8] ^= ((value ^ self.test_unchecked(lane)) as u8) << (lane % 8) - } - } - - #[inline] - #[must_use = "method returns a new vector and does not mutate the original value"] - pub(crate) fn to_simd(self) -> Simd { - unsafe { - core::intrinsics::simd::simd_select_bitmask( - self.0, - Simd::splat(T::TRUE), - Simd::splat(T::FALSE), - ) - } - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) unsafe fn from_simd_unchecked(value: Simd) -> Self { - unsafe { Self(core::intrinsics::simd::simd_bitmask(value), PhantomData) } - } - - #[inline] - pub(crate) fn to_bitmask_integer(self) -> u64 { - let mut bitmask = [0u8; 8]; - bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref()); - u64::from_ne_bytes(bitmask) - } - - #[inline] - pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { - let mut bytes = as SupportedLaneCount>::EMPTY_BIT_MASK; - let len = bytes.as_mut().len(); - bytes - .as_mut() - .copy_from_slice(&bitmask.to_ne_bytes()[..len]); - Self(bytes, PhantomData) - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) fn convert(self) -> Mask - where - U: MaskElement, - { - // Safety: bitmask layout does not depend on the element width - unsafe { core::mem::transmute_copy(&self) } - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn any(self) -> bool { - self != Self::splat(false) - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn all(self) -> bool { - self == Self::splat(true) - } -} - -impl core::ops::BitAnd for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, - as SupportedLaneCount>::BitMask: AsRef<[u8]> + AsMut<[u8]>, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn bitand(mut self, rhs: Self) -> Self { - for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) { - *l &= r; - } - self - } -} - -impl core::ops::BitOr for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, - as SupportedLaneCount>::BitMask: AsRef<[u8]> + AsMut<[u8]>, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn bitor(mut self, rhs: Self) -> Self { - for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) { - *l |= r; - } - self - } -} - -impl core::ops::BitXor for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn bitxor(mut self, rhs: Self) -> Self::Output { - for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) { - *l ^= r; - } - self - } -} - -impl core::ops::Not for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn not(mut self) -> Self::Output { - for x in self.0.as_mut() { - *x = !*x; - } - if N % 8 > 0 { - *self.0.as_mut().last_mut().unwrap() &= u8::MAX >> (8 - N % 8); - } - self - } -} diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs deleted file mode 100644 index 5ad2c1d1eaf..00000000000 --- a/crates/core_simd/src/masks/full_masks.rs +++ /dev/null @@ -1,297 +0,0 @@ -//! Masks that take up full SIMD vector registers. - -use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; - -#[repr(transparent)] -pub(crate) struct Mask(Simd) -where - T: MaskElement, - LaneCount: SupportedLaneCount; - -impl Copy for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ -} - -impl Clone for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn clone(&self) -> Self { - *self - } -} - -impl PartialEq for Mask -where - T: MaskElement + PartialEq, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} - -impl PartialOrd for Mask -where - T: MaskElement + PartialOrd, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Eq for Mask -where - T: MaskElement + Eq, - LaneCount: SupportedLaneCount, -{ -} - -impl Ord for Mask -where - T: MaskElement + Ord, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.0.cmp(&other.0) - } -} - -// Used for bitmask bit order workaround -pub(crate) trait ReverseBits { - // Reverse the least significant `n` bits of `self`. - // (Remaining bits must be 0.) - fn reverse_bits(self, n: usize) -> Self; -} - -macro_rules! impl_reverse_bits { - { $($int:ty),* } => { - $( - impl ReverseBits for $int { - #[inline(always)] - fn reverse_bits(self, n: usize) -> Self { - let rev = <$int>::reverse_bits(self); - let bitsize = size_of::<$int>() * 8; - if n < bitsize { - // Shift things back to the right - rev >> (bitsize - n) - } else { - rev - } - } - } - )* - } -} - -impl_reverse_bits! { u8, u16, u32, u64 } - -impl Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - #[rustc_const_unstable(feature = "portable_simd", issue = "86656")] - pub(crate) const fn splat(value: bool) -> Self { - Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) unsafe fn test_unchecked(&self, lane: usize) -> bool { - T::eq(self.0[lane], T::TRUE) - } - - #[inline] - pub(crate) unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { - self.0[lane] = if value { T::TRUE } else { T::FALSE } - } - - #[inline] - #[must_use = "method returns a new vector and does not mutate the original value"] - pub(crate) fn to_simd(self) -> Simd { - self.0 - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) unsafe fn from_simd_unchecked(value: Simd) -> Self { - Self(value) - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) fn convert(self) -> Mask - where - U: MaskElement, - { - // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type. - unsafe { Mask(core::intrinsics::simd::simd_cast(self.0)) } - } - - #[inline] - unsafe fn to_bitmask_impl(self) -> U - where - LaneCount: SupportedLaneCount, - { - let resized = self.to_simd().resize::(T::FALSE); - - // Safety: `resized` is an integer vector with length M, which must match T - let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized) }; - - // LLVM assumes bit order should match endianness - if cfg!(target_endian = "big") { - bitmask.reverse_bits(M) - } else { - bitmask - } - } - - #[inline] - unsafe fn from_bitmask_impl(bitmask: U) -> Self - where - LaneCount: SupportedLaneCount, - { - // LLVM assumes bit order should match endianness - let bitmask = if cfg!(target_endian = "big") { - bitmask.reverse_bits(M) - } else { - bitmask - }; - - // SAFETY: `mask` is the correct bitmask type for a u64 bitmask - let mask: Simd = unsafe { - core::intrinsics::simd::simd_select_bitmask( - bitmask, - Simd::::splat(T::TRUE), - Simd::::splat(T::FALSE), - ) - }; - - // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` - unsafe { Self::from_simd_unchecked(mask.resize::(T::FALSE)) } - } - - #[inline] - pub(crate) fn to_bitmask_integer(self) -> u64 { - // TODO modify simd_bitmask to zero-extend output, making this unnecessary - if N <= 8 { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() as u64 } - } else if N <= 16 { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() as u64 } - } else if N <= 32 { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() as u64 } - } else { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() } - } - } - - #[inline] - pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { - // TODO modify simd_bitmask_select to truncate input, making this unnecessary - if N <= 8 { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask as u8) } - } else if N <= 16 { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask as u16) } - } else if N <= 32 { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask as u32) } - } else { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask) } - } - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn any(self) -> bool { - // Safety: use `self` as an integer vector - unsafe { core::intrinsics::simd::simd_reduce_any(self.to_simd()) } - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn all(self) -> bool { - // Safety: use `self` as an integer vector - unsafe { core::intrinsics::simd::simd_reduce_all(self.to_simd()) } - } -} - -impl From> for Simd -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn from(value: Mask) -> Self { - value.0 - } -} - -impl core::ops::BitAnd for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn bitand(self, rhs: Self) -> Self { - // Safety: `self` is an integer vector - unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) } - } -} - -impl core::ops::BitOr for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn bitor(self, rhs: Self) -> Self { - // Safety: `self` is an integer vector - unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) } - } -} - -impl core::ops::BitXor for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn bitxor(self, rhs: Self) -> Self { - // Safety: `self` is an integer vector - unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) } - } -} - -impl core::ops::Not for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn not(self) -> Self::Output { - Self::splat(true) ^ self - } -} diff --git a/crates/core_simd/src/mod.rs b/crates/core_simd/src/mod.rs index 45b1a0f9751..14fe70df4ed 100644 --- a/crates/core_simd/src/mod.rs +++ b/crates/core_simd/src/mod.rs @@ -29,6 +29,7 @@ pub mod simd { pub use crate::core_simd::cast::*; pub use crate::core_simd::lane_count::{LaneCount, SupportedLaneCount}; pub use crate::core_simd::masks::*; + pub use crate::core_simd::select::*; pub use crate::core_simd::swizzle::*; pub use crate::core_simd::to_bytes::ToBytes; pub use crate::core_simd::vector::*; diff --git a/crates/core_simd/src/ops.rs b/crates/core_simd/src/ops.rs index f36e8d01a73..f36e360fadf 100644 --- a/crates/core_simd/src/ops.rs +++ b/crates/core_simd/src/ops.rs @@ -1,4 +1,4 @@ -use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount, cmp::SimdPartialEq}; +use crate::simd::{LaneCount, Select, Simd, SimdElement, SupportedLaneCount, cmp::SimdPartialEq}; use core::ops::{Add, Mul}; use core::ops::{BitAnd, BitOr, BitXor}; use core::ops::{Div, Rem, Sub}; diff --git a/crates/core_simd/src/select.rs b/crates/core_simd/src/select.rs index a2db455a526..5240b9b0c71 100644 --- a/crates/core_simd/src/select.rs +++ b/crates/core_simd/src/select.rs @@ -1,54 +1,163 @@ -use crate::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount}; +use crate::simd::{ + FixEndianness, LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount, +}; -impl Mask +/// Choose elements from two vectors using a mask. +/// +/// For each element in the mask, choose the corresponding element from `true_values` if +/// that element mask is true, and `false_values` if that element mask is false. +/// +/// If the mask is `u64`, it's treated as a bitmask with the least significant bit +/// corresponding to the first element. +/// +/// # Examples +/// +/// ## Selecting values from `Simd` +/// ``` +/// # #![feature(portable_simd)] +/// # #[cfg(feature = "as_crate")] use core_simd::simd; +/// # #[cfg(not(feature = "as_crate"))] use core::simd; +/// # use simd::{Simd, Mask, Select}; +/// let a = Simd::from_array([0, 1, 2, 3]); +/// let b = Simd::from_array([4, 5, 6, 7]); +/// let mask = Mask::::from_array([true, false, false, true]); +/// let c = mask.select(a, b); +/// assert_eq!(c.to_array(), [0, 5, 6, 3]); +/// ``` +/// +/// ## Selecting values from `Mask` +/// ``` +/// # #![feature(portable_simd)] +/// # #[cfg(feature = "as_crate")] use core_simd::simd; +/// # #[cfg(not(feature = "as_crate"))] use core::simd; +/// # use simd::{Mask, Select}; +/// let a = Mask::::from_array([true, true, false, false]); +/// let b = Mask::::from_array([false, false, true, true]); +/// let mask = Mask::::from_array([true, false, false, true]); +/// let c = mask.select(a, b); +/// assert_eq!(c.to_array(), [true, false, true, false]); +/// ``` +/// +/// ## Selecting with a bitmask +/// ``` +/// # #![feature(portable_simd)] +/// # #[cfg(feature = "as_crate")] use core_simd::simd; +/// # #[cfg(not(feature = "as_crate"))] use core::simd; +/// # use simd::{Mask, Select}; +/// let a = Mask::::from_array([true, true, false, false]); +/// let b = Mask::::from_array([false, false, true, true]); +/// let mask = 0b1001; +/// let c = mask.select(a, b); +/// assert_eq!(c.to_array(), [true, false, true, false]); +/// ``` +pub trait Select { + /// Choose elements + fn select(self, true_values: T, false_values: T) -> T; +} + +impl Select> for Mask +where + T: SimdElement, + U: MaskElement, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn select(self, true_values: Simd, false_values: Simd) -> Simd { + // Safety: + // simd_as between masks is always safe (they're vectors of ints). + // simd_select uses a mask that matches the width and number of elements + unsafe { + let mask: Simd = core::intrinsics::simd::simd_as(self.to_simd()); + core::intrinsics::simd::simd_select(mask, true_values, false_values) + } + } +} + +impl Select> for u64 +where + T: SimdElement, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn select(self, true_values: Simd, false_values: Simd) -> Simd { + const { + assert!(N <= 64, "number of elements can't be greater than 64"); + } + + #[inline] + unsafe fn select_impl( + bitmask: U, + true_values: Simd, + false_values: Simd, + ) -> Simd + where + T: SimdElement, + LaneCount: SupportedLaneCount, + LaneCount: SupportedLaneCount, + { + let default = true_values[0]; + let true_values = true_values.resize::(default); + let false_values = false_values.resize::(default); + + // LLVM assumes bit order should match endianness + let bitmask = bitmask.fix_endianness(); + + // Safety: the caller guarantees that the size of U matches M + let selected = unsafe { + core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values) + }; + + selected.resize::(default) + } + + // TODO modify simd_bitmask_select to truncate input, making this unnecessary + if N <= 8 { + let bitmask = self as u8; + // Safety: bitmask matches length + unsafe { select_impl::(bitmask, true_values, false_values) } + } else if N <= 16 { + let bitmask = self as u16; + // Safety: bitmask matches length + unsafe { select_impl::(bitmask, true_values, false_values) } + } else if N <= 32 { + let bitmask = self as u32; + // Safety: bitmask matches length + unsafe { select_impl::(bitmask, true_values, false_values) } + } else { + let bitmask = self; + // Safety: bitmask matches length + unsafe { select_impl::(bitmask, true_values, false_values) } + } + } +} + +impl Select> for Mask where T: MaskElement, + U: MaskElement, LaneCount: SupportedLaneCount, { - /// Choose elements from two vectors. - /// - /// For each element in the mask, choose the corresponding element from `true_values` if - /// that element mask is true, and `false_values` if that element mask is false. - /// - /// # Examples - /// ``` - /// # #![feature(portable_simd)] - /// # use core::simd::{Simd, Mask}; - /// let a = Simd::from_array([0, 1, 2, 3]); - /// let b = Simd::from_array([4, 5, 6, 7]); - /// let mask = Mask::from_array([true, false, false, true]); - /// let c = mask.select(a, b); - /// assert_eq!(c.to_array(), [0, 5, 6, 3]); - /// ``` #[inline] - #[must_use = "method returns a new vector and does not mutate the original inputs"] - pub fn select(self, true_values: Simd, false_values: Simd) -> Simd - where - U: SimdElement, - { - // Safety: The mask has been cast to a vector of integers, - // and the operands to select between are vectors of the same type and length. - unsafe { core::intrinsics::simd::simd_select(self.to_simd(), true_values, false_values) } + fn select(self, true_values: Mask, false_values: Mask) -> Mask { + let selected: Simd = + Select::select(self, true_values.to_simd(), false_values.to_simd()); + + // Safety: all values come from masks + unsafe { Mask::from_simd_unchecked(selected) } } +} - /// Choose elements from two masks. - /// - /// For each element in the mask, choose the corresponding element from `true_values` if - /// that element mask is true, and `false_values` if that element mask is false. - /// - /// # Examples - /// ``` - /// # #![feature(portable_simd)] - /// # use core::simd::Mask; - /// let a = Mask::::from_array([true, true, false, false]); - /// let b = Mask::::from_array([false, false, true, true]); - /// let mask = Mask::::from_array([true, false, false, true]); - /// let c = mask.select_mask(a, b); - /// assert_eq!(c.to_array(), [true, false, true, false]); - /// ``` +impl Select> for u64 +where + T: MaskElement, + LaneCount: SupportedLaneCount, +{ #[inline] - #[must_use = "method returns a new mask and does not mutate the original inputs"] - pub fn select_mask(self, true_values: Self, false_values: Self) -> Self { - self & true_values | !self & false_values + fn select(self, true_values: Mask, false_values: Mask) -> Mask { + let selected: Simd = + Select::select(self, true_values.to_simd(), false_values.to_simd()); + + // Safety: all values come from masks + unsafe { Mask::from_simd_unchecked(selected) } } } diff --git a/crates/core_simd/src/simd/cmp/ord.rs b/crates/core_simd/src/simd/cmp/ord.rs index 4b2d0b55feb..1b1c689ad45 100644 --- a/crates/core_simd/src/simd/cmp/ord.rs +++ b/crates/core_simd/src/simd/cmp/ord.rs @@ -1,5 +1,5 @@ use crate::simd::{ - LaneCount, Mask, Simd, SupportedLaneCount, + LaneCount, Mask, Select, Simd, SupportedLaneCount, cmp::SimdPartialEq, ptr::{SimdConstPtr, SimdMutPtr}, }; @@ -194,12 +194,12 @@ macro_rules! impl_mask { { #[inline] fn simd_max(self, other: Self) -> Self { - self.simd_gt(other).select_mask(other, self) + self.simd_gt(other).select(other, self) } #[inline] fn simd_min(self, other: Self) -> Self { - self.simd_lt(other).select_mask(other, self) + self.simd_lt(other).select(other, self) } #[inline] diff --git a/crates/core_simd/src/simd/num/float.rs b/crates/core_simd/src/simd/num/float.rs index b5972c47373..76ab5748c63 100644 --- a/crates/core_simd/src/simd/num/float.rs +++ b/crates/core_simd/src/simd/num/float.rs @@ -1,6 +1,6 @@ use super::sealed::Sealed; use crate::simd::{ - LaneCount, Mask, Simd, SimdCast, SimdElement, SupportedLaneCount, + LaneCount, Mask, Select, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::{SimdPartialEq, SimdPartialOrd}, }; diff --git a/crates/core_simd/src/simd/num/int.rs b/crates/core_simd/src/simd/num/int.rs index d25050c3e4b..5a292407d05 100644 --- a/crates/core_simd/src/simd/num/int.rs +++ b/crates/core_simd/src/simd/num/int.rs @@ -1,6 +1,6 @@ use super::sealed::Sealed; use crate::simd::{ - LaneCount, Mask, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::SimdOrd, + LaneCount, Mask, Select, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::SimdOrd, cmp::SimdPartialOrd, num::SimdUint, }; diff --git a/crates/core_simd/src/swizzle_dyn.rs b/crates/core_simd/src/swizzle_dyn.rs index 773bd028bae..73b18595d0a 100644 --- a/crates/core_simd/src/swizzle_dyn.rs +++ b/crates/core_simd/src/swizzle_dyn.rs @@ -139,7 +139,7 @@ unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd, idxs: Simd) -> S #[inline] #[allow(clippy::let_and_return)] unsafe fn avx2_pshufb(bytes: Simd, idxs: Simd) -> Simd { - use crate::simd::cmp::SimdPartialOrd; + use crate::simd::{Select, cmp::SimdPartialOrd}; #[cfg(target_arch = "x86")] use core::arch::x86; #[cfg(target_arch = "x86_64")] @@ -200,7 +200,7 @@ fn zeroing_idxs(idxs: Simd) -> Simd where LaneCount: SupportedLaneCount, { - use crate::simd::cmp::SimdPartialOrd; + use crate::simd::{Select, cmp::SimdPartialOrd}; idxs.simd_lt(Simd::splat(N as u8)) .select(idxs, Simd::splat(u8::MAX)) }