diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index ef1d0c867ea..06e63b0f3d9 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -149,13 +149,14 @@ impl AssertKind { ResumedAfterReturn(CoroutineKind::Coroutine) => "coroutine resumed after completion", ResumedAfterReturn(CoroutineKind::Async(_)) => "`async fn` resumed after completion", ResumedAfterReturn(CoroutineKind::Gen(_)) => { - bug!("`gen fn` should just keep returning `None` after the first time") + "`gen fn` should just keep returning `None` after completion" } ResumedAfterPanic(CoroutineKind::Coroutine) => "coroutine resumed after panicking", ResumedAfterPanic(CoroutineKind::Async(_)) => "`async fn` resumed after panicking", ResumedAfterPanic(CoroutineKind::Gen(_)) => { - bug!("`gen fn` should just keep returning `None` after panicking") + "`gen fn` should just keep returning `None` after panicking" } + BoundsCheck { .. } | MisalignedPointerDereference { .. } => { bug!("Unexpected AssertKind") } diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs index fa56d59dd80..8fecff16a91 100644 --- a/compiler/rustc_mir_transform/src/coroutine.rs +++ b/compiler/rustc_mir_transform/src/coroutine.rs @@ -224,7 +224,7 @@ struct SuspensionPoint<'tcx> { struct TransformVisitor<'tcx> { tcx: TyCtxt<'tcx>, - is_async_kind: bool, + coroutine_kind: hir::CoroutineKind, state_adt_ref: AdtDef<'tcx>, state_args: GenericArgsRef<'tcx>, @@ -261,31 +261,53 @@ impl<'tcx> TransformVisitor<'tcx> { is_return: bool, statements: &mut Vec>, ) { - let idx = VariantIdx::new(match (is_return, self.is_async_kind) { - (true, false) => 1, // CoroutineState::Complete - (false, false) => 0, // CoroutineState::Yielded - (true, true) => 0, // Poll::Ready - (false, true) => 1, // Poll::Pending + let idx = VariantIdx::new(match (is_return, self.coroutine_kind) { + (true, hir::CoroutineKind::Coroutine) => 1, // CoroutineState::Complete + (false, hir::CoroutineKind::Coroutine) => 0, // CoroutineState::Yielded + (true, hir::CoroutineKind::Async(_)) => 0, // Poll::Ready + (false, hir::CoroutineKind::Async(_)) => 1, // Poll::Pending + (true, hir::CoroutineKind::Gen(_)) => 0, // Option::None + (false, hir::CoroutineKind::Gen(_)) => 1, // Option::Some }); let kind = AggregateKind::Adt(self.state_adt_ref.did(), idx, self.state_args, None, None); - // `Poll::Pending` - if self.is_async_kind && idx == VariantIdx::new(1) { - assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + match self.coroutine_kind { + // `Poll::Pending` + CoroutineKind::Async(_) => { + if !is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); - // FIXME(swatinem): assert that `val` is indeed unit? - statements.push(Statement { - kind: StatementKind::Assign(Box::new(( - Place::return_place(), - Rvalue::Aggregate(Box::new(kind), IndexVec::new()), - ))), - source_info, - }); - return; + // FIXME(swatinem): assert that `val` is indeed unit? + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + // `Option::None` + CoroutineKind::Gen(_) => { + if is_return { + assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 0); + + statements.push(Statement { + kind: StatementKind::Assign(Box::new(( + Place::return_place(), + Rvalue::Aggregate(Box::new(kind), IndexVec::new()), + ))), + source_info, + }); + return; + } + } + CoroutineKind::Coroutine => {} } - // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)` or `CoroutineState::Complete(x)` + // else: `Poll::Ready(x)`, `CoroutineState::Yielded(x)`, `CoroutineState::Complete(x)`, or `Option::Some(x)` assert_eq!(self.state_adt_ref.variant(idx).fields.len(), 1); statements.push(Statement { @@ -1439,18 +1461,28 @@ impl<'tcx> MirPass<'tcx> for StateTransform { }; let is_async_kind = matches!(body.coroutine_kind(), Some(CoroutineKind::Async(_))); - let (state_adt_ref, state_args) = if is_async_kind { - // Compute Poll - let poll_did = tcx.require_lang_item(LangItem::Poll, None); - let poll_adt_ref = tcx.adt_def(poll_did); - let poll_args = tcx.mk_args(&[body.return_ty().into()]); - (poll_adt_ref, poll_args) - } else { - // Compute CoroutineState - let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); - let state_adt_ref = tcx.adt_def(state_did); - let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); - (state_adt_ref, state_args) + let (state_adt_ref, state_args) = match body.coroutine_kind().unwrap() { + CoroutineKind::Async(_) => { + // Compute Poll + let poll_did = tcx.require_lang_item(LangItem::Poll, None); + let poll_adt_ref = tcx.adt_def(poll_did); + let poll_args = tcx.mk_args(&[body.return_ty().into()]); + (poll_adt_ref, poll_args) + } + CoroutineKind::Gen(_) => { + // Compute Option + let option_did = tcx.require_lang_item(LangItem::Option, None); + let option_adt_ref = tcx.adt_def(option_did); + let option_args = tcx.mk_args(&[body.yield_ty().unwrap().into()]); + (option_adt_ref, option_args) + } + CoroutineKind::Coroutine => { + // Compute CoroutineState + let state_did = tcx.require_lang_item(LangItem::CoroutineState, None); + let state_adt_ref = tcx.adt_def(state_did); + let state_args = tcx.mk_args(&[yield_ty.into(), body.return_ty().into()]); + (state_adt_ref, state_args) + } }; let ret_ty = Ty::new_adt(tcx, state_adt_ref, state_args); @@ -1518,7 +1550,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform { // or Poll::Ready(x) and Poll::Pending respectively depending on `is_async_kind`. let mut transform = TransformVisitor { tcx, - is_async_kind, + coroutine_kind: body.coroutine_kind().unwrap(), state_adt_ref, state_args, remap, diff --git a/compiler/rustc_ty_utils/src/instance.rs b/compiler/rustc_ty_utils/src/instance.rs index 0e9d79c15c3..c0fe13ac996 100644 --- a/compiler/rustc_ty_utils/src/instance.rs +++ b/compiler/rustc_ty_utils/src/instance.rs @@ -258,6 +258,19 @@ fn resolve_associated_item<'tcx>( debug_assert!(tcx.defaultness(trait_item_id).has_value()); Some(Instance::new(trait_item_id, rcvr_args)) } + } else if Some(trait_ref.def_id) == lang_items.iterator_trait() { + let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else { + bug!() + }; + if Some(trait_item_id) == tcx.lang_items().next_fn() { + // `Iterator::next` is generated by the compiler. + Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args }) + } else { + // All other methods are default methods of the `Iterator` trait. + // (this assumes that `ImplSource::Builtin` is only used for methods on `Iterator`) + debug_assert!(tcx.defaultness(trait_item_id).has_value()); + Some(Instance::new(trait_item_id, rcvr_args)) + } } else if Some(trait_ref.def_id) == lang_items.gen_trait() { let ty::Coroutine(coroutine_def_id, args, _) = *rcvr_args.type_at(0).kind() else { bug!() diff --git a/tests/ui/coroutine/gen_block_iterate.rs b/tests/ui/coroutine/gen_block_iterate.rs new file mode 100644 index 00000000000..26932723779 --- /dev/null +++ b/tests/ui/coroutine/gen_block_iterate.rs @@ -0,0 +1,18 @@ +// revisions: next old +//compile-flags: --edition 2024 -Zunstable-options +//[next] compile-flags: -Ztrait-solver=next +// run-pass +#![feature(coroutines)] + +fn foo() -> impl Iterator { + gen { yield 42; for x in 3..6 { yield x } } +} + +fn main() { + let mut iter = foo(); + assert_eq!(iter.next(), Some(42)); + assert_eq!(iter.next(), Some(3)); + assert_eq!(iter.next(), Some(4)); + assert_eq!(iter.next(), Some(5)); + assert_eq!(iter.next(), None); +}