From 11375c86571ce58646e84cf47df884a3bc2a9934 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Thu, 7 Dec 2023 17:39:02 +0000
Subject: [PATCH] Add tests

---
 compiler/rustc_ast_lowering/src/lib.rs       | 25 +++--
 tests/ui/coroutine/async_gen_fn.e2024.stderr | 12 +++
 tests/ui/coroutine/async_gen_fn.none.stderr  | 18 ++++
 tests/ui/coroutine/async_gen_fn.rs           | 16 ++--
 tests/ui/coroutine/async_gen_fn_iter.rs      | 96 ++++++++++++++++++++
 5 files changed, 148 insertions(+), 19 deletions(-)
 create mode 100644 tests/ui/coroutine/async_gen_fn.e2024.stderr
 create mode 100644 tests/ui/coroutine/async_gen_fn.none.stderr
 create mode 100644 tests/ui/coroutine/async_gen_fn_iter.rs

diff --git a/compiler/rustc_ast_lowering/src/lib.rs b/compiler/rustc_ast_lowering/src/lib.rs
index c48024a3f4a..753650f7324 100644
--- a/compiler/rustc_ast_lowering/src/lib.rs
+++ b/compiler/rustc_ast_lowering/src/lib.rs
@@ -132,6 +132,7 @@ struct LoweringContext<'a, 'hir> {
 
     allow_try_trait: Lrc<[Symbol]>,
     allow_gen_future: Lrc<[Symbol]>,
+    allow_async_iterator: Lrc<[Symbol]>,
 
     /// Mapping from generics `def_id`s to TAIT generics `def_id`s.
     /// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
@@ -176,6 +177,8 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
             } else {
                 [sym::gen_future].into()
             },
+            // FIXME(gen_blocks): how does `closure_track_caller`
+            allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),
             generics_def_id_map: Default::default(),
             host_param_id: None,
         }
@@ -1900,14 +1903,18 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         fn_span: Span,
     ) -> hir::FnRetTy<'hir> {
         let span = self.lower_span(fn_span);
-        let opaque_ty_span = self.mark_span_with_reason(DesugaringKind::Async, span, None);
 
-        let opaque_ty_node_id = match coro {
-            CoroutineKind::Async { return_impl_trait_id, .. }
-            | CoroutineKind::Gen { return_impl_trait_id, .. }
-            | CoroutineKind::AsyncGen { return_impl_trait_id, .. } => return_impl_trait_id,
+        let (opaque_ty_node_id, allowed_features) = match coro {
+            CoroutineKind::Async { return_impl_trait_id, .. } => (return_impl_trait_id, None),
+            CoroutineKind::Gen { return_impl_trait_id, .. } => (return_impl_trait_id, None),
+            CoroutineKind::AsyncGen { return_impl_trait_id, .. } => {
+                (return_impl_trait_id, Some(self.allow_async_iterator.clone()))
+            }
         };
 
+        let opaque_ty_span =
+            self.mark_span_with_reason(DesugaringKind::Async, span, allowed_features);
+
         let captured_lifetimes: Vec<_> = self
             .resolver
             .take_extra_lifetime_params(opaque_ty_node_id)
@@ -1926,7 +1933,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
                 let bound = this.lower_coroutine_fn_output_type_to_bound(
                     output,
                     coro,
-                    span,
+                    opaque_ty_span,
                     ImplTraitContext::ReturnPositionOpaqueTy {
                         origin: hir::OpaqueTyOrigin::FnReturn(fn_def_id),
                         fn_kind,
@@ -1945,7 +1952,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
         &mut self,
         output: &FnRetTy,
         coro: CoroutineKind,
-        span: Span,
+        opaque_ty_span: Span,
         nested_impl_trait_context: ImplTraitContext,
     ) -> hir::GenericBound<'hir> {
         // Compute the `T` in `Future<Output = T>` from the return type.
@@ -1968,14 +1975,14 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
 
         let future_args = self.arena.alloc(hir::GenericArgs {
             args: &[],
-            bindings: arena_vec![self; self.assoc_ty_binding(assoc_ty_name, span, output_ty)],
+            bindings: arena_vec![self; self.assoc_ty_binding(assoc_ty_name, opaque_ty_span, output_ty)],
             parenthesized: hir::GenericArgsParentheses::No,
             span_ext: DUMMY_SP,
         });
 
         hir::GenericBound::LangItemTrait(
             trait_lang_item,
-            self.lower_span(span),
+            opaque_ty_span,
             self.next_id(),
             future_args,
         )
diff --git a/tests/ui/coroutine/async_gen_fn.e2024.stderr b/tests/ui/coroutine/async_gen_fn.e2024.stderr
new file mode 100644
index 00000000000..d24cdbbc30d
--- /dev/null
+++ b/tests/ui/coroutine/async_gen_fn.e2024.stderr
@@ -0,0 +1,12 @@
+error[E0658]: gen blocks are experimental
+  --> $DIR/async_gen_fn.rs:4:1
+   |
+LL | async gen fn foo() {}
+   | ^^^^^^^^^
+   |
+   = note: see issue #117078 <https://github.com/rust-lang/rust/issues/117078> for more information
+   = help: add `#![feature(gen_blocks)]` to the crate attributes to enable
+
+error: aborting due to 1 previous error
+
+For more information about this error, try `rustc --explain E0658`.
diff --git a/tests/ui/coroutine/async_gen_fn.none.stderr b/tests/ui/coroutine/async_gen_fn.none.stderr
new file mode 100644
index 00000000000..7950251a75d
--- /dev/null
+++ b/tests/ui/coroutine/async_gen_fn.none.stderr
@@ -0,0 +1,18 @@
+error[E0670]: `async fn` is not permitted in Rust 2015
+  --> $DIR/async_gen_fn.rs:4:1
+   |
+LL | async gen fn foo() {}
+   | ^^^^^ to use `async fn`, switch to Rust 2018 or later
+   |
+   = help: pass `--edition 2021` to `rustc`
+   = note: for more on editions, read https://doc.rust-lang.org/edition-guide
+
+error: expected one of `extern`, `fn`, or `unsafe`, found `gen`
+  --> $DIR/async_gen_fn.rs:4:7
+   |
+LL | async gen fn foo() {}
+   |       ^^^ expected one of `extern`, `fn`, or `unsafe`
+
+error: aborting due to 2 previous errors
+
+For more information about this error, try `rustc --explain E0670`.
diff --git a/tests/ui/coroutine/async_gen_fn.rs b/tests/ui/coroutine/async_gen_fn.rs
index f51fef43504..20564106f99 100644
--- a/tests/ui/coroutine/async_gen_fn.rs
+++ b/tests/ui/coroutine/async_gen_fn.rs
@@ -1,13 +1,9 @@
-// edition: 2024
-// compile-flags: -Zunstable-options
-// check-pass
+// revisions: e2024 none
+//[e2024] compile-flags: --edition 2024 -Zunstable-options
 
-#![feature(gen_blocks, async_iterator)]
-
-async fn bar() {}
-
-async gen fn foo() {
-    yield bar().await;
-}
+async gen fn foo() {}
+//[none]~^ ERROR: `async fn` is not permitted in Rust 2015
+//[none]~| ERROR: expected one of `extern`, `fn`, or `unsafe`, found `gen`
+//[e2024]~^^^ ERROR: gen blocks are experimental
 
 fn main() {}
diff --git a/tests/ui/coroutine/async_gen_fn_iter.rs b/tests/ui/coroutine/async_gen_fn_iter.rs
new file mode 100644
index 00000000000..6f8f3feb87e
--- /dev/null
+++ b/tests/ui/coroutine/async_gen_fn_iter.rs
@@ -0,0 +1,96 @@
+// edition: 2024
+// compile-flags: -Zunstable-options
+// run-pass
+
+#![feature(gen_blocks, async_iterator)]
+
+// make sure that a ridiculously simple async gen fn works as an iterator.
+
+async fn pause() {
+    // this doesn't actually do anything, lol
+}
+
+async fn one() -> i32 {
+    1
+}
+
+async fn two() -> i32 {
+    2
+}
+
+async gen fn foo() -> i32 {
+    yield one().await;
+    pause().await;
+    yield two().await;
+    pause().await;
+    yield 3;
+    pause().await;
+}
+
+async fn async_main() {
+    let mut iter = std::pin::pin!(foo());
+    assert_eq!(iter.next().await, Some(1));
+    assert_eq!(iter.as_mut().next().await, Some(2));
+    assert_eq!(iter.as_mut().next().await, Some(3));
+    assert_eq!(iter.as_mut().next().await, None);
+}
+
+// ------------------------------------------------------------------------- //
+// Implementation Details Below...
+
+use std::pin::Pin;
+use std::task::*;
+use std::async_iter::AsyncIterator;
+use std::future::Future;
+
+trait AsyncIterExt {
+    fn next(&mut self) -> Next<'_, Self>;
+}
+
+impl<T> AsyncIterExt for T {
+    fn next(&mut self) -> Next<'_, Self> {
+        Next { s: self }
+    }
+}
+
+struct Next<'s, S: ?Sized> {
+    s: &'s mut S,
+}
+
+impl<'s, S: AsyncIterator> Future for Next<'s, S> where S: Unpin {
+    type Output = Option<S::Item>;
+
+    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<S::Item>> {
+        Pin::new(&mut *self.s).poll_next(cx)
+    }
+}
+
+pub fn noop_waker() -> Waker {
+    let raw = RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE);
+
+    // SAFETY: the contracts for RawWaker and RawWakerVTable are upheld
+    unsafe { Waker::from_raw(raw) }
+}
+
+const NOOP_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
+
+unsafe fn noop_clone(_p: *const ()) -> RawWaker {
+    RawWaker::new(std::ptr::null(), &NOOP_WAKER_VTABLE)
+}
+
+unsafe fn noop(_p: *const ()) {}
+
+fn main() {
+    let mut fut = async_main();
+
+    // Poll loop, just to test the future...
+    let waker = noop_waker();
+    let ctx = &mut Context::from_waker(&waker);
+
+    loop {
+        match unsafe { Pin::new_unchecked(&mut fut).poll(ctx) } {
+            Poll::Pending => {}
+            Poll::Ready(()) => break,
+        }
+    }
+}