From 11db9de7289d0405f40cfd3cb0f9acda7e5b9564 Mon Sep 17 00:00:00 2001 From: Ralf Jung Date: Sun, 3 Dec 2023 11:28:40 +0100 Subject: [PATCH] simd_select_bitmask: support passing the mask as an array --- src/tools/miri/src/shims/intrinsics/simd.rs | 15 ++++++++++++--- src/tools/miri/tests/pass/portable-simd.rs | 16 ++++++++++++++++ 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/tools/miri/src/shims/intrinsics/simd.rs b/src/tools/miri/src/shims/intrinsics/simd.rs index d0a293d5f81..63af0814c8b 100644 --- a/src/tools/miri/src/shims/intrinsics/simd.rs +++ b/src/tools/miri/src/shims/intrinsics/simd.rs @@ -386,7 +386,6 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let (dest, dest_len) = this.place_to_simd(dest)?; let bitmask_len = dest_len.max(8); - assert!(mask.layout.ty.is_integral()); assert!(bitmask_len <= 64); assert_eq!(bitmask_len, mask.layout.size.bits()); assert_eq!(dest_len, yes_len); @@ -394,8 +393,18 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> { let dest_len = u32::try_from(dest_len).unwrap(); let bitmask_len = u32::try_from(bitmask_len).unwrap(); - let mask: u64 = - this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(); + // The mask can be a single integer or an array. + let mask: u64 = match mask.layout.ty.kind() { + ty::Int(..) | ty::Uint(..) => + this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(), + ty::Array(elem, _) if matches!(elem.kind(), ty::Uint(ty::UintTy::U8)) => { + let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap(); + let mask = mask.transmute(mask_ty, this)?; + this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap() + } + _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty), + }; + for i in 0..dest_len { let mask = mask & 1u64 diff --git a/src/tools/miri/tests/pass/portable-simd.rs b/src/tools/miri/tests/pass/portable-simd.rs index 2179bcf1c38..1ef9d8f38c0 100644 --- a/src/tools/miri/tests/pass/portable-simd.rs +++ b/src/tools/miri/tests/pass/portable-simd.rs @@ -247,6 +247,22 @@ fn simd_mask() { assert_eq!(bitmask2, [0b0001]); } } + + // This used to cause an ICE. + let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]); + assert_eq!( + mask32x8::from_bitmask_vector(bitmask), + mask32x8::from_array([true, false, true, false, false, false, true, false]), + ); + let bitmask = + u8x16::from_array([0b01000101, 0b11110000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + assert_eq!( + mask32x16::from_bitmask_vector(bitmask), + mask32x16::from_array([ + true, false, true, false, false, false, true, false, false, false, false, false, true, + true, true, true, + ]), + ); } fn simd_cast() {