From a07eb0abbd24e21716d476d7c363270c907f5d2e Mon Sep 17 00:00:00 2001
From: Erik Desjardins <erikdesjardins@users.noreply.github.com>
Date: Mon, 15 May 2023 00:49:12 -0400
Subject: [PATCH] implement vector-containing aggregate alignment for x86
 darwin

---
 compiler/rustc_target/src/abi/call/x86.rs | 82 ++++++++++++++---------
 tests/codegen/align-byval-vector.rs       | 58 ++++++++++++++++
 2 files changed, 109 insertions(+), 31 deletions(-)
 create mode 100644 tests/codegen/align-byval-vector.rs

diff --git a/compiler/rustc_target/src/abi/call/x86.rs b/compiler/rustc_target/src/abi/call/x86.rs
index 58c0717b7d1..d2c604fafa6 100644
--- a/compiler/rustc_target/src/abi/call/x86.rs
+++ b/compiler/rustc_target/src/abi/call/x86.rs
@@ -1,5 +1,5 @@
 use crate::abi::call::{ArgAttribute, FnAbi, PassMode, Reg, RegKind};
-use crate::abi::{Align, HasDataLayout, TyAbiInterface};
+use crate::abi::{Abi, Align, HasDataLayout, TyAbiInterface, TyAndLayout};
 use crate::spec::HasTargetSpec;
 
 #[derive(PartialEq)]
@@ -53,38 +53,58 @@ where
         if arg.is_ignore() {
             continue;
         }
-        if !arg.layout.is_aggregate() {
-            arg.extend_integer_width_to(32);
-            continue;
-        }
 
-        // We need to compute the alignment of the `byval` argument. The rules can be found in
-        // `X86_32ABIInfo::getTypeStackAlignInBytes` in Clang's `TargetInfo.cpp`. Summarized here,
-        // they are:
-        //
-        // 1. If the natural alignment of the type is less than or equal to 4, the alignment is 4.
-        //
-        // 2. Otherwise, on Linux, the alignment of any vector type is the natural alignment.
-        // (This doesn't matter here because we ensure we have an aggregate with the check above.)
-        //
-        // 3. Otherwise, on Apple platforms, the alignment of anything that contains a vector type
-        // is 16.
-        //
-        // 4. If none of these conditions are true, the alignment is 4.
-        let t = cx.target_spec();
-        let align_4 = Align::from_bytes(4).unwrap();
-        let align_16 = Align::from_bytes(16).unwrap();
-        let byval_align = if arg.layout.align.abi < align_4 {
-            align_4
-        } else if t.is_like_osx && arg.layout.align.abi >= align_16 {
-            // FIXME(pcwalton): This is dubious--we should actually be looking inside the type to
-            // determine if it contains SIMD vector values--but I think it's fine?
-            align_16
+        if arg.layout.is_aggregate() {
+            // We need to compute the alignment of the `byval` argument. The rules can be found in
+            // `X86_32ABIInfo::getTypeStackAlignInBytes` in Clang's `TargetInfo.cpp`. Summarized
+            // here, they are:
+            //
+            // 1. If the natural alignment of the type is <= 4, the alignment is 4.
+            //
+            // 2. Otherwise, on Linux, the alignment of any vector type is the natural alignment.
+            // This doesn't matter here because we only pass aggregates via `byval`, not vectors.
+            //
+            // 3. Otherwise, on Apple platforms, the alignment of anything that contains a vector
+            // type is 16.
+            //
+            // 4. If none of these conditions are true, the alignment is 4.
+
+            fn contains_vector<'a, Ty, C>(cx: &C, layout: TyAndLayout<'a, Ty>) -> bool
+            where
+                Ty: TyAbiInterface<'a, C> + Copy,
+            {
+                match layout.abi {
+                    Abi::Uninhabited | Abi::Scalar(_) | Abi::ScalarPair(..) => false,
+                    Abi::Vector { .. } => true,
+                    Abi::Aggregate { .. } => {
+                        for i in 0..layout.fields.count() {
+                            if contains_vector(cx, layout.field(cx, i)) {
+                                return true;
+                            }
+                        }
+                        false
+                    }
+                }
+            }
+
+            let t = cx.target_spec();
+            let align_4 = Align::from_bytes(4).unwrap();
+            let align_16 = Align::from_bytes(16).unwrap();
+            let byval_align = if arg.layout.align.abi < align_4 {
+                // (1.)
+                align_4
+            } else if t.is_like_osx && contains_vector(cx, arg.layout) {
+                // (3.)
+                align_16
+            } else {
+                // (4.)
+                align_4
+            };
+
+            arg.make_indirect_byval(Some(byval_align));
         } else {
-            align_4
-        };
-
-        arg.make_indirect_byval(Some(byval_align));
+            arg.extend_integer_width_to(32);
+        }
     }
 
     if flavor == Flavor::FastcallOrVectorcall {
diff --git a/tests/codegen/align-byval-vector.rs b/tests/codegen/align-byval-vector.rs
new file mode 100644
index 00000000000..3c8be659671
--- /dev/null
+++ b/tests/codegen/align-byval-vector.rs
@@ -0,0 +1,58 @@
+// revisions:x86-linux x86-darwin
+
+//[x86-linux] compile-flags: --target i686-unknown-linux-gnu
+//[x86-linux] needs-llvm-components: x86
+//[x86-darwin] compile-flags: --target i686-apple-darwin
+//[x86-darwin] needs-llvm-components: x86
+
+// Tests that aggregates containing vector types get their alignment increased to 16 on Darwin.
+
+#![feature(no_core, lang_items, repr_simd, simd_ffi)]
+#![crate_type = "lib"]
+#![no_std]
+#![no_core]
+#![allow(non_camel_case_types)]
+
+#[lang = "sized"]
+trait Sized {}
+#[lang = "freeze"]
+trait Freeze {}
+#[lang = "copy"]
+trait Copy {}
+
+#[repr(simd)]
+pub struct i32x4(i32, i32, i32, i32);
+
+#[repr(C)]
+pub struct Foo {
+    a: i32x4,
+    b: i8,
+}
+
+// This tests that we recursively check for vector types, not just at the top level.
+#[repr(C)]
+pub struct DoubleFoo {
+    one: Foo,
+    two: Foo,
+}
+
+extern "C" {
+    // x86-linux: declare void @f({{.*}}byval(%Foo) align 4{{.*}})
+    // x86-darwin: declare void @f({{.*}}byval(%Foo) align 16{{.*}})
+    fn f(foo: Foo);
+
+    // x86-linux: declare void @g({{.*}}byval(%DoubleFoo) align 4{{.*}})
+    // x86-darwin: declare void @g({{.*}}byval(%DoubleFoo) align 16{{.*}})
+    fn g(foo: DoubleFoo);
+}
+
+pub fn main() {
+    unsafe { f(Foo { a: i32x4(1, 2, 3, 4), b: 0 }) }
+
+    unsafe {
+        g(DoubleFoo {
+            one: Foo { a: i32x4(1, 2, 3, 4), b: 0 },
+            two: Foo { a: i32x4(1, 2, 3, 4), b: 0 },
+        })
+    }
+}