diff --git a/crates/core_simd/src/swizzle_dyn.rs b/crates/core_simd/src/swizzle_dyn.rs index 773bd028bae..b2db63f930f 100644 --- a/crates/core_simd/src/swizzle_dyn.rs +++ b/crates/core_simd/src/swizzle_dyn.rs @@ -80,6 +80,8 @@ where }; transize(swizzler, self, idxs) } + #[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))] + 64 => transize(avx2_pshufb512, self, idxs), // Notable absence: avx512bw pshufb shuffle #[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))] 64 => { @@ -171,6 +173,61 @@ unsafe fn avx2_pshufb(bytes: Simd, idxs: Simd) -> Simd { } } +/// The above function but for 64 bytes +/// +/// # Safety +/// This requires AVX2 to work +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +#[allow(unused)] +#[inline] +#[allow(clippy::let_and_return)] +unsafe fn avx2_pshufb512(bytes: Simd, idxs: Simd) -> Simd { + use crate::simd::cmp::SimdPartialOrd; + #[cfg(target_arch = "x86")] + use core::arch::x86; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64 as x86; + use x86::_mm256_blendv_epi8 as avx2_blend; + use x86::_mm256_permute2x128_si256 as avx2_cross_shuffle; + use x86::_mm256_shuffle_epi8 as avx2_half_pshufb; + let high = Simd::splat(64u8); + // SAFETY: Caller promised AVX2 + unsafe { + let half_swizzler = |bytes0: Simd, bytes1: Simd, idxs: Simd| { + let mask0 = idxs << 2; + let mask1 = idxs << 3; + + let lolo0 = avx2_cross_shuffle::<0x00>(bytes0.into(), bytes0.into()); + let hihi0 = avx2_cross_shuffle::<0x11>(bytes0.into(), bytes0.into()); + let lolo0 = avx2_half_pshufb(lolo0, idxs.into()); + let hihi0 = avx2_half_pshufb(hihi0, idxs.into()); + let x = avx2_blend(lolo0, hihi0, mask1.into()); + + let lolo1 = avx2_cross_shuffle::<0x00>(bytes1.into(), bytes1.into()); + let hihi1 = avx2_cross_shuffle::<0x11>(bytes1.into(), bytes1.into()); + let lolo1 = avx2_half_pshufb(lolo1, idxs.into()); + let hihi1 = avx2_half_pshufb(hihi1, idxs.into()); + let y = avx2_blend(lolo1, hihi1, mask1.into()); + + avx2_blend(x, y, mask0.into()) + }; + + let bytes0 = bytes.extract::<0, 32>(); + let bytes1 = bytes.extract::<32, 32>(); + let idxs0 = idxs.extract::<0, 32>(); + let idxs1 = idxs.extract::<32, 32>(); + + let z0 = half_swizzler(bytes0, bytes1, idxs0); + let z1 = half_swizzler(bytes0, bytes1, idxs1); + + // SAFETY: Concatenation of two 32-element vectors to one 64-element vector + let z = mem::transmute::<[Simd; 2], Simd>([z0.into(), z1.into()]); + + idxs.simd_lt(high).select(z, Simd::splat(0u8)) + } +} + /// This sets up a call to an architecture-specific function, and in doing so /// it persuades rustc that everything is the correct size. Which it is. /// This would not be needed if one could convince Rust that, by matching on N,