diff --git a/compiler/rustc_smir/src/rustc_internal/internal.rs b/compiler/rustc_smir/src/rustc_internal/internal.rs index 17162d0de25..5689e8f3b3d 100644 --- a/compiler/rustc_smir/src/rustc_internal/internal.rs +++ b/compiler/rustc_smir/src/rustc_internal/internal.rs @@ -17,7 +17,7 @@ use stable_mir::ty::{ GenericArgKind, GenericArgs, IndexedVal, IntTy, Movability, Region, RigidTy, Span, TermKind, TraitRef, Ty, UintTy, VariantDef, VariantIdx, }; -use stable_mir::{CrateItem, DefId}; +use stable_mir::{CrateItem, CrateNum, DefId}; use super::RustcInternal; @@ -28,6 +28,13 @@ impl<'tcx> RustcInternal<'tcx> for CrateItem { } } +impl<'tcx> RustcInternal<'tcx> for CrateNum { + type T = rustc_span::def_id::CrateNum; + fn internal(&self, _tables: &mut Tables<'tcx>) -> Self::T { + rustc_span::def_id::CrateNum::from_usize(*self) + } +} + impl<'tcx> RustcInternal<'tcx> for DefId { type T = rustc_span::def_id::DefId; fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { diff --git a/compiler/rustc_smir/src/rustc_smir/context.rs b/compiler/rustc_smir/src/rustc_smir/context.rs index f84c466cc44..fffc454804d 100644 --- a/compiler/rustc_smir/src/rustc_smir/context.rs +++ b/compiler/rustc_smir/src/rustc_smir/context.rs @@ -25,8 +25,9 @@ use stable_mir::ty::{ AdtDef, AdtKind, Allocation, ClosureDef, ClosureKind, Const, FieldDef, FnDef, GenericArgs, LineInfo, PolyFnSig, RigidTy, Span, Ty, TyKind, VariantDef, }; -use stable_mir::{Crate, CrateItem, DefId, Error, Filename, ItemKind, Symbol}; +use stable_mir::{Crate, CrateItem, CrateNum, DefId, Error, Filename, ItemKind, Symbol}; use std::cell::RefCell; +use std::iter; use crate::rustc_internal::{internal, RustcInternal}; use crate::rustc_smir::builder::BodyBuilder; @@ -67,10 +68,15 @@ impl<'tcx> Context for TablesWrapper<'tcx> { } fn all_trait_decls(&self) -> stable_mir::TraitDecls { + let mut tables = self.0.borrow_mut(); + tables.tcx.all_traits().map(|trait_def_id| tables.trait_def(trait_def_id)).collect() + } + + fn trait_decls(&self, crate_num: CrateNum) -> stable_mir::TraitDecls { let mut tables = self.0.borrow_mut(); tables .tcx - .traits(LOCAL_CRATE) + .traits(crate_num.internal(&mut *tables)) .iter() .map(|trait_def_id| tables.trait_def(*trait_def_id)) .collect() @@ -84,10 +90,20 @@ impl<'tcx> Context for TablesWrapper<'tcx> { } fn all_trait_impls(&self) -> stable_mir::ImplTraitDecls { + let mut tables = self.0.borrow_mut(); + let tcx = tables.tcx; + iter::once(LOCAL_CRATE) + .chain(tables.tcx.crates(()).iter().copied()) + .flat_map(|cnum| tcx.trait_impls_in_crate(cnum).iter()) + .map(|impl_def_id| tables.impl_def(*impl_def_id)) + .collect() + } + + fn trait_impls(&self, crate_num: CrateNum) -> stable_mir::ImplTraitDecls { let mut tables = self.0.borrow_mut(); tables .tcx - .trait_impls_in_crate(LOCAL_CRATE) + .trait_impls_in_crate(crate_num.internal(&mut *tables)) .iter() .map(|impl_def_id| tables.impl_def(*impl_def_id)) .collect() diff --git a/compiler/stable_mir/src/compiler_interface.rs b/compiler/stable_mir/src/compiler_interface.rs index f52e506059b..fb83dae5714 100644 --- a/compiler/stable_mir/src/compiler_interface.rs +++ b/compiler/stable_mir/src/compiler_interface.rs @@ -16,8 +16,8 @@ use crate::ty::{ TraitDef, Ty, TyKind, VariantDef, }; use crate::{ - mir, Crate, CrateItem, CrateItems, DefId, Error, Filename, ImplTraitDecls, ItemKind, Symbol, - TraitDecls, + mir, Crate, CrateItem, CrateItems, CrateNum, DefId, Error, Filename, ImplTraitDecls, ItemKind, + Symbol, TraitDecls, }; /// This trait defines the interface between stable_mir and the Rust compiler. @@ -32,8 +32,10 @@ pub trait Context { /// Check whether the body of a function is available. fn has_body(&self, item: DefId) -> bool; fn all_trait_decls(&self) -> TraitDecls; + fn trait_decls(&self, crate_num: CrateNum) -> TraitDecls; fn trait_decl(&self, trait_def: &TraitDef) -> TraitDecl; fn all_trait_impls(&self) -> ImplTraitDecls; + fn trait_impls(&self, crate_num: CrateNum) -> ImplTraitDecls; fn trait_impl(&self, trait_impl: &ImplDef) -> ImplTrait; fn generics_of(&self, def_id: DefId) -> Generics; fn predicates_of(&self, def_id: DefId) -> GenericPredicates; diff --git a/compiler/stable_mir/src/lib.rs b/compiler/stable_mir/src/lib.rs index 9194f1e6bdb..de5dfcdf207 100644 --- a/compiler/stable_mir/src/lib.rs +++ b/compiler/stable_mir/src/lib.rs @@ -31,7 +31,7 @@ pub use crate::error::*; use crate::mir::pretty::function_name; use crate::mir::Body; use crate::mir::Mutability; -use crate::ty::{ImplDef, ImplTrait, IndexedVal, Span, TraitDecl, TraitDef, Ty}; +use crate::ty::{ImplDef, IndexedVal, Span, TraitDef, Ty}; pub mod abi; #[macro_use] @@ -86,6 +86,18 @@ pub struct Crate { pub is_local: bool, } +impl Crate { + /// The list of traits declared in this crate. + pub fn trait_decls(&self) -> TraitDecls { + with(|cx| cx.trait_decls(self.id)) + } + + /// The list of trait implementations in this crate. + pub fn trait_impls(&self) -> ImplTraitDecls { + with(|cx| cx.trait_impls(self.id)) + } +} + #[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] pub enum ItemKind { Fn, @@ -169,18 +181,10 @@ pub fn all_trait_decls() -> TraitDecls { with(|cx| cx.all_trait_decls()) } -pub fn trait_decl(trait_def: &TraitDef) -> TraitDecl { - with(|cx| cx.trait_decl(trait_def)) -} - pub fn all_trait_impls() -> ImplTraitDecls { with(|cx| cx.all_trait_impls()) } -pub fn trait_impl(trait_impl: &ImplDef) -> ImplTrait { - with(|cx| cx.trait_impl(trait_impl)) -} - /// A type that provides internal information but that can still be used for debug purpose. #[derive(Clone, PartialEq, Eq, Hash)] pub struct Opaque(String); diff --git a/compiler/stable_mir/src/ty.rs b/compiler/stable_mir/src/ty.rs index 9e6ecbe8315..eba2ac57012 100644 --- a/compiler/stable_mir/src/ty.rs +++ b/compiler/stable_mir/src/ty.rs @@ -714,9 +714,16 @@ crate_def! { } crate_def! { + /// A trait's definition. pub TraitDef; } +impl TraitDef { + pub fn declaration(trait_def: &TraitDef) -> TraitDecl { + with(|cx| cx.trait_decl(trait_def)) + } +} + crate_def! { pub GenericDef; } @@ -726,9 +733,17 @@ crate_def! { } crate_def! { + /// A trait impl definition. pub ImplDef; } +impl ImplDef { + /// Retrieve information about this implementation. + pub fn trait_impl(&self) -> ImplTrait { + with(|cx| cx.trait_impl(self)) + } +} + crate_def! { pub RegionDef; } diff --git a/tests/ui-fulldeps/stable-mir/check_trait_queries.rs b/tests/ui-fulldeps/stable-mir/check_trait_queries.rs new file mode 100644 index 00000000000..fb1197e4ecc --- /dev/null +++ b/tests/ui-fulldeps/stable-mir/check_trait_queries.rs @@ -0,0 +1,125 @@ +// run-pass +//! Test that users are able to retrieve information about trait declarations and implementations. + +// ignore-stage1 +// ignore-cross-compile +// ignore-remote +// ignore-windows-gnu mingw has troubles with linking https://github.com/rust-lang/rust/pull/116837 +// edition: 2021 + +#![feature(rustc_private)] +#![feature(assert_matches)] +#![feature(control_flow_enum)] + +extern crate rustc_middle; +#[macro_use] +extern crate rustc_smir; +extern crate rustc_driver; +extern crate rustc_interface; +extern crate stable_mir; + +use rustc_middle::ty::TyCtxt; +use rustc_smir::rustc_internal; +use stable_mir::CrateDef; +use std::collections::HashSet; +use std::io::Write; +use std::ops::ControlFlow; + +const CRATE_NAME: &str = "trait_test"; + +/// This function uses the Stable MIR APIs to get information about the test crate. +fn test_traits() -> ControlFlow<()> { + let local_crate = stable_mir::local_crate(); + let local_traits = local_crate.trait_decls(); + assert_eq!(local_traits.len(), 1, "Expected `Max` trait, but found {:?}", local_traits); + assert_eq!(&local_traits[0].name(), "Max"); + + let local_impls = local_crate.trait_impls(); + let impl_names = local_impls.iter().map(|trait_impl| trait_impl.name()).collect::>(); + assert_impl(&impl_names, ""); + assert_impl(&impl_names, ""); + assert_impl(&impl_names, ""); + assert_impl(&impl_names, ""); + assert_impl(&impl_names, ""); + assert_impl(&impl_names, ""); + assert_impl(&impl_names, ">"); + assert_impl(&impl_names, ""); + assert_impl(&impl_names, " for u64>"); + + let all_traits = stable_mir::all_trait_decls(); + assert!(all_traits.len() > local_traits.len()); + assert!( + local_traits.iter().all(|t| all_traits.contains(t)), + "Local: {local_traits:#?}, All: {all_traits:#?}" + ); + + let all_impls = stable_mir::all_trait_impls(); + assert!(all_impls.len() > local_impls.len()); + assert!( + local_impls.iter().all(|t| all_impls.contains(t)), + "Local: {local_impls:#?}, All: {all_impls:#?}" + ); + ControlFlow::Continue(()) +} + +fn assert_impl(impl_names: &HashSet, target: &str) { + assert!( + impl_names.contains(target), + "Failed to find `{target}`. Implementations available: {impl_names:?}", + ); +} + +/// This test will generate and analyze a dummy crate using the stable mir. +/// For that, it will first write the dummy crate into a file. +/// Then it will create a `StableMir` using custom arguments and then +/// it will run the compiler. +fn main() { + let path = "trait_queries.rs"; + generate_input(&path).unwrap(); + let args = vec![ + "rustc".to_string(), + "--crate-type=lib".to_string(), + "--crate-name".to_string(), + CRATE_NAME.to_string(), + path.to_string(), + ]; + run!(args, test_traits()).unwrap(); +} + +fn generate_input(path: &str) -> std::io::Result<()> { + let mut file = std::fs::File::create(path)?; + write!( + file, + r#" + use std::convert::TryFrom; + + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + pub struct Positive(u64); + + impl TryFrom for Positive {{ + type Error = (); + fn try_from(val: u64) -> Result {{ + if val > 0 {{ Ok(Positive(val)) }} else {{ Err(()) }} + }} + }} + + impl From for u64 {{ + fn from(val: Positive) -> u64 {{ val.0 }} + }} + + pub trait Max {{ + fn is_max(&self) -> bool; + }} + + impl Max for u64 {{ + fn is_max(&self) -> bool {{ *self == u64::MAX }} + }} + + impl Max for Positive {{ + fn is_max(&self) -> bool {{ self.0.is_max() }} + }} + + "# + )?; + Ok(()) +}