Factor out some repetitive code.

This commit is contained in:
Nicholas Nethercote 2024-08-28 13:47:22 +10:00
parent 408481f4d8
commit 590a02173b

View File

@ -63,7 +63,9 @@ use rustc_index::bit_set::{BitMatrix, BitSet, GrowableBitSet};
use rustc_index::{Idx, IndexVec}; use rustc_index::{Idx, IndexVec};
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
use rustc_middle::mir::*; use rustc_middle::mir::*;
use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, InstanceKind, Ty, TyCtxt}; use rustc_middle::ty::{
self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt,
};
use rustc_middle::{bug, span_bug}; use rustc_middle::{bug, span_bug};
use rustc_mir_dataflow::impls::{ use rustc_mir_dataflow::impls::{
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive, MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
@ -210,14 +212,10 @@ impl<'tcx> TransformVisitor<'tcx> {
// `gen` continues return `None` // `gen` continues return `None`
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None); let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
Rvalue::Aggregate( make_aggregate_adt(
Box::new(AggregateKind::Adt( option_def_id,
option_def_id, VariantIdx::ZERO,
VariantIdx::ZERO, self.tcx.mk_args(&[self.old_yield_ty.into()]),
self.tcx.mk_args(&[self.old_yield_ty.into()]),
None,
None,
)),
IndexVec::new(), IndexVec::new(),
) )
} }
@ -266,64 +264,28 @@ impl<'tcx> TransformVisitor<'tcx> {
is_return: bool, is_return: bool,
statements: &mut Vec<Statement<'tcx>>, statements: &mut Vec<Statement<'tcx>>,
) { ) {
const ZERO: VariantIdx = VariantIdx::ZERO;
const ONE: VariantIdx = VariantIdx::from_usize(1);
let rvalue = match self.coroutine_kind { let rvalue = match self.coroutine_kind {
CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => { CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None); let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, None);
let args = self.tcx.mk_args(&[self.old_ret_ty.into()]); let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
if is_return { let (variant_idx, operands) = if is_return {
// Poll::Ready(val) (ZERO, IndexVec::from_raw(vec![val])) // Poll::Ready(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
poll_def_id,
VariantIdx::ZERO,
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
} else { } else {
// Poll::Pending (ONE, IndexVec::new()) // Poll::Pending
Rvalue::Aggregate( };
Box::new(AggregateKind::Adt( make_aggregate_adt(poll_def_id, variant_idx, args, operands)
poll_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::new(),
)
}
} }
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => { CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
let option_def_id = self.tcx.require_lang_item(LangItem::Option, None); let option_def_id = self.tcx.require_lang_item(LangItem::Option, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into()]); let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
if is_return { let (variant_idx, operands) = if is_return {
// None (ZERO, IndexVec::new()) // None
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
option_def_id,
VariantIdx::ZERO,
args,
None,
None,
)),
IndexVec::new(),
)
} else { } else {
// Some(val) (ONE, IndexVec::from_raw(vec![val])) // Some(val)
Rvalue::Aggregate( };
Box::new(AggregateKind::Adt( make_aggregate_adt(option_def_id, variant_idx, args, operands)
option_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
}
} }
CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => { CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
if is_return { if is_return {
@ -349,31 +311,17 @@ impl<'tcx> TransformVisitor<'tcx> {
let coroutine_state_def_id = let coroutine_state_def_id =
self.tcx.require_lang_item(LangItem::CoroutineState, None); self.tcx.require_lang_item(LangItem::CoroutineState, None);
let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]); let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
if is_return { let variant_idx = if is_return {
// CoroutineState::Complete(val) ONE // CoroutineState::Complete(val)
Rvalue::Aggregate(
Box::new(AggregateKind::Adt(
coroutine_state_def_id,
VariantIdx::from_usize(1),
args,
None,
None,
)),
IndexVec::from_raw(vec![val]),
)
} else { } else {
// CoroutineState::Yielded(val) ZERO // CoroutineState::Yielded(val)
Rvalue::Aggregate( };
Box::new(AggregateKind::Adt( make_aggregate_adt(
coroutine_state_def_id, coroutine_state_def_id,
VariantIdx::ZERO, variant_idx,
args, args,
None, IndexVec::from_raw(vec![val]),
None, )
)),
IndexVec::from_raw(vec![val]),
)
}
} }
}; };
@ -509,6 +457,15 @@ impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
} }
} }
fn make_aggregate_adt<'tcx>(
def_id: DefId,
variant_idx: VariantIdx,
args: GenericArgsRef<'tcx>,
operands: IndexVec<FieldIdx, Operand<'tcx>>,
) -> Rvalue<'tcx> {
Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
}
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let coroutine_ty = body.local_decls.raw[1].ty; let coroutine_ty = body.local_decls.raw[1].ty;