2
2
//! Types representing
3
3
#![ allow( non_camel_case_types) ]
4
4
5
- #[ path = "masks/full_masks.rs" ]
6
- mod mask_impl;
7
-
8
5
use crate :: simd:: { LaneCount , Select , Simd , SimdCast , SimdElement , SupportedLaneCount } ;
9
6
use core:: cmp:: Ordering ;
10
7
use core:: { fmt, mem} ;
@@ -101,7 +98,7 @@ impl_element! { isize, usize }
101
98
/// The layout of this type is equivalent to `Simd<T, N>`, but elements
102
99
/// are guaranteed to be either 0 or -1.
103
100
#[ repr( transparent) ]
104
- pub struct Mask < T , const N : usize > ( mask_impl :: Mask < T , N > )
101
+ pub struct Mask < T , const N : usize > ( Simd < T , N > )
105
102
where
106
103
T : MaskElement ,
107
104
LaneCount < N > : SupportedLaneCount ;
@@ -133,7 +130,7 @@ where
133
130
#[ inline]
134
131
#[ rustc_const_unstable( feature = "portable_simd" , issue = "86656" ) ]
135
132
pub const fn splat ( value : bool ) -> Self {
136
- Self ( mask_impl :: Mask :: splat ( value) )
133
+ Self ( Simd :: splat ( if value { T :: TRUE } else { T :: FALSE } ) )
137
134
}
138
135
139
136
/// Converts an array of bools to a SIMD mask.
@@ -184,8 +181,8 @@ where
184
181
// Safety: the caller must confirm this invariant
185
182
unsafe {
186
183
core:: intrinsics:: assume ( <T as Sealed >:: valid ( value) ) ;
187
- Self ( mask_impl:: Mask :: from_simd_unchecked ( value) )
188
184
}
185
+ Self ( value)
189
186
}
190
187
191
188
/// Converts a vector of integers to a mask, where 0 represents `false` and -1
@@ -207,14 +204,15 @@ where
207
204
#[ inline]
208
205
#[ must_use = "method returns a new vector and does not mutate the original value" ]
209
206
pub fn to_simd ( self ) -> Simd < T , N > {
210
- self . 0 . to_simd ( )
207
+ self . 0
211
208
}
212
209
213
210
/// Converts the mask to a mask of any other element size.
214
211
#[ inline]
215
212
#[ must_use = "method returns a new mask and does not mutate the original value" ]
216
213
pub fn cast < U : MaskElement > ( self ) -> Mask < U , N > {
217
- Mask ( self . 0 . convert ( ) )
214
+ // Safety: mask elements are integers
215
+ unsafe { Mask ( core:: intrinsics:: simd:: simd_as ( self . 0 ) ) }
218
216
}
219
217
220
218
/// Tests the value of the specified element.
@@ -225,7 +223,7 @@ where
225
223
#[ must_use = "method returns a new bool and does not mutate the original value" ]
226
224
pub unsafe fn test_unchecked ( & self , index : usize ) -> bool {
227
225
// Safety: the caller must confirm this invariant
228
- unsafe { self . 0 . test_unchecked ( index) }
226
+ unsafe { T :: eq ( * self . 0 . as_array ( ) . get_unchecked ( index) , T :: TRUE ) }
229
227
}
230
228
231
229
/// Tests the value of the specified element.
@@ -236,9 +234,7 @@ where
236
234
#[ must_use = "method returns a new bool and does not mutate the original value" ]
237
235
#[ track_caller]
238
236
pub fn test ( & self , index : usize ) -> bool {
239
- assert ! ( index < N , "element index out of range" ) ;
240
- // Safety: the element index has been checked
241
- unsafe { self . test_unchecked ( index) }
237
+ T :: eq ( self . 0 [ index] , T :: TRUE )
242
238
}
243
239
244
240
/// Sets the value of the specified element.
@@ -249,7 +245,7 @@ where
249
245
pub unsafe fn set_unchecked ( & mut self , index : usize , value : bool ) {
250
246
// Safety: the caller must confirm this invariant
251
247
unsafe {
252
- self . 0 . set_unchecked ( index, value) ;
248
+ * self . 0 . as_mut_array ( ) . get_unchecked_mut ( index) = if value { T :: TRUE } else { T :: FALSE }
253
249
}
254
250
}
255
251
@@ -260,25 +256,23 @@ where
260
256
#[ inline]
261
257
#[ track_caller]
262
258
pub fn set ( & mut self , index : usize , value : bool ) {
263
- assert ! ( index < N , "element index out of range" ) ;
264
- // Safety: the element index has been checked
265
- unsafe {
266
- self . set_unchecked ( index, value) ;
267
- }
259
+ self . 0 [ index] = if value { T :: TRUE } else { T :: FALSE }
268
260
}
269
261
270
262
/// Returns true if any element is set, or false otherwise.
271
263
#[ inline]
272
264
#[ must_use = "method returns a new bool and does not mutate the original value" ]
273
265
pub fn any ( self ) -> bool {
274
- self . 0 . any ( )
266
+ // Safety: `self` is a mask vector
267
+ unsafe { core:: intrinsics:: simd:: simd_reduce_any ( self . 0 ) }
275
268
}
276
269
277
270
/// Returns true if all elements are set, or false otherwise.
278
271
#[ inline]
279
272
#[ must_use = "method returns a new bool and does not mutate the original value" ]
280
273
pub fn all ( self ) -> bool {
281
- self . 0 . all ( )
274
+ // Safety: `self` is a mask vector
275
+ unsafe { core:: intrinsics:: simd:: simd_reduce_all ( self . 0 ) }
282
276
}
283
277
284
278
/// Creates a bitmask from a mask.
@@ -288,7 +282,40 @@ where
288
282
#[ inline]
289
283
#[ must_use = "method returns a new integer and does not mutate the original value" ]
290
284
pub fn to_bitmask ( self ) -> u64 {
291
- self . 0 . to_bitmask_integer ( )
285
+ #[ inline]
286
+ unsafe fn to_bitmask_impl < T , U , const M : usize , const N : usize > ( mask : Mask < T , N > ) -> U
287
+ where
288
+ T : MaskElement ,
289
+ LaneCount < M > : SupportedLaneCount ,
290
+ LaneCount < N > : SupportedLaneCount ,
291
+ {
292
+ let resized = mask. resize :: < M > ( false ) ;
293
+
294
+ // Safety: `resized` is an integer vector with length M, which must match T
295
+ unsafe { core:: intrinsics:: simd:: simd_bitmask ( resized. 0 ) }
296
+ }
297
+
298
+ // TODO modify simd_bitmask to zero-extend output, making this unnecessary
299
+ let bitmask = if N <= 8 {
300
+ // Safety: bitmask matches length
301
+ unsafe { to_bitmask_impl :: < T , u8 , 8 , N > ( self ) as u64 }
302
+ } else if N <= 16 {
303
+ // Safety: bitmask matches length
304
+ unsafe { to_bitmask_impl :: < T , u16 , 16 , N > ( self ) as u64 }
305
+ } else if N <= 32 {
306
+ // Safety: bitmask matches length
307
+ unsafe { to_bitmask_impl :: < T , u32 , 32 , N > ( self ) as u64 }
308
+ } else {
309
+ // Safety: bitmask matches length
310
+ unsafe { to_bitmask_impl :: < T , u64 , 64 , N > ( self ) }
311
+ } ;
312
+
313
+ // LLVM assumes bit order should match endianness
314
+ if cfg ! ( target_endian = "big" ) {
315
+ bitmask. reverse_bits ( ) >> ( 64 - N . min ( 64 ) )
316
+ } else {
317
+ bitmask
318
+ }
292
319
}
293
320
294
321
/// Creates a mask from a bitmask.
@@ -298,7 +325,7 @@ where
298
325
#[ inline]
299
326
#[ must_use = "method returns a new mask and does not mutate the original value" ]
300
327
pub fn from_bitmask ( bitmask : u64 ) -> Self {
301
- Self ( mask_impl :: Mask :: from_bitmask_integer ( bitmask ) )
328
+ Self ( bitmask . select ( Simd :: splat ( T :: TRUE ) , Simd :: splat ( T :: FALSE ) ) )
302
329
}
303
330
304
331
/// Finds the index of the first set element.
@@ -442,7 +469,8 @@ where
442
469
type Output = Self ;
443
470
#[ inline]
444
471
fn bitand ( self , rhs : Self ) -> Self {
445
- Self ( self . 0 & rhs. 0 )
472
+ // Safety: `self` is an integer vector
473
+ unsafe { Self ( core:: intrinsics:: simd:: simd_and ( self . 0 , rhs. 0 ) ) }
446
474
}
447
475
}
448
476
@@ -478,7 +506,8 @@ where
478
506
type Output = Self ;
479
507
#[ inline]
480
508
fn bitor ( self , rhs : Self ) -> Self {
481
- Self ( self . 0 | rhs. 0 )
509
+ // Safety: `self` is an integer vector
510
+ unsafe { Self ( core:: intrinsics:: simd:: simd_or ( self . 0 , rhs. 0 ) ) }
482
511
}
483
512
}
484
513
@@ -514,7 +543,8 @@ where
514
543
type Output = Self ;
515
544
#[ inline]
516
545
fn bitxor ( self , rhs : Self ) -> Self :: Output {
517
- Self ( self . 0 ^ rhs. 0 )
546
+ // Safety: `self` is an integer vector
547
+ unsafe { Self ( core:: intrinsics:: simd:: simd_xor ( self . 0 , rhs. 0 ) ) }
518
548
}
519
549
}
520
550
@@ -550,7 +580,7 @@ where
550
580
type Output = Mask < T , N > ;
551
581
#[ inline]
552
582
fn not ( self ) -> Self :: Output {
553
- Self ( ! self . 0 )
583
+ Self :: splat ( true ) ^ self
554
584
}
555
585
}
556
586
@@ -561,7 +591,7 @@ where
561
591
{
562
592
#[ inline]
563
593
fn bitand_assign ( & mut self , rhs : Self ) {
564
- self . 0 = self . 0 & rhs. 0 ;
594
+ * self = * self & rhs;
565
595
}
566
596
}
567
597
@@ -583,7 +613,7 @@ where
583
613
{
584
614
#[ inline]
585
615
fn bitor_assign ( & mut self , rhs : Self ) {
586
- self . 0 = self . 0 | rhs. 0 ;
616
+ * self = * self | rhs;
587
617
}
588
618
}
589
619
@@ -605,7 +635,7 @@ where
605
635
{
606
636
#[ inline]
607
637
fn bitxor_assign ( & mut self , rhs : Self ) {
608
- self . 0 = self . 0 ^ rhs. 0 ;
638
+ * self = * self ^ rhs;
609
639
}
610
640
}
611
641
0 commit comments