diff --git a/compiler/rustc_middle/src/ty/relate.rs b/compiler/rustc_middle/src/ty/relate.rs index 4c7bcb1bf2e..504a3c8a6d8 100644 --- a/compiler/rustc_middle/src/ty/relate.rs +++ b/compiler/rustc_middle/src/ty/relate.rs @@ -1,7 +1,5 @@ use std::iter; -use rustc_hir as hir; -use rustc_target::spec::abi; pub use rustc_type_ir::relate::*; use crate::ty::error::{ExpectedFound, TypeError}; @@ -121,26 +119,6 @@ impl<'tcx> Relate> for &'tcx ty::List Relate> for hir::Safety { - fn relate>>( - _relation: &mut R, - a: hir::Safety, - b: hir::Safety, - ) -> RelateResult<'tcx, hir::Safety> { - if a != b { Err(TypeError::SafetyMismatch(ExpectedFound::new(true, a, b))) } else { Ok(a) } - } -} - -impl<'tcx> Relate> for abi::Abi { - fn relate>>( - _relation: &mut R, - a: abi::Abi, - b: abi::Abi, - ) -> RelateResult<'tcx, abi::Abi> { - if a == b { Ok(a) } else { Err(TypeError::AbiMismatch(ExpectedFound::new(true, a, b))) } - } -} - impl<'tcx> Relate> for ty::GenericArgsRef<'tcx> { fn relate>>( relation: &mut R, diff --git a/compiler/rustc_middle/src/ty/structural_impls.rs b/compiler/rustc_middle/src/ty/structural_impls.rs index cd9ff9b60d8..4872d8c89eb 100644 --- a/compiler/rustc_middle/src/ty/structural_impls.rs +++ b/compiler/rustc_middle/src/ty/structural_impls.rs @@ -264,8 +264,6 @@ TrivialTypeTraversalImpls! { // interners). TrivialTypeTraversalAndLiftImpls! { ::rustc_hir::def_id::DefId, - ::rustc_hir::Safety, - ::rustc_target::spec::abi::Abi, crate::ty::ClosureKind, crate::ty::ParamConst, crate::ty::ParamTy, @@ -276,6 +274,11 @@ TrivialTypeTraversalAndLiftImpls! { rustc_target::abi::Size, } +TrivialLiftImpls! { + ::rustc_hir::Safety, + ::rustc_target::spec::abi::Abi, +} + /////////////////////////////////////////////////////////////////////////// // Lift implementations diff --git a/compiler/rustc_type_ir/src/error.rs b/compiler/rustc_type_ir/src/error.rs index 8a6d37b7d23..72501945721 100644 --- a/compiler/rustc_type_ir/src/error.rs +++ b/compiler/rustc_type_ir/src/error.rs @@ -29,8 +29,8 @@ pub enum TypeError { Mismatch, ConstnessMismatch(ExpectedFound), PolarityMismatch(ExpectedFound), - SafetyMismatch(ExpectedFound), - AbiMismatch(ExpectedFound), + SafetyMismatch(#[type_visitable(ignore)] ExpectedFound), + AbiMismatch(#[type_visitable(ignore)] ExpectedFound), Mutability, ArgumentMutability(usize), TupleSize(ExpectedFound), diff --git a/compiler/rustc_type_ir/src/inherent.rs b/compiler/rustc_type_ir/src/inherent.rs index f7875bb5152..02ec29a7f3d 100644 --- a/compiler/rustc_type_ir/src/inherent.rs +++ b/compiler/rustc_type_ir/src/inherent.rs @@ -208,14 +208,14 @@ pub trait Tys>: fn output(self) -> I::Ty; } -pub trait Abi>: Copy + Debug + Hash + Eq + Relate { +pub trait Abi>: Copy + Debug + Hash + Eq { fn rust() -> Self; /// Whether this ABI is `extern "Rust"`. fn is_rust(self) -> bool; } -pub trait Safety>: Copy + Debug + Hash + Eq + Relate { +pub trait Safety>: Copy + Debug + Hash + Eq { fn safe() -> Self; fn is_safe(self) -> bool; diff --git a/compiler/rustc_type_ir/src/relate.rs b/compiler/rustc_type_ir/src/relate.rs index a0b93064694..ccb8e9fcf7c 100644 --- a/compiler/rustc_type_ir/src/relate.rs +++ b/compiler/rustc_type_ir/src/relate.rs @@ -174,12 +174,17 @@ impl Relate for ty::FnSig { ExpectedFound::new(true, a, b) })); } - let safety = relation.relate(a.safety, b.safety)?; - let abi = relation.relate(a.abi, b.abi)?; + + if a.safety != b.safety { + return Err(TypeError::SafetyMismatch(ExpectedFound::new(true, a.safety, b.safety))); + } + + if a.abi != b.abi { + return Err(TypeError::AbiMismatch(ExpectedFound::new(true, a.abi, b.abi))); + }; let a_inputs = a.inputs(); let b_inputs = b.inputs(); - if a_inputs.len() != b_inputs.len() { return Err(TypeError::ArgCount); } @@ -212,8 +217,8 @@ impl Relate for ty::FnSig { Ok(ty::FnSig { inputs_and_output: cx.mk_type_list_from_iter(inputs_and_output)?, c_variadic: a.c_variadic, - safety, - abi, + safety: a.safety, + abi: a.abi, }) } } diff --git a/compiler/rustc_type_ir/src/ty_kind.rs b/compiler/rustc_type_ir/src/ty_kind.rs index b7f6ef4ffbb..499e6d3dd37 100644 --- a/compiler/rustc_type_ir/src/ty_kind.rs +++ b/compiler/rustc_type_ir/src/ty_kind.rs @@ -861,7 +861,11 @@ pub struct TypeAndMut { pub struct FnSig { pub inputs_and_output: I::Tys, pub c_variadic: bool, + #[type_visitable(ignore)] + #[type_foldable(identity)] pub safety: I::Safety, + #[type_visitable(ignore)] + #[type_foldable(identity)] pub abi: I::Abi, } diff --git a/compiler/rustc_type_ir/src/ty_kind/closure.rs b/compiler/rustc_type_ir/src/ty_kind/closure.rs index 09a43b17955..10b164eae02 100644 --- a/compiler/rustc_type_ir/src/ty_kind/closure.rs +++ b/compiler/rustc_type_ir/src/ty_kind/closure.rs @@ -372,8 +372,12 @@ pub struct CoroutineClosureSignature { /// Always false pub c_variadic: bool, /// Always `Normal` (safe) + #[type_visitable(ignore)] + #[type_foldable(identity)] pub safety: I::Safety, /// Always `RustCall` + #[type_visitable(ignore)] + #[type_foldable(identity)] pub abi: I::Abi, } diff --git a/compiler/rustc_type_ir_macros/src/lib.rs b/compiler/rustc_type_ir_macros/src/lib.rs index 1a0a2479f6f..aaf69e2648d 100644 --- a/compiler/rustc_type_ir_macros/src/lib.rs +++ b/compiler/rustc_type_ir_macros/src/lib.rs @@ -1,18 +1,73 @@ -use quote::quote; -use syn::parse_quote; +use quote::{ToTokens, quote}; use syn::visit_mut::VisitMut; +use syn::{Attribute, parse_quote}; use synstructure::decl_derive; decl_derive!( - [TypeFoldable_Generic] => type_foldable_derive + [TypeVisitable_Generic, attributes(type_visitable)] => type_visitable_derive ); decl_derive!( - [TypeVisitable_Generic] => type_visitable_derive + [TypeFoldable_Generic, attributes(type_foldable)] => type_foldable_derive ); decl_derive!( [Lift_Generic] => lift_derive ); +fn has_ignore_attr(attrs: &[Attribute], name: &'static str, meta: &'static str) -> bool { + let mut ignored = false; + attrs.iter().for_each(|attr| { + if !attr.path().is_ident(name) { + return; + } + let _ = attr.parse_nested_meta(|nested| { + if nested.path.is_ident(meta) { + ignored = true; + } + Ok(()) + }); + }); + + ignored +} + +fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { + if let syn::Data::Union(_) = s.ast().data { + panic!("cannot derive on union") + } + + if !s.ast().generics.type_params().any(|ty| ty.ident == "I") { + s.add_impl_generic(parse_quote! { I }); + } + + s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_visitable", "ignore")); + + s.add_where_predicate(parse_quote! { I: Interner }); + s.add_bounds(synstructure::AddBounds::Fields); + let body_visit = s.each(|bind| { + quote! { + match ::rustc_ast_ir::visit::VisitorResult::branch( + ::rustc_type_ir::visit::TypeVisitable::visit_with(#bind, __visitor) + ) { + ::core::ops::ControlFlow::Continue(()) => {}, + ::core::ops::ControlFlow::Break(r) => { + return ::rustc_ast_ir::visit::VisitorResult::from_residual(r); + }, + } + } + }); + s.bind_with(|_| synstructure::BindStyle::Move); + + s.bound_impl(quote!(::rustc_type_ir::visit::TypeVisitable), quote! { + fn visit_with<__V: ::rustc_type_ir::visit::TypeVisitor>( + &self, + __visitor: &mut __V + ) -> __V::Result { + match *self { #body_visit } + <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output() + } + }) +} + fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { if let syn::Data::Union(_) = s.ast().data { panic!("cannot derive on union") @@ -29,12 +84,23 @@ fn type_foldable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::Toke let bindings = vi.bindings(); vi.construct(|_, index| { let bind = &bindings[index]; - quote! { - ::rustc_type_ir::fold::TypeFoldable::try_fold_with(#bind, __folder)? + + // retain value of fields with #[type_foldable(identity)] + if has_ignore_attr(&bind.ast().attrs, "type_foldable", "identity") { + bind.to_token_stream() + } else { + quote! { + ::rustc_type_ir::fold::TypeFoldable::try_fold_with(#bind, __folder)? + } } }) }); + // We filter fields which get ignored and don't require them to implement + // `TypeFoldable`. We do so after generating `body_fold` as we still need + // to generate code for them. + s.filter(|bi| !has_ignore_attr(&bi.ast().attrs, "type_foldable", "identity")); + s.add_bounds(synstructure::AddBounds::Fields); s.bound_impl(quote!(::rustc_type_ir::fold::TypeFoldable), quote! { fn try_fold_with<__F: ::rustc_type_ir::fold::FallibleTypeFolder>( self, @@ -113,39 +179,3 @@ fn lift(mut ty: syn::Type) -> syn::Type { ty } - -fn type_visitable_derive(mut s: synstructure::Structure<'_>) -> proc_macro2::TokenStream { - if let syn::Data::Union(_) = s.ast().data { - panic!("cannot derive on union") - } - - if !s.ast().generics.type_params().any(|ty| ty.ident == "I") { - s.add_impl_generic(parse_quote! { I }); - } - - s.add_where_predicate(parse_quote! { I: Interner }); - s.add_bounds(synstructure::AddBounds::Fields); - let body_visit = s.each(|bind| { - quote! { - match ::rustc_ast_ir::visit::VisitorResult::branch( - ::rustc_type_ir::visit::TypeVisitable::visit_with(#bind, __visitor) - ) { - ::core::ops::ControlFlow::Continue(()) => {}, - ::core::ops::ControlFlow::Break(r) => { - return ::rustc_ast_ir::visit::VisitorResult::from_residual(r); - }, - } - } - }); - s.bind_with(|_| synstructure::BindStyle::Move); - - s.bound_impl(quote!(::rustc_type_ir::visit::TypeVisitable), quote! { - fn visit_with<__V: ::rustc_type_ir::visit::TypeVisitor>( - &self, - __visitor: &mut __V - ) -> __V::Result { - match *self { #body_visit } - <__V::Result as ::rustc_ast_ir::visit::VisitorResult>::output() - } - }) -}