Build a shim to call async closures with different AsyncFn trait kinds

This commit is contained in:
Michael Goulet 2024-01-25 00:30:55 +00:00
parent a82bae2172
commit fc4fff4038
13 changed files with 175 additions and 11 deletions

View File

@ -545,6 +545,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
ty::InstanceDef::VTableShim(..)
| ty::InstanceDef::ReifyShim(..)
| ty::InstanceDef::ClosureOnceShim { .. }
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
| ty::InstanceDef::FnPtrShim(..)
| ty::InstanceDef::DropGlue(..)
| ty::InstanceDef::CloneShim(..)

View File

@ -402,6 +402,7 @@ impl<'tcx> CodegenUnit<'tcx> {
| InstanceDef::FnPtrShim(..)
| InstanceDef::Virtual(..)
| InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::DropGlue(..)
| InstanceDef::CloneShim(..)
| InstanceDef::ThreadLocalShim(..)

View File

@ -345,6 +345,7 @@ macro_rules! make_mir_visitor {
ty::InstanceDef::Virtual(_def_id, _) |
ty::InstanceDef::ThreadLocalShim(_def_id) |
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
ty::InstanceDef::DropGlue(_def_id, None) => {}
ty::InstanceDef::FnPtrShim(_def_id, ty) |

View File

@ -82,11 +82,25 @@ pub enum InstanceDef<'tcx> {
/// details on that).
Virtual(DefId, usize),
/// `<[FnMut closure] as FnOnce>::call_once`.
/// `<[FnMut/Fn closure] as FnOnce>::call_once`.
///
/// The `DefId` is the ID of the `call_once` method in `FnOnce`.
///
/// This generates a body that will just borrow the (owned) self type,
/// and dispatch to the `FnMut::call_mut` instance for the closure.
ClosureOnceShim { call_once: DefId, track_caller: bool },
/// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once` or
/// `<[Fn coroutine-closure] as FnMut>::call_mut`.
///
/// The body generated here differs significantly from the `ClosureOnceShim`,
/// since we need to generate a distinct coroutine type that will move the
/// closure's upvars *out* of the closure.
ConstructCoroutineInClosureShim {
coroutine_closure_def_id: DefId,
target_kind: ty::ClosureKind,
},
/// Compiler-generated accessor for thread locals which returns a reference to the thread local
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
/// native support.
@ -168,6 +182,10 @@ impl<'tcx> InstanceDef<'tcx> {
| InstanceDef::Intrinsic(def_id)
| InstanceDef::ThreadLocalShim(def_id)
| InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
| ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id: def_id,
target_kind: _,
}
| InstanceDef::DropGlue(def_id, _)
| InstanceDef::CloneShim(def_id, _)
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@ -187,6 +205,7 @@ impl<'tcx> InstanceDef<'tcx> {
| InstanceDef::Virtual(..)
| InstanceDef::Intrinsic(..)
| InstanceDef::ClosureOnceShim { .. }
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::DropGlue(..)
| InstanceDef::CloneShim(..)
| InstanceDef::FnPtrAddrShim(..) => None,
@ -282,6 +301,7 @@ impl<'tcx> InstanceDef<'tcx> {
| InstanceDef::FnPtrShim(..)
| InstanceDef::DropGlue(_, Some(_)) => false,
InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::DropGlue(..)
| InstanceDef::Item(_)
| InstanceDef::Intrinsic(..)
@ -319,6 +339,7 @@ fn fmt_instance(
InstanceDef::Virtual(_, num) => write!(f, " - virtual#{num}"),
InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),

View File

@ -1680,6 +1680,7 @@ impl<'tcx> TyCtxt<'tcx> {
| ty::InstanceDef::FnPtrShim(..)
| ty::InstanceDef::Virtual(..)
| ty::InstanceDef::ClosureOnceShim { .. }
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
| ty::InstanceDef::DropGlue(..)
| ty::InstanceDef::CloneShim(..)
| ty::InstanceDef::ThreadLocalShim(..)

View File

@ -317,6 +317,7 @@ impl<'tcx> Inliner<'tcx> {
| InstanceDef::ReifyShim(_)
| InstanceDef::FnPtrShim(..)
| InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::DropGlue(..)
| InstanceDef::CloneShim(..)
| InstanceDef::ThreadLocalShim(..)

View File

@ -87,6 +87,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
| InstanceDef::ReifyShim(_)
| InstanceDef::FnPtrShim(..)
| InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::ThreadLocalShim { .. }
| InstanceDef::CloneShim(..) => {}

View File

@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId;
use rustc_hir::lang_items::LangItem;
use rustc_middle::mir::*;
use rustc_middle::query::Providers;
use rustc_middle::ty::GenericArgs;
use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt};
use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL};
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
use rustc_index::{Idx, IndexVec};
@ -66,6 +66,21 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
}
ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind,
} => match target_kind {
ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"),
ty::ClosureKind::FnMut => {
let body = build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
// No need to optimize the body, it has already been optimized.
return body;
}
ty::ClosureKind::FnOnce => {
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
}
},
ty::InstanceDef::DropGlue(def_id, ty) => {
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
// of this function. Is this intentional?
@ -981,3 +996,107 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t
let source = MirSource::from_instance(ty::InstanceDef::FnPtrAddrShim(def_id, self_ty));
new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span)
}
fn build_construct_coroutine_by_move_shim<'tcx>(
tcx: TyCtxt<'tcx>,
coroutine_closure_def_id: DefId,
) -> Body<'tcx> {
let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
bug!();
};
let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
tcx.mk_fn_sig(
[self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()),
sig.to_coroutine_given_kind_and_upvars(
tcx,
args.as_coroutine_closure().parent_args(),
tcx.coroutine_for_closure(coroutine_closure_def_id),
ty::ClosureKind::FnOnce,
tcx.lifetimes.re_erased,
args.as_coroutine_closure().tupled_upvars_ty(),
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
),
sig.c_variadic,
sig.unsafety,
sig.abi,
)
});
let sig = tcx.liberate_late_bound_regions(coroutine_closure_def_id, poly_sig);
let ty::Coroutine(coroutine_def_id, coroutine_args) = *sig.output().kind() else {
bug!();
};
let span = tcx.def_span(coroutine_closure_def_id);
let locals = local_decls_for_sig(&sig, span);
let mut fields = vec![];
for idx in 1..sig.inputs().len() {
fields.push(Operand::Move(Local::from_usize(idx + 1).into()));
}
for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() {
fields.push(Operand::Move(tcx.mk_place_field(
Local::from_usize(1).into(),
FieldIdx::from_usize(idx),
ty,
)));
}
let source_info = SourceInfo::outermost(span);
let rvalue = Rvalue::Aggregate(
Box::new(AggregateKind::Coroutine(coroutine_def_id, coroutine_args)),
IndexVec::from_raw(fields),
);
let stmt = Statement {
source_info,
kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
};
let statements = vec![stmt];
let start_block = BasicBlockData {
statements,
terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
is_cleanup: false,
};
let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind: ty::ClosureKind::FnOnce,
});
new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span)
}
fn build_construct_coroutine_by_mut_shim<'tcx>(
tcx: TyCtxt<'tcx>,
coroutine_closure_def_id: DefId,
) -> Body<'tcx> {
let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone();
let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else {
bug!();
};
let args = args.as_coroutine_closure();
body.local_decls[RETURN_PLACE].ty =
tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| {
sig.to_coroutine_given_kind_and_upvars(
tcx,
args.parent_args(),
tcx.coroutine_for_closure(coroutine_closure_def_id),
ty::ClosureKind::FnMut,
tcx.lifetimes.re_erased,
args.tupled_upvars_ty(),
args.coroutine_captures_by_ref_ty(),
)
}));
body.local_decls[CAPTURE_STRUCT_LOCAL].ty =
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty);
body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind: ty::ClosureKind::FnMut,
});
body
}

View File

@ -983,6 +983,7 @@ fn visit_instance_use<'tcx>(
| ty::InstanceDef::VTableShim(..)
| ty::InstanceDef::ReifyShim(..)
| ty::InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| ty::InstanceDef::Item(..)
| ty::InstanceDef::FnPtrShim(..)
| ty::InstanceDef::CloneShim(..)

View File

@ -620,6 +620,7 @@ fn characteristic_def_id_of_mono_item<'tcx>(
| ty::InstanceDef::ReifyShim(..)
| ty::InstanceDef::FnPtrShim(..)
| ty::InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| ty::InstanceDef::Intrinsic(..)
| ty::InstanceDef::DropGlue(..)
| ty::InstanceDef::Virtual(..)
@ -783,6 +784,7 @@ fn mono_item_visibility<'tcx>(
| InstanceDef::Virtual(..)
| InstanceDef::Intrinsic(..)
| InstanceDef::ClosureOnceShim { .. }
| InstanceDef::ConstructCoroutineInClosureShim { .. }
| InstanceDef::DropGlue(..)
| InstanceDef::CloneShim(..)
| InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden,

View File

@ -799,6 +799,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> {
| ty::InstanceDef::ReifyShim(..)
| ty::InstanceDef::FnPtrAddrShim(..)
| ty::InstanceDef::ClosureOnceShim { .. }
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
| ty::InstanceDef::ThreadLocalShim(..)
| ty::InstanceDef::DropGlue(..)
| ty::InstanceDef::CloneShim(..)

View File

@ -111,11 +111,14 @@ fn fn_sig_for_fn_abi<'tcx>(
kind: ty::BoundRegionKind::BrEnv,
};
let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br);
let env_ty = tcx.closure_env_ty(
Ty::new_coroutine_closure(tcx, def_id, args),
args.as_coroutine_closure().kind(),
env_region,
);
let mut kind = args.as_coroutine_closure().kind();
if let InstanceDef::ConstructCoroutineInClosureShim { target_kind, .. } = instance.def {
kind = target_kind;
}
let env_ty =
tcx.closure_env_ty(Ty::new_coroutine_closure(tcx, def_id, args), kind, env_region);
let sig = sig.skip_binder();
ty::Binder::bind_with_vars(
@ -125,7 +128,7 @@ fn fn_sig_for_fn_abi<'tcx>(
tcx,
args.as_coroutine_closure().parent_args(),
tcx.coroutine_for_closure(def_id),
args.as_coroutine_closure().kind(),
kind,
env_region,
args.as_coroutine_closure().tupled_upvars_ty(),
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),

View File

@ -283,10 +283,21 @@ fn resolve_associated_item<'tcx>(
tcx.item_name(trait_item_id)
),
}
} else if tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id).is_some() {
} else if let Some(target_kind) = tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id)
{
match *rcvr_args.type_at(0).kind() {
ty::CoroutineClosure(closure_def_id, args) => {
Some(Instance::new(closure_def_id, args))
ty::CoroutineClosure(coroutine_closure_def_id, args) => {
if target_kind > args.as_coroutine_closure().kind() {
Some(Instance {
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
coroutine_closure_def_id,
target_kind,
},
args,
})
} else {
Some(Instance::new(coroutine_closure_def_id, args))
}
}
_ => bug!(
"no built-in definition for `{trait_ref}::{}` for non-lending-closure type",