From 40f41e7e896e7d68534d05e955c20330b13f6d0b Mon Sep 17 00:00:00 2001
From: Matthew Maurer <mmaurer@google.com>
Date: Sat, 23 Mar 2024 01:19:43 +0000
Subject: [PATCH] CFI: Support arbitrary receivers

Previously, we only rewrote `&self` and `&mut self` receivers. By
instantiating the method from the trait definition, we can make this
work work with arbitrary legal receivers instead.
---
 Cargo.lock                                    |   1 +
 compiler/rustc_symbol_mangling/Cargo.toml     |   1 +
 compiler/rustc_symbol_mangling/src/lib.rs     |   1 +
 .../src/typeid/typeid_itanium_cxx_abi.rs      | 109 +++++++++++-------
 tests/ui/sanitizer/cfi-complex-receiver.rs    |  42 +++++++
 5 files changed, 115 insertions(+), 39 deletions(-)
 create mode 100644 tests/ui/sanitizer/cfi-complex-receiver.rs

diff --git a/Cargo.lock b/Cargo.lock
index b8fe1ebaf80..ad48620bd90 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4634,6 +4634,7 @@ dependencies = [
  "rustc_session",
  "rustc_span",
  "rustc_target",
+ "rustc_trait_selection",
  "tracing",
  "twox-hash",
 ]
diff --git a/compiler/rustc_symbol_mangling/Cargo.toml b/compiler/rustc_symbol_mangling/Cargo.toml
index 0ce522c9cab..1c8f1d03670 100644
--- a/compiler/rustc_symbol_mangling/Cargo.toml
+++ b/compiler/rustc_symbol_mangling/Cargo.toml
@@ -15,6 +15,7 @@ rustc_middle = { path = "../rustc_middle" }
 rustc_session = { path = "../rustc_session" }
 rustc_span = { path = "../rustc_span" }
 rustc_target = { path = "../rustc_target" }
+rustc_trait_selection = { path = "../rustc_trait_selection" }
 tracing = "0.1"
 twox-hash = "1.6.3"
 # tidy-alphabetical-end
diff --git a/compiler/rustc_symbol_mangling/src/lib.rs b/compiler/rustc_symbol_mangling/src/lib.rs
index 02bb1fde75c..0588af9bda7 100644
--- a/compiler/rustc_symbol_mangling/src/lib.rs
+++ b/compiler/rustc_symbol_mangling/src/lib.rs
@@ -90,6 +90,7 @@
 #![doc(html_root_url = "https://doc.rust-lang.org/nightly/nightly-rustc/")]
 #![doc(rust_logo)]
 #![feature(rustdoc_internals)]
+#![feature(let_chains)]
 #![allow(internal_features)]
 
 #[macro_use]
diff --git a/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs b/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
index 367fec0e8fc..9406038232d 100644
--- a/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
+++ b/compiler/rustc_symbol_mangling/src/typeid/typeid_itanium_cxx_abi.rs
@@ -11,6 +11,7 @@ use rustc_data_structures::base_n;
 use rustc_data_structures::fx::FxHashMap;
 use rustc_hir as hir;
 use rustc_middle::ty::layout::IntegerExt;
+use rustc_middle::ty::TypeVisitableExt;
 use rustc_middle::ty::{
     self, Const, ExistentialPredicate, FloatTy, FnSig, Instance, IntTy, List, Region, RegionKind,
     TermKind, Ty, TyCtxt, UintTy,
@@ -21,7 +22,9 @@ use rustc_span::sym;
 use rustc_target::abi::call::{Conv, FnAbi, PassMode};
 use rustc_target::abi::Integer;
 use rustc_target::spec::abi::Abi;
+use rustc_trait_selection::traits;
 use std::fmt::Write as _;
+use std::iter;
 
 use crate::typeid::TypeIdOptions;
 
@@ -1115,51 +1118,46 @@ pub fn typeid_for_instance<'tcx>(
         instance.args = strip_receiver_auto(tcx, instance.args)
     }
 
+    if let Some(impl_id) = tcx.impl_of_method(instance.def_id())
+        && let Some(trait_ref) = tcx.impl_trait_ref(impl_id)
+    {
+        let impl_method = tcx.associated_item(instance.def_id());
+        let method_id = impl_method
+            .trait_item_def_id
+            .expect("Part of a trait implementation, but not linked to the def_id?");
+        let trait_method = tcx.associated_item(method_id);
+        if traits::is_vtable_safe_method(tcx, trait_ref.skip_binder().def_id, trait_method) {
+            // Trait methods will have a Self polymorphic parameter, where the concreteized
+            // implementatation will not. We need to walk back to the more general trait method
+            let trait_ref = tcx.instantiate_and_normalize_erasing_regions(
+                instance.args,
+                ty::ParamEnv::reveal_all(),
+                trait_ref,
+            );
+            let invoke_ty = trait_object_ty(tcx, ty::Binder::dummy(trait_ref));
+
+            // At the call site, any call to this concrete function through a vtable will be
+            // `Virtual(method_id, idx)` with appropriate arguments for the method. Since we have the
+            // original method id, and we've recovered the trait arguments, we can make the callee
+            // instance we're computing the alias set for match the caller instance.
+            //
+            // Right now, our code ignores the vtable index everywhere, so we use 0 as a placeholder.
+            // If we ever *do* start encoding the vtable index, we will need to generate an alias set
+            // based on which vtables we are putting this method into, as there will be more than one
+            // index value when supertraits are involved.
+            instance.def = ty::InstanceDef::Virtual(method_id, 0);
+            let abstract_trait_args =
+                tcx.mk_args_trait(invoke_ty, trait_ref.args.into_iter().skip(1));
+            instance.args = instance.args.rebase_onto(tcx, impl_id, abstract_trait_args);
+        }
+    }
+
     let fn_abi = tcx
         .fn_abi_of_instance(tcx.param_env(instance.def_id()).and((instance, ty::List::empty())))
         .unwrap_or_else(|instance| {
             bug!("typeid_for_instance: couldn't get fn_abi of instance {:?}", instance)
         });
 
-    // If this instance is a method and self is a reference, get the impl it belongs to
-    let impl_def_id = tcx.impl_of_method(instance.def_id());
-    if impl_def_id.is_some() && !fn_abi.args.is_empty() && fn_abi.args[0].layout.ty.is_ref() {
-        // If this impl is not an inherent impl, get the trait it implements
-        if let Some(trait_ref) = tcx.impl_trait_ref(impl_def_id.unwrap()) {
-            // Transform the concrete self into a reference to a trait object
-            let existential_predicate = trait_ref.map_bound(|trait_ref| {
-                ty::ExistentialPredicate::Trait(ty::ExistentialTraitRef::erase_self_ty(
-                    tcx, trait_ref,
-                ))
-            });
-            let existential_predicates = tcx.mk_poly_existential_predicates(&[ty::Binder::dummy(
-                existential_predicate.skip_binder(),
-            )]);
-            // Is the concrete self mutable?
-            let self_ty = if fn_abi.args[0].layout.ty.is_mutable_ptr() {
-                Ty::new_mut_ref(
-                    tcx,
-                    tcx.lifetimes.re_erased,
-                    Ty::new_dynamic(tcx, existential_predicates, tcx.lifetimes.re_erased, ty::Dyn),
-                )
-            } else {
-                Ty::new_imm_ref(
-                    tcx,
-                    tcx.lifetimes.re_erased,
-                    Ty::new_dynamic(tcx, existential_predicates, tcx.lifetimes.re_erased, ty::Dyn),
-                )
-            };
-
-            // Replace the concrete self in an fn_abi clone by the reference to a trait object
-            let mut fn_abi = fn_abi.clone();
-            // HACK(rcvalle): It is okay to not replace or update the entire ArgAbi here because the
-            //   other fields are never used.
-            fn_abi.args[0].layout.ty = self_ty;
-
-            return typeid_for_fnabi(tcx, &fn_abi, options);
-        }
-    }
-
     typeid_for_fnabi(tcx, fn_abi, options)
 }
 
@@ -1182,3 +1180,36 @@ fn strip_receiver_auto<'tcx>(
     let new_rcvr = Ty::new_dynamic(tcx, filtered_preds, *lifetime, *kind);
     tcx.mk_args_trait(new_rcvr, args.into_iter().skip(1))
 }
+
+fn trait_object_ty<'tcx>(tcx: TyCtxt<'tcx>, poly_trait_ref: ty::PolyTraitRef<'tcx>) -> Ty<'tcx> {
+    assert!(!poly_trait_ref.has_non_region_param());
+    let principal_pred = poly_trait_ref.map_bound(|trait_ref| {
+        ty::ExistentialPredicate::Trait(ty::ExistentialTraitRef::erase_self_ty(tcx, trait_ref))
+    });
+    let mut assoc_preds: Vec<_> = traits::supertraits(tcx, poly_trait_ref)
+        .flat_map(|super_poly_trait_ref| {
+            tcx.associated_items(super_poly_trait_ref.def_id())
+                .in_definition_order()
+                .filter(|item| item.kind == ty::AssocKind::Type)
+                .map(move |assoc_ty| {
+                    super_poly_trait_ref.map_bound(|super_trait_ref| {
+                        let alias_ty = ty::AliasTy::new(tcx, assoc_ty.def_id, super_trait_ref.args);
+                        let resolved = tcx.normalize_erasing_regions(
+                            ty::ParamEnv::reveal_all(),
+                            alias_ty.to_ty(tcx),
+                        );
+                        ty::ExistentialPredicate::Projection(ty::ExistentialProjection {
+                            def_id: assoc_ty.def_id,
+                            args: ty::ExistentialTraitRef::erase_self_ty(tcx, super_trait_ref).args,
+                            term: resolved.into(),
+                        })
+                    })
+                })
+        })
+        .collect();
+    assoc_preds.sort_by(|a, b| a.skip_binder().stable_cmp(tcx, &b.skip_binder()));
+    let preds = tcx.mk_poly_existential_predicates_from_iter(
+        iter::once(principal_pred).chain(assoc_preds.into_iter()),
+    );
+    Ty::new_dynamic(tcx, preds, tcx.lifetimes.re_erased, ty::Dyn)
+}
diff --git a/tests/ui/sanitizer/cfi-complex-receiver.rs b/tests/ui/sanitizer/cfi-complex-receiver.rs
new file mode 100644
index 00000000000..c3e59258db2
--- /dev/null
+++ b/tests/ui/sanitizer/cfi-complex-receiver.rs
@@ -0,0 +1,42 @@
+// Check that more complex receivers work:
+// * Arc<dyn Foo> as for custom receivers
+// * &dyn Bar<T=Baz> for type constraints
+
+//@ needs-sanitizer-cfi
+// FIXME(#122848) Remove only-linux once OSX CFI binaries work
+//@ only-linux
+//@ compile-flags: --crate-type=bin -Cprefer-dynamic=off -Clto -Zsanitizer=cfi
+//@ compile-flags: -C target-feature=-crt-static -C codegen-units=1 -C opt-level=0
+//@ run-pass
+
+use std::sync::Arc;
+
+trait Foo {
+    fn foo(self: Arc<Self>);
+}
+
+struct FooImpl;
+
+impl Foo for FooImpl {
+    fn foo(self: Arc<Self>) {}
+}
+
+trait Bar {
+    type T;
+    fn bar(&self) -> Self::T;
+}
+
+struct BarImpl;
+
+impl Bar for BarImpl {
+    type T = i32;
+    fn bar(&self) -> Self::T { 7 }
+}
+
+fn main() {
+    let foo: Arc<dyn Foo> = Arc::new(FooImpl);
+    foo.foo();
+
+    let bar: &dyn Bar<T=i32> = &BarImpl;
+    assert_eq!(bar.bar(), 7);
+}