diff --git a/compiler/rustc_smir/src/rustc_internal/internal.rs b/compiler/rustc_smir/src/rustc_internal/internal.rs index 7cfdbbbf703..5bb3c1a0d4c 100644 --- a/compiler/rustc_smir/src/rustc_internal/internal.rs +++ b/compiler/rustc_smir/src/rustc_internal/internal.rs @@ -6,11 +6,23 @@ // Prefer importing stable_mir over internal rustc constructs to make this file more readable. use crate::rustc_smir::Tables; use rustc_middle::ty::{self as rustc_ty, Ty as InternalTy}; -use stable_mir::ty::{Const, GenericArgKind, GenericArgs, Region, Ty}; -use stable_mir::DefId; +use rustc_span::Symbol; +use stable_mir::mir::mono::{Instance, MonoItem, StaticDef}; +use stable_mir::ty::{ + Binder, BoundRegionKind, BoundTyKind, BoundVariableKind, ClosureKind, Const, GenericArgKind, + GenericArgs, Region, TraitRef, Ty, +}; +use stable_mir::{AllocId, CrateItem, DefId}; use super::RustcInternal; +impl<'tcx> RustcInternal<'tcx> for CrateItem { + type T = rustc_span::def_id::DefId; + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + self.0.internal(tables) + } +} + impl<'tcx> RustcInternal<'tcx> for DefId { type T = rustc_span::def_id::DefId; fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { @@ -38,8 +50,9 @@ impl<'tcx> RustcInternal<'tcx> for GenericArgKind { impl<'tcx> RustcInternal<'tcx> for Region { type T = rustc_ty::Region<'tcx>; - fn internal(&self, _tables: &mut Tables<'tcx>) -> Self::T { - todo!() + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + // Cannot recover region. Use erased instead. + tables.tcx.lifetimes.re_erased } } @@ -65,3 +78,118 @@ impl<'tcx> RustcInternal<'tcx> for Const { tables.constants[self.id] } } + +impl<'tcx> RustcInternal<'tcx> for MonoItem { + type T = rustc_middle::mir::mono::MonoItem<'tcx>; + + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + use rustc_middle::mir::mono as rustc_mono; + match self { + MonoItem::Fn(instance) => rustc_mono::MonoItem::Fn(instance.internal(tables)), + MonoItem::Static(def) => rustc_mono::MonoItem::Static(def.internal(tables)), + MonoItem::GlobalAsm(_) => { + unimplemented!() + } + } + } +} + +impl<'tcx> RustcInternal<'tcx> for Instance { + type T = rustc_ty::Instance<'tcx>; + + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + tables.instances[self.def] + } +} + +impl<'tcx> RustcInternal<'tcx> for StaticDef { + type T = rustc_span::def_id::DefId; + + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + self.0.internal(tables) + } +} + +#[allow(rustc::usage_of_qualified_ty)] +impl<'tcx, T> RustcInternal<'tcx> for Binder +where + T: RustcInternal<'tcx>, + T::T: rustc_ty::TypeVisitable>, +{ + type T = rustc_ty::Binder<'tcx, T::T>; + + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + rustc_ty::Binder::bind_with_vars( + self.value.internal(tables), + tables.tcx.mk_bound_variable_kinds_from_iter( + self.bound_vars.iter().map(|bound| bound.internal(tables)), + ), + ) + } +} + +impl<'tcx> RustcInternal<'tcx> for BoundVariableKind { + type T = rustc_ty::BoundVariableKind; + + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + match self { + BoundVariableKind::Ty(kind) => rustc_ty::BoundVariableKind::Ty(match kind { + BoundTyKind::Anon => rustc_ty::BoundTyKind::Anon, + BoundTyKind::Param(def, symbol) => { + rustc_ty::BoundTyKind::Param(def.0.internal(tables), Symbol::intern(&symbol)) + } + }), + BoundVariableKind::Region(kind) => rustc_ty::BoundVariableKind::Region(match kind { + BoundRegionKind::BrAnon => rustc_ty::BoundRegionKind::BrAnon, + BoundRegionKind::BrNamed(def, symbol) => rustc_ty::BoundRegionKind::BrNamed( + def.0.internal(tables), + Symbol::intern(&symbol), + ), + BoundRegionKind::BrEnv => rustc_ty::BoundRegionKind::BrEnv, + }), + BoundVariableKind::Const => rustc_ty::BoundVariableKind::Const, + } + } +} + +impl<'tcx> RustcInternal<'tcx> for TraitRef { + type T = rustc_ty::TraitRef<'tcx>; + + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + rustc_ty::TraitRef::new( + tables.tcx, + self.def_id.0.internal(tables), + self.args().internal(tables), + ) + } +} + +impl<'tcx> RustcInternal<'tcx> for AllocId { + type T = rustc_middle::mir::interpret::AllocId; + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + tables.alloc_ids[*self] + } +} + +impl<'tcx> RustcInternal<'tcx> for ClosureKind { + type T = rustc_ty::ClosureKind; + + fn internal(&self, _tables: &mut Tables<'tcx>) -> Self::T { + match self { + ClosureKind::Fn => rustc_ty::ClosureKind::Fn, + ClosureKind::FnMut => rustc_ty::ClosureKind::FnMut, + ClosureKind::FnOnce => rustc_ty::ClosureKind::FnOnce, + } + } +} + +impl<'tcx, T> RustcInternal<'tcx> for &T +where + T: RustcInternal<'tcx>, +{ + type T = T::T; + + fn internal(&self, tables: &mut Tables<'tcx>) -> Self::T { + (*self).internal(tables) + } +} diff --git a/compiler/rustc_smir/src/rustc_internal/mod.rs b/compiler/rustc_smir/src/rustc_internal/mod.rs index f0b368bec39..c82f948f195 100644 --- a/compiler/rustc_smir/src/rustc_internal/mod.rs +++ b/compiler/rustc_smir/src/rustc_internal/mod.rs @@ -13,6 +13,7 @@ use rustc_span::def_id::{CrateNum, DefId}; use rustc_span::Span; use scoped_tls::scoped_thread_local; use stable_mir::ty::IndexedVal; +use stable_mir::Error; use std::cell::Cell; use std::cell::RefCell; use std::fmt::Debug; @@ -21,11 +22,11 @@ use std::ops::Index; mod internal; -pub fn stable<'tcx, S: Stable<'tcx>>(item: &S) -> S::T { +pub fn stable<'tcx, S: Stable<'tcx>>(item: S) -> S::T { with_tables(|tables| item.stable(tables)) } -pub fn internal<'tcx, S: RustcInternal<'tcx>>(item: &S) -> S::T { +pub fn internal<'tcx, S: RustcInternal<'tcx>>(item: S) -> S::T { with_tables(|tables| item.internal(tables)) } @@ -144,12 +145,13 @@ pub fn crate_num(item: &stable_mir::Crate) -> CrateNum { // datastructures and stable MIR datastructures scoped_thread_local! (static TLV: Cell<*const ()>); -pub(crate) fn init<'tcx>(tables: &TablesWrapper<'tcx>, f: impl FnOnce()) { +pub(crate) fn init<'tcx, F, T>(tables: &TablesWrapper<'tcx>, f: F) -> T +where + F: FnOnce() -> T, +{ assert!(!TLV.is_set()); let ptr = tables as *const _ as *const (); - TLV.set(&Cell::new(ptr), || { - f(); - }); + TLV.set(&Cell::new(ptr), || f()) } /// Loads the current context and calls a function with it. @@ -165,7 +167,10 @@ pub(crate) fn with_tables<'tcx, R>(f: impl FnOnce(&mut Tables<'tcx>) -> R) -> R }) } -pub fn run(tcx: TyCtxt<'_>, f: impl FnOnce()) { +pub fn run(tcx: TyCtxt<'_>, f: F) -> Result +where + F: FnOnce() -> T, +{ let tables = TablesWrapper(RefCell::new(Tables { tcx, def_ids: IndexMap::default(), @@ -175,7 +180,7 @@ pub fn run(tcx: TyCtxt<'_>, f: impl FnOnce()) { instances: IndexMap::default(), constants: IndexMap::default(), })); - stable_mir::run(&tables, || init(&tables, f)); + stable_mir::run(&tables, || init(&tables, f)) } #[macro_export] @@ -241,7 +246,8 @@ macro_rules! run { queries.global_ctxt().unwrap().enter(|tcx| { rustc_internal::run(tcx, || { self.result = Some((self.callback)(tcx)); - }); + }) + .unwrap(); if self.result.as_ref().is_some_and(|val| val.is_continue()) { Compilation::Continue } else { diff --git a/compiler/stable_mir/src/lib.rs b/compiler/stable_mir/src/lib.rs index f316671b278..63e9d54544b 100644 --- a/compiler/stable_mir/src/lib.rs +++ b/compiler/stable_mir/src/lib.rs @@ -47,7 +47,7 @@ pub type Symbol = String; pub type CrateNum = usize; /// A unique identification number for each item accessible for the current compilation unit. -#[derive(Clone, Copy, PartialEq, Eq)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] pub struct DefId(usize); impl Debug for DefId { @@ -240,12 +240,16 @@ pub trait Context { // datastructures and stable MIR datastructures scoped_thread_local! (static TLV: Cell<*const ()>); -pub fn run(context: &dyn Context, f: impl FnOnce()) { - assert!(!TLV.is_set()); - let ptr: *const () = &context as *const &_ as _; - TLV.set(&Cell::new(ptr), || { - f(); - }); +pub fn run(context: &dyn Context, f: F) -> Result +where + F: FnOnce() -> T, +{ + if TLV.is_set() { + Err(Error::from("StableMIR already running")) + } else { + let ptr: *const () = &context as *const &_ as _; + TLV.set(&Cell::new(ptr), || Ok(f())) + } } /// Loads the current context and calls a function with it. @@ -260,7 +264,7 @@ pub fn with(f: impl FnOnce(&dyn Context) -> R) -> R { } /// A type that provides internal information but that can still be used for debug purpose. -#[derive(Clone, Eq, PartialEq)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct Opaque(String); impl std::fmt::Display for Opaque {