From 114693074eb905016a5539d8559b8d56fbcf19d1 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Wed, 10 Sep 2025 00:17:47 -0400 Subject: [PATCH 1/6] Guarantee Mask has the same layout as Simd. Implement select as a trait that also supports bitmasks. --- crates/core_simd/src/masks.rs | 16 +- crates/core_simd/src/mod.rs | 1 + crates/core_simd/src/ops.rs | 2 +- crates/core_simd/src/select.rs | 197 +++++++++++++++++++------ crates/core_simd/src/simd/cmp/ord.rs | 6 +- crates/core_simd/src/simd/num/float.rs | 2 +- crates/core_simd/src/simd/num/int.rs | 2 +- crates/core_simd/src/swizzle_dyn.rs | 2 +- 8 files changed, 167 insertions(+), 61 deletions(-) diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index ca1e3db8b46..9a81320b44a 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -2,17 +2,10 @@ //! Types representing #![allow(non_camel_case_types)] -#[cfg_attr( - not(all(target_arch = "x86_64", target_feature = "avx512f")), - path = "masks/full_masks.rs" -)] -#[cfg_attr( - all(target_arch = "x86_64", target_feature = "avx512f"), - path = "masks/bitmask.rs" -)] +#[path = "masks/full_masks.rs"] mod mask_impl; -use crate::simd::{LaneCount, Simd, SimdCast, SimdElement, SupportedLaneCount}; +use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount}; use core::cmp::Ordering; use core::{fmt, mem}; @@ -105,9 +98,8 @@ impl_element! { isize, usize } /// /// Masks represent boolean inclusion/exclusion on a per-element basis. /// -/// The layout of this type is unspecified, and may change between platforms -/// and/or Rust versions, and code should not assume that it is equivalent to -/// `[T; N]`. +/// The layout of this type is equivalent to `Simd`, but elements +/// are guaranteed to be either 0 or -1. #[repr(transparent)] pub struct Mask(mask_impl::Mask) where diff --git a/crates/core_simd/src/mod.rs b/crates/core_simd/src/mod.rs index 45b1a0f9751..14fe70df4ed 100644 --- a/crates/core_simd/src/mod.rs +++ b/crates/core_simd/src/mod.rs @@ -29,6 +29,7 @@ pub mod simd { pub use crate::core_simd::cast::*; pub use crate::core_simd::lane_count::{LaneCount, SupportedLaneCount}; pub use crate::core_simd::masks::*; + pub use crate::core_simd::select::*; pub use crate::core_simd::swizzle::*; pub use crate::core_simd::to_bytes::ToBytes; pub use crate::core_simd::vector::*; diff --git a/crates/core_simd/src/ops.rs b/crates/core_simd/src/ops.rs index f36e8d01a73..f36e360fadf 100644 --- a/crates/core_simd/src/ops.rs +++ b/crates/core_simd/src/ops.rs @@ -1,4 +1,4 @@ -use crate::simd::{LaneCount, Simd, SimdElement, SupportedLaneCount, cmp::SimdPartialEq}; +use crate::simd::{LaneCount, Select, Simd, SimdElement, SupportedLaneCount, cmp::SimdPartialEq}; use core::ops::{Add, Mul}; use core::ops::{BitAnd, BitOr, BitXor}; use core::ops::{Div, Rem, Sub}; diff --git a/crates/core_simd/src/select.rs b/crates/core_simd/src/select.rs index a2db455a526..d3dc64c7f0f 100644 --- a/crates/core_simd/src/select.rs +++ b/crates/core_simd/src/select.rs @@ -1,54 +1,167 @@ use crate::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount}; -impl Mask +/// Choose elements from two vectors using a mask. +/// +/// For each element in the mask, choose the corresponding element from `true_values` if +/// that element mask is true, and `false_values` if that element mask is false. +/// +/// If the mask is `u64`, it's treated as a bitmask with the least significant bit +/// corresponding to the first element. +/// +/// # Examples +/// +/// ## Selecting values from `Simd` +/// ``` +/// # #![feature(portable_simd)] +/// # #[cfg(feature = "as_crate")] use core_simd::simd; +/// # #[cfg(not(feature = "as_crate"))] use core::simd; +/// # use simd::{Simd, Mask, Select}; +/// let a = Simd::from_array([0, 1, 2, 3]); +/// let b = Simd::from_array([4, 5, 6, 7]); +/// let mask = Mask::::from_array([true, false, false, true]); +/// let c = mask.select(a, b); +/// assert_eq!(c.to_array(), [0, 5, 6, 3]); +/// ``` +/// +/// ## Selecting values from `Mask` +/// ``` +/// # #![feature(portable_simd)] +/// # #[cfg(feature = "as_crate")] use core_simd::simd; +/// # #[cfg(not(feature = "as_crate"))] use core::simd; +/// # use simd::{Mask, Select}; +/// let a = Mask::::from_array([true, true, false, false]); +/// let b = Mask::::from_array([false, false, true, true]); +/// let mask = Mask::::from_array([true, false, false, true]); +/// let c = mask.select(a, b); +/// assert_eq!(c.to_array(), [true, false, true, false]); +/// ``` +/// +/// ## Selecting with a bitmask +/// ``` +/// # #![feature(portable_simd)] +/// # #[cfg(feature = "as_crate")] use core_simd::simd; +/// # #[cfg(not(feature = "as_crate"))] use core::simd; +/// # use simd::{Mask, Select}; +/// let a = Mask::::from_array([true, true, false, false]); +/// let b = Mask::::from_array([false, false, true, true]); +/// let mask = 0b1001; +/// let c = mask.select(a, b); +/// assert_eq!(c.to_array(), [true, false, true, false]); +/// ``` +pub trait Select { + /// Choose elements + fn select(self, true_values: T, false_values: T) -> T; +} + +impl Select> for Mask +where + T: SimdElement, + U: MaskElement, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn select(self, true_values: Simd, false_values: Simd) -> Simd { + // Safety: + // simd_as between masks is always safe (they're vectors of ints). + // simd_select uses a mask that matches the width and number of elements + unsafe { + let mask: Simd = core::intrinsics::simd::simd_as(self.to_simd()); + core::intrinsics::simd::simd_select(mask, true_values, false_values) + } + } +} + +impl Select> for u64 +where + T: SimdElement, + LaneCount: SupportedLaneCount, +{ + #[inline] + fn select(self, true_values: Simd, false_values: Simd) -> Simd { + const { + assert!(N <= 64, "number of elements can't be greater than 64"); + } + + // LLVM assumes bit order should match endianness + let bitmask = if cfg!(target_endian = "big") { + let rev = self.reverse_bits(); + if N < 64 { + // Shift things back to the right + rev >> (64 - N) + } else { + rev + } + } else { + self + }; + + #[inline] + unsafe fn select_impl( + bitmask: U, + true_values: Simd, + false_values: Simd, + ) -> Simd + where + T: SimdElement, + LaneCount: SupportedLaneCount, + LaneCount: SupportedLaneCount, + { + let default = true_values[0]; + let true_values = true_values.resize::(default); + let false_values = false_values.resize::(default); + + // Safety: the caller guarantees that the size of U matches M + let selected = unsafe { + core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values) + }; + + selected.resize::(default) + } + + // TODO modify simd_bitmask_select to truncate input, making this unnecessary + if N <= 8 { + // Safety: bitmask matches length + unsafe { select_impl::(bitmask as u8, true_values, false_values) } + } else if N <= 16 { + // Safety: bitmask matches length + unsafe { select_impl::(bitmask as u16, true_values, false_values) } + } else if N <= 32 { + // Safety: bitmask matches length + unsafe { select_impl::(bitmask as u32, true_values, false_values) } + } else { + // Safety: bitmask matches length + unsafe { select_impl::(bitmask, true_values, false_values) } + } + } +} + +impl Select> for Mask where T: MaskElement, + U: MaskElement, LaneCount: SupportedLaneCount, { - /// Choose elements from two vectors. - /// - /// For each element in the mask, choose the corresponding element from `true_values` if - /// that element mask is true, and `false_values` if that element mask is false. - /// - /// # Examples - /// ``` - /// # #![feature(portable_simd)] - /// # use core::simd::{Simd, Mask}; - /// let a = Simd::from_array([0, 1, 2, 3]); - /// let b = Simd::from_array([4, 5, 6, 7]); - /// let mask = Mask::from_array([true, false, false, true]); - /// let c = mask.select(a, b); - /// assert_eq!(c.to_array(), [0, 5, 6, 3]); - /// ``` #[inline] - #[must_use = "method returns a new vector and does not mutate the original inputs"] - pub fn select(self, true_values: Simd, false_values: Simd) -> Simd - where - U: SimdElement, - { - // Safety: The mask has been cast to a vector of integers, - // and the operands to select between are vectors of the same type and length. - unsafe { core::intrinsics::simd::simd_select(self.to_simd(), true_values, false_values) } + fn select(self, true_values: Mask, false_values: Mask) -> Mask { + let selected: Simd = + Select::select(self, true_values.to_simd(), false_values.to_simd()); + + // Safety: all values come from masks + unsafe { Mask::from_simd_unchecked(selected) } } +} - /// Choose elements from two masks. - /// - /// For each element in the mask, choose the corresponding element from `true_values` if - /// that element mask is true, and `false_values` if that element mask is false. - /// - /// # Examples - /// ``` - /// # #![feature(portable_simd)] - /// # use core::simd::Mask; - /// let a = Mask::::from_array([true, true, false, false]); - /// let b = Mask::::from_array([false, false, true, true]); - /// let mask = Mask::::from_array([true, false, false, true]); - /// let c = mask.select_mask(a, b); - /// assert_eq!(c.to_array(), [true, false, true, false]); - /// ``` +impl Select> for u64 +where + T: MaskElement, + LaneCount: SupportedLaneCount, +{ #[inline] - #[must_use = "method returns a new mask and does not mutate the original inputs"] - pub fn select_mask(self, true_values: Self, false_values: Self) -> Self { - self & true_values | !self & false_values + fn select(self, true_values: Mask, false_values: Mask) -> Mask { + let selected: Simd = + Select::select(self, true_values.to_simd(), false_values.to_simd()); + + // Safety: all values come from masks + unsafe { Mask::from_simd_unchecked(selected) } } } diff --git a/crates/core_simd/src/simd/cmp/ord.rs b/crates/core_simd/src/simd/cmp/ord.rs index 4b2d0b55feb..1b1c689ad45 100644 --- a/crates/core_simd/src/simd/cmp/ord.rs +++ b/crates/core_simd/src/simd/cmp/ord.rs @@ -1,5 +1,5 @@ use crate::simd::{ - LaneCount, Mask, Simd, SupportedLaneCount, + LaneCount, Mask, Select, Simd, SupportedLaneCount, cmp::SimdPartialEq, ptr::{SimdConstPtr, SimdMutPtr}, }; @@ -194,12 +194,12 @@ macro_rules! impl_mask { { #[inline] fn simd_max(self, other: Self) -> Self { - self.simd_gt(other).select_mask(other, self) + self.simd_gt(other).select(other, self) } #[inline] fn simd_min(self, other: Self) -> Self { - self.simd_lt(other).select_mask(other, self) + self.simd_lt(other).select(other, self) } #[inline] diff --git a/crates/core_simd/src/simd/num/float.rs b/crates/core_simd/src/simd/num/float.rs index b5972c47373..76ab5748c63 100644 --- a/crates/core_simd/src/simd/num/float.rs +++ b/crates/core_simd/src/simd/num/float.rs @@ -1,6 +1,6 @@ use super::sealed::Sealed; use crate::simd::{ - LaneCount, Mask, Simd, SimdCast, SimdElement, SupportedLaneCount, + LaneCount, Mask, Select, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::{SimdPartialEq, SimdPartialOrd}, }; diff --git a/crates/core_simd/src/simd/num/int.rs b/crates/core_simd/src/simd/num/int.rs index d25050c3e4b..5a292407d05 100644 --- a/crates/core_simd/src/simd/num/int.rs +++ b/crates/core_simd/src/simd/num/int.rs @@ -1,6 +1,6 @@ use super::sealed::Sealed; use crate::simd::{ - LaneCount, Mask, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::SimdOrd, + LaneCount, Mask, Select, Simd, SimdCast, SimdElement, SupportedLaneCount, cmp::SimdOrd, cmp::SimdPartialOrd, num::SimdUint, }; diff --git a/crates/core_simd/src/swizzle_dyn.rs b/crates/core_simd/src/swizzle_dyn.rs index 773bd028bae..016422b7060 100644 --- a/crates/core_simd/src/swizzle_dyn.rs +++ b/crates/core_simd/src/swizzle_dyn.rs @@ -1,4 +1,4 @@ -use crate::simd::{LaneCount, Simd, SupportedLaneCount}; +use crate::simd::{LaneCount, Select, Simd, SupportedLaneCount}; use core::mem; impl Simd From b8df2c96a16675ab49fcdebb3f9ffde619b6730c Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Wed, 10 Sep 2025 01:45:04 -0400 Subject: [PATCH 2/6] Remove mask backing implementations --- crates/core_simd/src/masks.rs | 88 ++++--- crates/core_simd/src/masks/bitmask.rs | 231 ------------------ crates/core_simd/src/masks/full_masks.rs | 297 ----------------------- 3 files changed, 59 insertions(+), 557 deletions(-) delete mode 100644 crates/core_simd/src/masks/bitmask.rs delete mode 100644 crates/core_simd/src/masks/full_masks.rs diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index 9a81320b44a..1cdfe013ce3 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -2,9 +2,6 @@ //! Types representing #![allow(non_camel_case_types)] -#[path = "masks/full_masks.rs"] -mod mask_impl; - use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneCount}; use core::cmp::Ordering; use core::{fmt, mem}; @@ -101,7 +98,7 @@ impl_element! { isize, usize } /// The layout of this type is equivalent to `Simd`, but elements /// are guaranteed to be either 0 or -1. #[repr(transparent)] -pub struct Mask(mask_impl::Mask) +pub struct Mask(Simd) where T: MaskElement, LaneCount: SupportedLaneCount; @@ -133,7 +130,7 @@ where #[inline] #[rustc_const_unstable(feature = "portable_simd", issue = "86656")] pub const fn splat(value: bool) -> Self { - Self(mask_impl::Mask::splat(value)) + Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) } /// Converts an array of bools to a SIMD mask. @@ -184,8 +181,8 @@ where // Safety: the caller must confirm this invariant unsafe { core::intrinsics::assume(::valid(value)); - Self(mask_impl::Mask::from_simd_unchecked(value)) } + Self(value) } /// Converts a vector of integers to a mask, where 0 represents `false` and -1 @@ -207,14 +204,15 @@ where #[inline] #[must_use = "method returns a new vector and does not mutate the original value"] pub fn to_simd(self) -> Simd { - self.0.to_simd() + self.0 } /// Converts the mask to a mask of any other element size. #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn cast(self) -> Mask { - Mask(self.0.convert()) + // Safety: mask elements are integers + unsafe { Mask(core::intrinsics::simd::simd_as(self.0)) } } /// Tests the value of the specified element. @@ -225,7 +223,7 @@ where #[must_use = "method returns a new bool and does not mutate the original value"] pub unsafe fn test_unchecked(&self, index: usize) -> bool { // Safety: the caller must confirm this invariant - unsafe { self.0.test_unchecked(index) } + unsafe { T::eq(*self.0.as_array().get_unchecked(index), T::TRUE) } } /// Tests the value of the specified element. @@ -236,9 +234,7 @@ where #[must_use = "method returns a new bool and does not mutate the original value"] #[track_caller] pub fn test(&self, index: usize) -> bool { - assert!(index < N, "element index out of range"); - // Safety: the element index has been checked - unsafe { self.test_unchecked(index) } + T::eq(self.0[index], T::TRUE) } /// Sets the value of the specified element. @@ -249,7 +245,7 @@ where pub unsafe fn set_unchecked(&mut self, index: usize, value: bool) { // Safety: the caller must confirm this invariant unsafe { - self.0.set_unchecked(index, value); + *self.0.as_mut_array().get_unchecked_mut(index) = if value { T::TRUE } else { T::FALSE } } } @@ -260,25 +256,23 @@ where #[inline] #[track_caller] pub fn set(&mut self, index: usize, value: bool) { - assert!(index < N, "element index out of range"); - // Safety: the element index has been checked - unsafe { - self.set_unchecked(index, value); - } + self.0[index] = if value { T::TRUE } else { T::FALSE } } /// Returns true if any element is set, or false otherwise. #[inline] #[must_use = "method returns a new bool and does not mutate the original value"] pub fn any(self) -> bool { - self.0.any() + // Safety: `self` is a mask vector + unsafe { core::intrinsics::simd::simd_reduce_any(self.0) } } /// Returns true if all elements are set, or false otherwise. #[inline] #[must_use = "method returns a new bool and does not mutate the original value"] pub fn all(self) -> bool { - self.0.all() + // Safety: `self` is a mask vector + unsafe { core::intrinsics::simd::simd_reduce_all(self.0) } } /// Creates a bitmask from a mask. @@ -288,7 +282,40 @@ where #[inline] #[must_use = "method returns a new integer and does not mutate the original value"] pub fn to_bitmask(self) -> u64 { - self.0.to_bitmask_integer() + #[inline] + unsafe fn to_bitmask_impl(mask: Mask) -> U + where + T: MaskElement, + LaneCount: SupportedLaneCount, + LaneCount: SupportedLaneCount, + { + let resized = mask.resize::(false); + + // Safety: `resized` is an integer vector with length M, which must match T + unsafe { core::intrinsics::simd::simd_bitmask(resized.0) } + } + + // TODO modify simd_bitmask to zero-extend output, making this unnecessary + let bitmask = if N <= 8 { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) as u64 } + } else if N <= 16 { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) as u64 } + } else if N <= 32 { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) as u64 } + } else { + // Safety: bitmask matches length + unsafe { to_bitmask_impl::(self) } + }; + + // LLVM assumes bit order should match endianness + if cfg!(target_endian = "big") { + bitmask.reverse_bits() >> (64 - N.min(64)) + } else { + bitmask + } } /// Creates a mask from a bitmask. @@ -298,7 +325,7 @@ where #[inline] #[must_use = "method returns a new mask and does not mutate the original value"] pub fn from_bitmask(bitmask: u64) -> Self { - Self(mask_impl::Mask::from_bitmask_integer(bitmask)) + Self(bitmask.select(Simd::splat(T::TRUE), Simd::splat(T::FALSE))) } /// Finds the index of the first set element. @@ -442,7 +469,8 @@ where type Output = Self; #[inline] fn bitand(self, rhs: Self) -> Self { - Self(self.0 & rhs.0) + // Safety: `self` is an integer vector + unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) } } } @@ -478,7 +506,8 @@ where type Output = Self; #[inline] fn bitor(self, rhs: Self) -> Self { - Self(self.0 | rhs.0) + // Safety: `self` is an integer vector + unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) } } } @@ -514,7 +543,8 @@ where type Output = Self; #[inline] fn bitxor(self, rhs: Self) -> Self::Output { - Self(self.0 ^ rhs.0) + // Safety: `self` is an integer vector + unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) } } } @@ -550,7 +580,7 @@ where type Output = Mask; #[inline] fn not(self) -> Self::Output { - Self(!self.0) + Self::splat(true) ^ self } } @@ -561,7 +591,7 @@ where { #[inline] fn bitand_assign(&mut self, rhs: Self) { - self.0 = self.0 & rhs.0; + *self = *self & rhs; } } @@ -583,7 +613,7 @@ where { #[inline] fn bitor_assign(&mut self, rhs: Self) { - self.0 = self.0 | rhs.0; + *self = *self | rhs; } } @@ -605,7 +635,7 @@ where { #[inline] fn bitxor_assign(&mut self, rhs: Self) { - self.0 = self.0 ^ rhs.0; + *self = *self ^ rhs; } } diff --git a/crates/core_simd/src/masks/bitmask.rs b/crates/core_simd/src/masks/bitmask.rs deleted file mode 100644 index 83ee88c372a..00000000000 --- a/crates/core_simd/src/masks/bitmask.rs +++ /dev/null @@ -1,231 +0,0 @@ -#![allow(unused_imports)] -use super::MaskElement; -use crate::simd::{LaneCount, Simd, SupportedLaneCount}; -use core::marker::PhantomData; - -/// A mask where each lane is represented by a single bit. -#[repr(transparent)] -pub(crate) struct Mask( - as SupportedLaneCount>::BitMask, - PhantomData, -) -where - T: MaskElement, - LaneCount: SupportedLaneCount; - -impl Copy for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ -} - -impl Clone for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn clone(&self) -> Self { - *self - } -} - -impl PartialEq for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn eq(&self, other: &Self) -> bool { - self.0.as_ref() == other.0.as_ref() - } -} - -impl PartialOrd for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn partial_cmp(&self, other: &Self) -> Option { - self.0.as_ref().partial_cmp(other.0.as_ref()) - } -} - -impl Eq for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ -} - -impl Ord for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.0.as_ref().cmp(other.0.as_ref()) - } -} - -impl Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - #[rustc_const_unstable(feature = "portable_simd", issue = "86656")] - pub(crate) const fn splat(value: bool) -> Self { - Self( - if value { - as SupportedLaneCount>::FULL_BIT_MASK - } else { - as SupportedLaneCount>::EMPTY_BIT_MASK - }, - PhantomData, - ) - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) unsafe fn test_unchecked(&self, lane: usize) -> bool { - (self.0.as_ref()[lane / 8] >> (lane % 8)) & 0x1 > 0 - } - - #[inline] - pub(crate) unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { - unsafe { - self.0.as_mut()[lane / 8] ^= ((value ^ self.test_unchecked(lane)) as u8) << (lane % 8) - } - } - - #[inline] - #[must_use = "method returns a new vector and does not mutate the original value"] - pub(crate) fn to_simd(self) -> Simd { - unsafe { - core::intrinsics::simd::simd_select_bitmask( - self.0, - Simd::splat(T::TRUE), - Simd::splat(T::FALSE), - ) - } - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) unsafe fn from_simd_unchecked(value: Simd) -> Self { - unsafe { Self(core::intrinsics::simd::simd_bitmask(value), PhantomData) } - } - - #[inline] - pub(crate) fn to_bitmask_integer(self) -> u64 { - let mut bitmask = [0u8; 8]; - bitmask[..self.0.as_ref().len()].copy_from_slice(self.0.as_ref()); - u64::from_ne_bytes(bitmask) - } - - #[inline] - pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { - let mut bytes = as SupportedLaneCount>::EMPTY_BIT_MASK; - let len = bytes.as_mut().len(); - bytes - .as_mut() - .copy_from_slice(&bitmask.to_ne_bytes()[..len]); - Self(bytes, PhantomData) - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) fn convert(self) -> Mask - where - U: MaskElement, - { - // Safety: bitmask layout does not depend on the element width - unsafe { core::mem::transmute_copy(&self) } - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn any(self) -> bool { - self != Self::splat(false) - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn all(self) -> bool { - self == Self::splat(true) - } -} - -impl core::ops::BitAnd for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, - as SupportedLaneCount>::BitMask: AsRef<[u8]> + AsMut<[u8]>, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn bitand(mut self, rhs: Self) -> Self { - for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) { - *l &= r; - } - self - } -} - -impl core::ops::BitOr for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, - as SupportedLaneCount>::BitMask: AsRef<[u8]> + AsMut<[u8]>, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn bitor(mut self, rhs: Self) -> Self { - for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) { - *l |= r; - } - self - } -} - -impl core::ops::BitXor for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn bitxor(mut self, rhs: Self) -> Self::Output { - for (l, r) in self.0.as_mut().iter_mut().zip(rhs.0.as_ref().iter()) { - *l ^= r; - } - self - } -} - -impl core::ops::Not for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - fn not(mut self) -> Self::Output { - for x in self.0.as_mut() { - *x = !*x; - } - if N % 8 > 0 { - *self.0.as_mut().last_mut().unwrap() &= u8::MAX >> (8 - N % 8); - } - self - } -} diff --git a/crates/core_simd/src/masks/full_masks.rs b/crates/core_simd/src/masks/full_masks.rs deleted file mode 100644 index 5ad2c1d1eaf..00000000000 --- a/crates/core_simd/src/masks/full_masks.rs +++ /dev/null @@ -1,297 +0,0 @@ -//! Masks that take up full SIMD vector registers. - -use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount}; - -#[repr(transparent)] -pub(crate) struct Mask(Simd) -where - T: MaskElement, - LaneCount: SupportedLaneCount; - -impl Copy for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ -} - -impl Clone for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn clone(&self) -> Self { - *self - } -} - -impl PartialEq for Mask -where - T: MaskElement + PartialEq, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn eq(&self, other: &Self) -> bool { - self.0.eq(&other.0) - } -} - -impl PartialOrd for Mask -where - T: MaskElement + PartialOrd, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) - } -} - -impl Eq for Mask -where - T: MaskElement + Eq, - LaneCount: SupportedLaneCount, -{ -} - -impl Ord for Mask -where - T: MaskElement + Ord, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn cmp(&self, other: &Self) -> core::cmp::Ordering { - self.0.cmp(&other.0) - } -} - -// Used for bitmask bit order workaround -pub(crate) trait ReverseBits { - // Reverse the least significant `n` bits of `self`. - // (Remaining bits must be 0.) - fn reverse_bits(self, n: usize) -> Self; -} - -macro_rules! impl_reverse_bits { - { $($int:ty),* } => { - $( - impl ReverseBits for $int { - #[inline(always)] - fn reverse_bits(self, n: usize) -> Self { - let rev = <$int>::reverse_bits(self); - let bitsize = size_of::<$int>() * 8; - if n < bitsize { - // Shift things back to the right - rev >> (bitsize - n) - } else { - rev - } - } - } - )* - } -} - -impl_reverse_bits! { u8, u16, u32, u64 } - -impl Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - #[rustc_const_unstable(feature = "portable_simd", issue = "86656")] - pub(crate) const fn splat(value: bool) -> Self { - Self(Simd::splat(if value { T::TRUE } else { T::FALSE })) - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) unsafe fn test_unchecked(&self, lane: usize) -> bool { - T::eq(self.0[lane], T::TRUE) - } - - #[inline] - pub(crate) unsafe fn set_unchecked(&mut self, lane: usize, value: bool) { - self.0[lane] = if value { T::TRUE } else { T::FALSE } - } - - #[inline] - #[must_use = "method returns a new vector and does not mutate the original value"] - pub(crate) fn to_simd(self) -> Simd { - self.0 - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) unsafe fn from_simd_unchecked(value: Simd) -> Self { - Self(value) - } - - #[inline] - #[must_use = "method returns a new mask and does not mutate the original value"] - pub(crate) fn convert(self) -> Mask - where - U: MaskElement, - { - // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type. - unsafe { Mask(core::intrinsics::simd::simd_cast(self.0)) } - } - - #[inline] - unsafe fn to_bitmask_impl(self) -> U - where - LaneCount: SupportedLaneCount, - { - let resized = self.to_simd().resize::(T::FALSE); - - // Safety: `resized` is an integer vector with length M, which must match T - let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized) }; - - // LLVM assumes bit order should match endianness - if cfg!(target_endian = "big") { - bitmask.reverse_bits(M) - } else { - bitmask - } - } - - #[inline] - unsafe fn from_bitmask_impl(bitmask: U) -> Self - where - LaneCount: SupportedLaneCount, - { - // LLVM assumes bit order should match endianness - let bitmask = if cfg!(target_endian = "big") { - bitmask.reverse_bits(M) - } else { - bitmask - }; - - // SAFETY: `mask` is the correct bitmask type for a u64 bitmask - let mask: Simd = unsafe { - core::intrinsics::simd::simd_select_bitmask( - bitmask, - Simd::::splat(T::TRUE), - Simd::::splat(T::FALSE), - ) - }; - - // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE` - unsafe { Self::from_simd_unchecked(mask.resize::(T::FALSE)) } - } - - #[inline] - pub(crate) fn to_bitmask_integer(self) -> u64 { - // TODO modify simd_bitmask to zero-extend output, making this unnecessary - if N <= 8 { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() as u64 } - } else if N <= 16 { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() as u64 } - } else if N <= 32 { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() as u64 } - } else { - // Safety: bitmask matches length - unsafe { self.to_bitmask_impl::() } - } - } - - #[inline] - pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self { - // TODO modify simd_bitmask_select to truncate input, making this unnecessary - if N <= 8 { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask as u8) } - } else if N <= 16 { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask as u16) } - } else if N <= 32 { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask as u32) } - } else { - // Safety: bitmask matches length - unsafe { Self::from_bitmask_impl::(bitmask) } - } - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn any(self) -> bool { - // Safety: use `self` as an integer vector - unsafe { core::intrinsics::simd::simd_reduce_any(self.to_simd()) } - } - - #[inline] - #[must_use = "method returns a new bool and does not mutate the original value"] - pub(crate) fn all(self) -> bool { - // Safety: use `self` as an integer vector - unsafe { core::intrinsics::simd::simd_reduce_all(self.to_simd()) } - } -} - -impl From> for Simd -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - #[inline] - fn from(value: Mask) -> Self { - value.0 - } -} - -impl core::ops::BitAnd for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn bitand(self, rhs: Self) -> Self { - // Safety: `self` is an integer vector - unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) } - } -} - -impl core::ops::BitOr for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn bitor(self, rhs: Self) -> Self { - // Safety: `self` is an integer vector - unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) } - } -} - -impl core::ops::BitXor for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn bitxor(self, rhs: Self) -> Self { - // Safety: `self` is an integer vector - unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) } - } -} - -impl core::ops::Not for Mask -where - T: MaskElement, - LaneCount: SupportedLaneCount, -{ - type Output = Self; - #[inline] - fn not(self) -> Self::Output { - Self::splat(true) ^ self - } -} From e75a8d8b92b1a7b0a628a78b88bb8e0154f45be1 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Wed, 10 Sep 2025 10:00:10 -0400 Subject: [PATCH 3/6] Fix clippy --- crates/core_simd/src/swizzle_dyn.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/crates/core_simd/src/swizzle_dyn.rs b/crates/core_simd/src/swizzle_dyn.rs index 016422b7060..46c570cf4f4 100644 --- a/crates/core_simd/src/swizzle_dyn.rs +++ b/crates/core_simd/src/swizzle_dyn.rs @@ -1,4 +1,4 @@ -use crate::simd::{LaneCount, Select, Simd, SupportedLaneCount}; +use crate::simd::{LaneCount, Simd, SupportedLaneCount}; use core::mem; impl Simd @@ -139,7 +139,7 @@ unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd, idxs: Simd) -> S #[inline] #[allow(clippy::let_and_return)] unsafe fn avx2_pshufb(bytes: Simd, idxs: Simd) -> Simd { - use crate::simd::cmp::SimdPartialOrd; + use crate::simd::{cmp::SimdPartialOrd, Select}; #[cfg(target_arch = "x86")] use core::arch::x86; #[cfg(target_arch = "x86_64")] @@ -200,7 +200,7 @@ fn zeroing_idxs(idxs: Simd) -> Simd where LaneCount: SupportedLaneCount, { - use crate::simd::cmp::SimdPartialOrd; + use crate::simd::{cmp::SimdPartialOrd, Select}; idxs.simd_lt(Simd::splat(N as u8)) .select(idxs, Simd::splat(u8::MAX)) } From 728f375f207bf4c5db4e3f5e978259442472e56a Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Fri, 12 Sep 2025 00:35:18 -0400 Subject: [PATCH 4/6] Fix endianness correction --- crates/core_simd/src/masks.rs | 41 ++++++++++++++++++++++------- crates/core_simd/src/select.rs | 32 ++++++++++------------ crates/core_simd/src/swizzle_dyn.rs | 4 +-- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index 1cdfe013ce3..eac2316c187 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -6,6 +6,29 @@ use crate::simd::{LaneCount, Select, Simd, SimdCast, SimdElement, SupportedLaneC use core::cmp::Ordering; use core::{fmt, mem}; +pub(crate) trait FixEndianness { + fn fix_endianness(self) -> Self; +} + +macro_rules! impl_fix_endianness { + { $($int:ty),* } => { + $( + impl FixEndianness for $int { + #[inline(always)] + fn fix_endianness(self) -> Self { + if cfg!(target_endian = "big") { + <$int>::reverse_bits(self) + } else { + self + } + } + } + )* + } +} + +impl_fix_endianness! { u8, u16, u32, u64 } + mod sealed { use super::*; @@ -283,7 +306,9 @@ where #[must_use = "method returns a new integer and does not mutate the original value"] pub fn to_bitmask(self) -> u64 { #[inline] - unsafe fn to_bitmask_impl(mask: Mask) -> U + unsafe fn to_bitmask_impl( + mask: Mask, + ) -> U where T: MaskElement, LaneCount: SupportedLaneCount, @@ -292,11 +317,14 @@ where let resized = mask.resize::(false); // Safety: `resized` is an integer vector with length M, which must match T - unsafe { core::intrinsics::simd::simd_bitmask(resized.0) } + let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized.0) }; + + // LLVM assumes bit order should match endianness + bitmask.fix_endianness() } // TODO modify simd_bitmask to zero-extend output, making this unnecessary - let bitmask = if N <= 8 { + if N <= 8 { // Safety: bitmask matches length unsafe { to_bitmask_impl::(self) as u64 } } else if N <= 16 { @@ -308,13 +336,6 @@ where } else { // Safety: bitmask matches length unsafe { to_bitmask_impl::(self) } - }; - - // LLVM assumes bit order should match endianness - if cfg!(target_endian = "big") { - bitmask.reverse_bits() >> (64 - N.min(64)) - } else { - bitmask } } diff --git a/crates/core_simd/src/select.rs b/crates/core_simd/src/select.rs index d3dc64c7f0f..5240b9b0c71 100644 --- a/crates/core_simd/src/select.rs +++ b/crates/core_simd/src/select.rs @@ -1,4 +1,6 @@ -use crate::simd::{LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount}; +use crate::simd::{ + FixEndianness, LaneCount, Mask, MaskElement, Simd, SimdElement, SupportedLaneCount, +}; /// Choose elements from two vectors using a mask. /// @@ -82,21 +84,8 @@ where assert!(N <= 64, "number of elements can't be greater than 64"); } - // LLVM assumes bit order should match endianness - let bitmask = if cfg!(target_endian = "big") { - let rev = self.reverse_bits(); - if N < 64 { - // Shift things back to the right - rev >> (64 - N) - } else { - rev - } - } else { - self - }; - #[inline] - unsafe fn select_impl( + unsafe fn select_impl( bitmask: U, true_values: Simd, false_values: Simd, @@ -110,6 +99,9 @@ where let true_values = true_values.resize::(default); let false_values = false_values.resize::(default); + // LLVM assumes bit order should match endianness + let bitmask = bitmask.fix_endianness(); + // Safety: the caller guarantees that the size of U matches M let selected = unsafe { core::intrinsics::simd::simd_select_bitmask(bitmask, true_values, false_values) @@ -120,15 +112,19 @@ where // TODO modify simd_bitmask_select to truncate input, making this unnecessary if N <= 8 { + let bitmask = self as u8; // Safety: bitmask matches length - unsafe { select_impl::(bitmask as u8, true_values, false_values) } + unsafe { select_impl::(bitmask, true_values, false_values) } } else if N <= 16 { + let bitmask = self as u16; // Safety: bitmask matches length - unsafe { select_impl::(bitmask as u16, true_values, false_values) } + unsafe { select_impl::(bitmask, true_values, false_values) } } else if N <= 32 { + let bitmask = self as u32; // Safety: bitmask matches length - unsafe { select_impl::(bitmask as u32, true_values, false_values) } + unsafe { select_impl::(bitmask, true_values, false_values) } } else { + let bitmask = self; // Safety: bitmask matches length unsafe { select_impl::(bitmask, true_values, false_values) } } diff --git a/crates/core_simd/src/swizzle_dyn.rs b/crates/core_simd/src/swizzle_dyn.rs index 46c570cf4f4..73b18595d0a 100644 --- a/crates/core_simd/src/swizzle_dyn.rs +++ b/crates/core_simd/src/swizzle_dyn.rs @@ -139,7 +139,7 @@ unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd, idxs: Simd) -> S #[inline] #[allow(clippy::let_and_return)] unsafe fn avx2_pshufb(bytes: Simd, idxs: Simd) -> Simd { - use crate::simd::{cmp::SimdPartialOrd, Select}; + use crate::simd::{Select, cmp::SimdPartialOrd}; #[cfg(target_arch = "x86")] use core::arch::x86; #[cfg(target_arch = "x86_64")] @@ -200,7 +200,7 @@ fn zeroing_idxs(idxs: Simd) -> Simd where LaneCount: SupportedLaneCount, { - use crate::simd::{cmp::SimdPartialOrd, Select}; + use crate::simd::{Select, cmp::SimdPartialOrd}; idxs.simd_lt(Simd::splat(N as u8)) .select(idxs, Simd::splat(u8::MAX)) } From 45a18ae52ce0dfc1980d277e8ada620cb28099d0 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Fri, 12 Sep 2025 00:37:21 -0400 Subject: [PATCH 5/6] Revert layout guarantee --- crates/core_simd/src/masks.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index eac2316c187..87494797832 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -118,8 +118,9 @@ impl_element! { isize, usize } /// /// Masks represent boolean inclusion/exclusion on a per-element basis. /// -/// The layout of this type is equivalent to `Simd`, but elements -/// are guaranteed to be either 0 or -1. +/// The layout of this type is unspecified, and may change between platforms +/// and/or Rust versions, and code should not assume that it is equivalent to +/// `[T; N]`. #[repr(transparent)] pub struct Mask(Simd) where From cb429ef62cfa5226243658f16e4ef9fc40ec7780 Mon Sep 17 00:00:00 2001 From: Caleb Zulawski Date: Fri, 12 Sep 2025 20:50:02 -0400 Subject: [PATCH 6/6] Add a const error for to_bitmask > 64 elements --- crates/core_simd/src/masks.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/core_simd/src/masks.rs b/crates/core_simd/src/masks.rs index 87494797832..7baa9647591 100644 --- a/crates/core_simd/src/masks.rs +++ b/crates/core_simd/src/masks.rs @@ -302,10 +302,13 @@ where /// Creates a bitmask from a mask. /// /// Each bit is set if the corresponding element in the mask is `true`. - /// If the mask contains more than 64 elements, the bitmask is truncated to the first 64. #[inline] #[must_use = "method returns a new integer and does not mutate the original value"] pub fn to_bitmask(self) -> u64 { + const { + assert!(N <= 64, "number of elements can't be greater than 64"); + } + #[inline] unsafe fn to_bitmask_impl( mask: Mask,