Add a special case for align_offset /w stride != 1

This generalizes the previous `stride == 1` special case to apply to any
situation where the requested alignment is divisible by the stride. This
in turn allows the test case from #98809 produce ideal assembly, along
the lines of:

    leaq 15(%rdi), %rax
    andq $-16, %rax

This also produces pretty high quality code for situations where the
alignment of the input pointer isn’t known:

    pub unsafe fn ptr_u32(slice: *const u32) -> *const u32 {
        slice.offset(slice.align_offset(16) as isize)
    }

    // =>

    movl %edi, %eax
    andl $3, %eax
    leaq 15(%rdi), %rcx
    andq $-16, %rcx
    subq %rdi, %rcx
    shrq $2, %rcx
    negq %rax
    sbbq %rax, %rax
    orq  %rcx, %rax
    leaq (%rdi,%rax,4), %rax

Here LLVM is smart enough to replace the `usize::MAX` special case with
a branch-less bitwise-OR approach, where the mask is constructed using
the neg and sbb instructions. This appears to work across various
architectures I’ve tried.

This change ends up introducing more branches and code in situations
where there is less knowledge of the arguments. For example when the
requested alignment is entirely unknown. This use-case was never really
a focus of this function, so I’m not particularly worried, especially
since llvm-mca is saying that the new code is still appreciably faster,
despite all the new branching.

Fixes #98809.
Sadly, this does not help with #72356.
This commit is contained in:
Simonas Kazlauskas 2022-07-04 00:23:31 +03:00
parent 6a10920564
commit 62a182cf7f
3 changed files with 138 additions and 58 deletions

View File

@ -1594,11 +1594,10 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
// FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
// 1, where the method versions of these operations are not inlined.
use intrinsics::{
unchecked_shl, unchecked_shr, unchecked_sub, wrapping_add, wrapping_mul, wrapping_sub,
cttz_nonzero, exact_div, unchecked_rem, unchecked_shl, unchecked_shr, unchecked_sub,
wrapping_add, wrapping_mul, wrapping_sub,
};
let addr = p.addr();
/// Calculate multiplicative modular inverse of `x` modulo `m`.
///
/// This implementation is tailored for `align_offset` and has following preconditions:
@ -1648,36 +1647,61 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
}
}
let addr = p.addr();
let stride = mem::size_of::<T>();
// SAFETY: `a` is a power-of-two, therefore non-zero.
let a_minus_one = unsafe { unchecked_sub(a, 1) };
if stride == 1 {
// `stride == 1` case can be computed more simply through `-p (mod a)`, but doing so
// inhibits LLVM's ability to select instructions like `lea`. Instead we compute
//
// round_up_to_next_alignment(p, a) - p
//
// which distributes operations around the load-bearing, but pessimizing `and` sufficiently
// for LLVM to be able to utilize the various optimizations it knows about.
return wrapping_sub(wrapping_add(addr, a_minus_one) & wrapping_sub(0, a), addr);
if stride == 0 {
// SPECIAL_CASE: handle 0-sized types. No matter how many times we step, the address will
// stay the same, so no offset will be able to align the pointer unless it is already
// aligned. This branch _will_ be optimized out as `stride` is known at compile-time.
let p_mod_a = addr & a_minus_one;
return if p_mod_a == 0 { 0 } else { usize::MAX };
}
let pmoda = addr & a_minus_one;
if pmoda == 0 {
// Already aligned. Yay!
return 0;
} else if stride == 0 {
// If the pointer is not aligned, and the element is zero-sized, then no amount of
// elements will ever align the pointer.
return usize::MAX;
// SAFETY: `stride == 0` case has been handled by the special case above.
let a_mod_stride = unsafe { unchecked_rem(a, stride) };
if a_mod_stride == 0 {
// SPECIAL_CASE: In cases where the `a` is divisible by `stride`, byte offset to align a
// pointer can be computed more simply through `-p (mod a)`. In the off-chance the byte
// offset is not a multiple of `stride`, the input pointer was misaligned and no pointer
// offset will be able to produce a `p` aligned to the specified `a`.
//
// The naive `-p (mod a)` equation inhibits LLVM's ability to select instructions
// like `lea`. We compute `(round_up_to_next_alignment(p, a) - p)` instead. This
// redistributes operations around the load-bearing, but pessimizing `and` instruction
// sufficiently for LLVM to be able to utilize the various optimizations it knows about.
//
// LLVM handles the branch here particularly nicely. If this branch needs to be evaluated
// at runtime, it will produce a mask `if addr_mod_stride == 0 { 0 } else { usize::MAX }`
// in a branch-free way and then bitwise-OR it with whatever result the `-p mod a`
// computation produces.
// SAFETY: `stride == 0` case has been handled by the special case above.
let addr_mod_stride = unsafe { unchecked_rem(addr, stride) };
return if addr_mod_stride == 0 {
let aligned_address = wrapping_add(addr, a_minus_one) & wrapping_sub(0, a);
let byte_offset = wrapping_sub(aligned_address, addr);
// SAFETY: `stride` is non-zero. This is guaranteed to divide exactly as well, because
// addr has been verified to be aligned to the original types alignment requirements.
unsafe { exact_div(byte_offset, stride) }
} else {
usize::MAX
};
}
let smoda = stride & a_minus_one;
// GENERAL_CASE: From here on were handling the very general case where `addr` may be
// misaligned, there isnt an obvious relationship between `stride` and `a` that we can take an
// advantage of, etc. This case produces machine code that isnt particularly high quality,
// compared to the special cases above. The code produced here is still within the realm of
// miracles, given the situations this case has to deal with.
// SAFETY: a is power-of-two hence non-zero. stride == 0 case is handled above.
let gcdpow = unsafe { intrinsics::cttz_nonzero(stride).min(intrinsics::cttz_nonzero(a)) };
let gcdpow = unsafe { cttz_nonzero(stride).min(cttz_nonzero(a)) };
// SAFETY: gcdpow has an upper-bound thats at most the number of bits in a usize.
let gcd = unsafe { unchecked_shl(1usize, gcdpow) };
// SAFETY: gcd is always greater or equal to 1.
if addr & unsafe { unchecked_sub(gcd, 1) } == 0 {
// This branch solves for the following linear congruence equation:
@ -1693,14 +1717,13 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
// ` p' + s'o = 0 mod a' `
// ` o = (a' - (p' mod a')) * (s'^-1 mod a') `
//
// The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the second
// term is "how does incrementing `p` by `s` bytes change the relative alignment of `p`" (again
// divided by `g`).
// Division by `g` is necessary to make the inverse well formed if `a` and `s` are not
// co-prime.
// The first term is "the relative alignment of `p` to `a`" (divided by the `g`), the
// second term is "how does incrementing `p` by `s` bytes change the relative alignment of
// `p`" (again divided by `g`). Division by `g` is necessary to make the inverse well
// formed if `a` and `s` are not co-prime.
//
// Furthermore, the result produced by this solution is not "minimal", so it is necessary
// to take the result `o mod lcm(s, a)`. We can replace `lcm(s, a)` with just a `a'`.
// to take the result `o mod lcm(s, a)`. This `lcm(s, a)` is the same as `a'`.
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`.
@ -1710,11 +1733,11 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
let a2minus1 = unsafe { unchecked_sub(a2, 1) };
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`.
let s2 = unsafe { unchecked_shr(smoda, gcdpow) };
let s2 = unsafe { unchecked_shr(stride & a_minus_one, gcdpow) };
// SAFETY: `gcdpow` has an upper-bound not greater than the number of trailing 0-bits in
// `a`. Furthermore, the subtraction cannot overflow, because `a2 = a >> gcdpow` will
// always be strictly greater than `(p % a) >> gcdpow`.
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(pmoda, gcdpow)) };
let minusp2 = unsafe { unchecked_sub(a2, unchecked_shr(addr & a_minus_one, gcdpow)) };
// SAFETY: `a2` is a power-of-two, as proven above. `s2` is strictly less than `a2`
// because `(s % a) >> gcdpow` is strictly less than `a >> gcdpow`.
return wrapping_mul(minusp2, unsafe { mod_inv(s2, a2) }) & a2minus1;

View File

@ -362,7 +362,7 @@ fn align_offset_zst() {
}
#[test]
fn align_offset_stride1() {
fn align_offset_stride_one() {
// For pointers of stride = 1, the pointer can always be aligned. The offset is equal to
// number of bytes.
let mut align = 1;
@ -383,24 +383,8 @@ fn align_offset_stride1() {
}
#[test]
fn align_offset_weird_strides() {
#[repr(packed)]
struct A3(u16, u8);
struct A4(u32);
#[repr(packed)]
struct A5(u32, u8);
#[repr(packed)]
struct A6(u32, u16);
#[repr(packed)]
struct A7(u32, u16, u8);
#[repr(packed)]
struct A8(u32, u32);
#[repr(packed)]
struct A9(u32, u32, u8);
#[repr(packed)]
struct A10(u32, u32, u16);
unsafe fn test_weird_stride<T>(ptr: *const T, align: usize) -> bool {
fn align_offset_various_strides() {
unsafe fn test_stride<T>(ptr: *const T, align: usize) -> bool {
let numptr = ptr as usize;
let mut expected = usize::MAX;
// Naive but definitely correct way to find the *first* aligned element of stride::<T>.
@ -434,14 +418,39 @@ fn align_offset_weird_strides() {
while align < limit {
for ptr in 1usize..4 * align {
unsafe {
x |= test_weird_stride::<A3>(ptr::invalid::<A3>(ptr), align);
x |= test_weird_stride::<A4>(ptr::invalid::<A4>(ptr), align);
x |= test_weird_stride::<A5>(ptr::invalid::<A5>(ptr), align);
x |= test_weird_stride::<A6>(ptr::invalid::<A6>(ptr), align);
x |= test_weird_stride::<A7>(ptr::invalid::<A7>(ptr), align);
x |= test_weird_stride::<A8>(ptr::invalid::<A8>(ptr), align);
x |= test_weird_stride::<A9>(ptr::invalid::<A9>(ptr), align);
x |= test_weird_stride::<A10>(ptr::invalid::<A10>(ptr), align);
#[repr(packed)]
struct A3(u16, u8);
x |= test_stride::<A3>(ptr::invalid::<A3>(ptr), align);
struct A4(u32);
x |= test_stride::<A4>(ptr::invalid::<A4>(ptr), align);
#[repr(packed)]
struct A5(u32, u8);
x |= test_stride::<A5>(ptr::invalid::<A5>(ptr), align);
#[repr(packed)]
struct A6(u32, u16);
x |= test_stride::<A6>(ptr::invalid::<A6>(ptr), align);
#[repr(packed)]
struct A7(u32, u16, u8);
x |= test_stride::<A7>(ptr::invalid::<A7>(ptr), align);
#[repr(packed)]
struct A8(u32, u32);
x |= test_stride::<A8>(ptr::invalid::<A8>(ptr), align);
#[repr(packed)]
struct A9(u32, u32, u8);
x |= test_stride::<A9>(ptr::invalid::<A9>(ptr), align);
#[repr(packed)]
struct A10(u32, u32, u16);
x |= test_stride::<A10>(ptr::invalid::<A10>(ptr), align);
x |= test_stride::<u32>(ptr::invalid::<u32>(ptr), align);
x |= test_stride::<u128>(ptr::invalid::<u128>(ptr), align);
}
}
align = (align + 1).next_power_of_two();

View File

@ -0,0 +1,48 @@
// assembly-output: emit-asm
// compile-flags: -Copt-level=1
// only-x86_64
// min-llvm-version: 14.0
#![crate_type="rlib"]
// CHECK-LABEL: align_offset_byte_ptr
// CHECK: leaq 31
// CHECK: andq $-32
// CHECK: subq
#[no_mangle]
pub fn align_offset_byte_ptr(ptr: *const u8) -> usize {
ptr.align_offset(32)
}
// CHECK-LABEL: align_offset_byte_slice
// CHECK: leaq 31
// CHECK: andq $-32
// CHECK: subq
#[no_mangle]
pub fn align_offset_byte_slice(slice: &[u8]) -> usize {
slice.as_ptr().align_offset(32)
}
// CHECK-LABEL: align_offset_word_ptr
// CHECK: leaq 31
// CHECK: andq $-32
// CHECK: subq
// CHECK: shrq
// This `ptr` is not known to be aligned, so it is required to check if it is at all possible to
// align. LLVM applies a simple mask.
// CHECK: orq
#[no_mangle]
pub fn align_offset_word_ptr(ptr: *const u32) -> usize {
ptr.align_offset(32)
}
// CHECK-LABEL: align_offset_word_slice
// CHECK: leaq 31
// CHECK: andq $-32
// CHECK: subq
// CHECK: shrq
// `slice` is known to be aligned, so `!0` is not possible as a return
// CHECK-NOT: orq
#[no_mangle]
pub fn align_offset_word_slice(slice: &[u32]) -> usize {
slice.as_ptr().align_offset(32)
}