From 2480a0f3f69fbca417990d68d816ed5423ad2cd6 Mon Sep 17 00:00:00 2001 From: Arpad Borsos Date: Thu, 28 Dec 2023 13:57:43 +0100 Subject: [PATCH] Merge Coroutine lowering functions Instead of having separate `make_async/etc_expr` functions, this merges them them into one, reducing code duplication a bit. --- compiler/rustc_ast_lowering/src/expr.rs | 274 +++++++----------------- compiler/rustc_ast_lowering/src/item.rs | 39 ++-- 2 files changed, 85 insertions(+), 228 deletions(-) diff --git a/compiler/rustc_ast_lowering/src/expr.rs b/compiler/rustc_ast_lowering/src/expr.rs index ccc6644923a..e568da9bbc0 100644 --- a/compiler/rustc_ast_lowering/src/expr.rs +++ b/compiler/rustc_ast_lowering/src/expr.rs @@ -183,14 +183,6 @@ impl<'hir> LoweringContext<'_, 'hir> { self.arena.alloc_from_iter(arms.iter().map(|x| self.lower_arm(x))), hir::MatchSource::Normal, ), - ExprKind::Gen(capture_clause, block, GenBlockKind::Async) => self.make_async_expr( - *capture_clause, - e.id, - None, - e.span, - hir::CoroutineSource::Block, - |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), - ), ExprKind::Await(expr, await_kw_span) => self.lower_expr_await(*await_kw_span, expr), ExprKind::Closure(box Closure { binder, @@ -226,6 +218,22 @@ impl<'hir> LoweringContext<'_, 'hir> { *fn_arg_span, ), }, + ExprKind::Gen(capture_clause, block, genblock_kind) => { + let desugaring_kind = match genblock_kind { + GenBlockKind::Async => hir::CoroutineDesugaring::Async, + GenBlockKind::Gen => hir::CoroutineDesugaring::Gen, + GenBlockKind::AsyncGen => hir::CoroutineDesugaring::AsyncGen, + }; + self.make_desugared_coroutine_expr( + *capture_clause, + e.id, + None, + e.span, + desugaring_kind, + hir::CoroutineSource::Block, + |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), + ) + } ExprKind::Block(blk, opt_label) => { let opt_label = self.lower_label(*opt_label); hir::ExprKind::Block(self.lower_block(blk, opt_label.is_some()), opt_label) @@ -313,23 +321,6 @@ impl<'hir> LoweringContext<'_, 'hir> { rest, ) } - ExprKind::Gen(capture_clause, block, GenBlockKind::Gen) => self.make_gen_expr( - *capture_clause, - e.id, - None, - e.span, - hir::CoroutineSource::Block, - |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), - ), - ExprKind::Gen(capture_clause, block, GenBlockKind::AsyncGen) => self - .make_async_gen_expr( - *capture_clause, - e.id, - None, - e.span, - hir::CoroutineSource::Block, - |this| this.with_new_scopes(e.span, |this| this.lower_block_expr(block)), - ), ExprKind::Yield(opt_expr) => self.lower_expr_yield(e.span, opt_expr.as_deref()), ExprKind::Err => { hir::ExprKind::Err(self.dcx().span_delayed_bug(e.span, "lowered ExprKind::Err")) @@ -612,213 +603,91 @@ impl<'hir> LoweringContext<'_, 'hir> { hir::Arm { hir_id, pat, guard, body, span } } - /// Lower an `async` construct to a coroutine that implements `Future`. + /// Lower/desugar a coroutine construct. + /// + /// In particular, this creates the correct async resume argument and `_task_context`. /// /// This results in: /// /// ```text - /// static move? |_task_context| -> { + /// static move? |<_task_context?>| -> { /// /// } /// ``` - pub(super) fn make_async_expr( + pub(super) fn make_desugared_coroutine_expr( &mut self, capture_clause: CaptureBy, closure_node_id: NodeId, - ret_ty: Option>, - span: Span, - async_coroutine_source: hir::CoroutineSource, - body: impl FnOnce(&mut Self) -> hir::Expr<'hir>, - ) -> hir::ExprKind<'hir> { - let output = ret_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span))); - - // Resume argument type: `ResumeTy` - let unstable_span = self.mark_span_with_reason( - DesugaringKind::Async, - self.lower_span(span), - Some(self.allow_gen_future.clone()), - ); - let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span); - let input_ty = hir::Ty { - hir_id: self.next_id(), - kind: hir::TyKind::Path(resume_ty), - span: unstable_span, - }; - - // The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`. - let fn_decl = self.arena.alloc(hir::FnDecl { - inputs: arena_vec![self; input_ty], - output, - c_variadic: false, - implicit_self: hir::ImplicitSelfKind::None, - lifetime_elision_allowed: false, - }); - - // Lower the argument pattern/ident. The ident is used again in the `.await` lowering. - let (pat, task_context_hid) = self.pat_ident_binding_mode( - span, - Ident::with_dummy_span(sym::_task_context), - hir::BindingAnnotation::MUT, - ); - let param = hir::Param { - hir_id: self.next_id(), - pat, - ty_span: self.lower_span(span), - span: self.lower_span(span), - }; - let params = arena_vec![self; param]; - - let coroutine_kind = - hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Async, async_coroutine_source); - let body = self.lower_body(move |this| { - this.coroutine_kind = Some(coroutine_kind); - - let old_ctx = this.task_context; - this.task_context = Some(task_context_hid); - let res = body(this); - this.task_context = old_ctx; - (params, res) - }); - - // `static |_task_context| -> { body }`: - hir::ExprKind::Closure(self.arena.alloc(hir::Closure { - def_id: self.local_def_id(closure_node_id), - binder: hir::ClosureBinder::Default, - capture_clause, - bound_generic_params: &[], - fn_decl, - body, - fn_decl_span: self.lower_span(span), - fn_arg_span: None, - kind: hir::ClosureKind::Coroutine(coroutine_kind), - constness: hir::Constness::NotConst, - })) - } - - /// Lower a `gen` construct to a generator that implements `Iterator`. - /// - /// This results in: - /// - /// ```text - /// static move? |()| -> () { - /// - /// } - /// ``` - pub(super) fn make_gen_expr( - &mut self, - capture_clause: CaptureBy, - closure_node_id: NodeId, - _yield_ty: Option>, + return_ty: Option>, span: Span, + desugaring_kind: hir::CoroutineDesugaring, coroutine_source: hir::CoroutineSource, body: impl FnOnce(&mut Self) -> hir::Expr<'hir>, ) -> hir::ExprKind<'hir> { - let output = hir::FnRetTy::DefaultReturn(self.lower_span(span)); + let coroutine_kind = hir::CoroutineKind::Desugared(desugaring_kind, coroutine_source); + + // The `async` desugaring takes a resume argument and maintains a `task_context`, + // whereas a generator does not. + let (inputs, params, task_context): (&[_], &[_], _) = match desugaring_kind { + hir::CoroutineDesugaring::Async | hir::CoroutineDesugaring::AsyncGen => { + // Resume argument type: `ResumeTy` + let unstable_span = self.mark_span_with_reason( + DesugaringKind::Async, + self.lower_span(span), + Some(self.allow_gen_future.clone()), + ); + let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span); + let input_ty = hir::Ty { + hir_id: self.next_id(), + kind: hir::TyKind::Path(resume_ty), + span: unstable_span, + }; + let inputs = arena_vec![self; input_ty]; + + // Lower the argument pattern/ident. The ident is used again in the `.await` lowering. + let (pat, task_context_hid) = self.pat_ident_binding_mode( + span, + Ident::with_dummy_span(sym::_task_context), + hir::BindingAnnotation::MUT, + ); + let param = hir::Param { + hir_id: self.next_id(), + pat, + ty_span: self.lower_span(span), + span: self.lower_span(span), + }; + let params = arena_vec![self; param]; + + (inputs, params, Some(task_context_hid)) + } + hir::CoroutineDesugaring::Gen => (&[], &[], None), + }; + + let output = + return_ty.unwrap_or_else(|| hir::FnRetTy::DefaultReturn(self.lower_span(span))); - // The closure/generator `FnDecl` takes a single (resume) argument of type `input_ty`. let fn_decl = self.arena.alloc(hir::FnDecl { - inputs: &[], + inputs, output, c_variadic: false, implicit_self: hir::ImplicitSelfKind::None, lifetime_elision_allowed: false, }); - let coroutine_kind = - hir::CoroutineKind::Desugared(hir::CoroutineDesugaring::Gen, coroutine_source); - let body = self.lower_body(move |this| { - this.coroutine_kind = Some(coroutine_kind); - - let res = body(this); - (&[], res) - }); - - // `static |()| -> () { body }`: - hir::ExprKind::Closure(self.arena.alloc(hir::Closure { - def_id: self.local_def_id(closure_node_id), - binder: hir::ClosureBinder::Default, - capture_clause, - bound_generic_params: &[], - fn_decl, - body, - fn_decl_span: self.lower_span(span), - fn_arg_span: None, - kind: hir::ClosureKind::Coroutine(coroutine_kind), - constness: hir::Constness::NotConst, - })) - } - - /// Lower a `async gen` construct to a generator that implements `AsyncIterator`. - /// - /// This results in: - /// - /// ```text - /// static move? |_task_context| -> () { - /// - /// } - /// ``` - pub(super) fn make_async_gen_expr( - &mut self, - capture_clause: CaptureBy, - closure_node_id: NodeId, - _yield_ty: Option>, - span: Span, - async_coroutine_source: hir::CoroutineSource, - body: impl FnOnce(&mut Self) -> hir::Expr<'hir>, - ) -> hir::ExprKind<'hir> { - let output = hir::FnRetTy::DefaultReturn(self.lower_span(span)); - - // Resume argument type: `ResumeTy` - let unstable_span = self.mark_span_with_reason( - DesugaringKind::Async, - self.lower_span(span), - Some(self.allow_gen_future.clone()), - ); - let resume_ty = self.make_lang_item_qpath(hir::LangItem::ResumeTy, unstable_span); - let input_ty = hir::Ty { - hir_id: self.next_id(), - kind: hir::TyKind::Path(resume_ty), - span: unstable_span, - }; - - // The closure/coroutine `FnDecl` takes a single (resume) argument of type `input_ty`. - let fn_decl = self.arena.alloc(hir::FnDecl { - inputs: arena_vec![self; input_ty], - output, - c_variadic: false, - implicit_self: hir::ImplicitSelfKind::None, - lifetime_elision_allowed: false, - }); - - // Lower the argument pattern/ident. The ident is used again in the `.await` lowering. - let (pat, task_context_hid) = self.pat_ident_binding_mode( - span, - Ident::with_dummy_span(sym::_task_context), - hir::BindingAnnotation::MUT, - ); - let param = hir::Param { - hir_id: self.next_id(), - pat, - ty_span: self.lower_span(span), - span: self.lower_span(span), - }; - let params = arena_vec![self; param]; - - let coroutine_kind = hir::CoroutineKind::Desugared( - hir::CoroutineDesugaring::AsyncGen, - async_coroutine_source, - ); let body = self.lower_body(move |this| { this.coroutine_kind = Some(coroutine_kind); let old_ctx = this.task_context; - this.task_context = Some(task_context_hid); + if task_context.is_some() { + this.task_context = task_context; + } let res = body(this); this.task_context = old_ctx; + (params, res) }); - // `static |_task_context| -> { body }`: + // `static |<_task_context?>| -> { }`: hir::ExprKind::Closure(self.arena.alloc(hir::Closure { def_id: self.local_def_id(closure_node_id), binder: hir::ClosureBinder::Default, @@ -1203,11 +1072,12 @@ impl<'hir> LoweringContext<'_, 'hir> { None }; - let async_body = this.make_async_expr( + let async_body = this.make_desugared_coroutine_expr( capture_clause, inner_closure_id, async_ret_ty, body.span, + hir::CoroutineDesugaring::Async, hir::CoroutineSource::Closure, |this| this.with_new_scopes(fn_decl_span, |this| this.lower_expr_mut(body)), ); diff --git a/compiler/rustc_ast_lowering/src/item.rs b/compiler/rustc_ast_lowering/src/item.rs index 3848f3b7782..45357aca533 100644 --- a/compiler/rustc_ast_lowering/src/item.rs +++ b/compiler/rustc_ast_lowering/src/item.rs @@ -1209,33 +1209,20 @@ impl<'hir> LoweringContext<'_, 'hir> { this.expr_block(body) }; - // FIXME(gen_blocks): Consider unifying the `make_*_expr` functions. - let coroutine_expr = match coroutine_kind { - CoroutineKind::Async { .. } => this.make_async_expr( - CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, - closure_id, - None, - body.span, - hir::CoroutineSource::Fn, - mkbody, - ), - CoroutineKind::Gen { .. } => this.make_gen_expr( - CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, - closure_id, - None, - body.span, - hir::CoroutineSource::Fn, - mkbody, - ), - CoroutineKind::AsyncGen { .. } => this.make_async_gen_expr( - CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, - closure_id, - None, - body.span, - hir::CoroutineSource::Fn, - mkbody, - ), + let desugaring_kind = match coroutine_kind { + CoroutineKind::Async { .. } => hir::CoroutineDesugaring::Async, + CoroutineKind::Gen { .. } => hir::CoroutineDesugaring::Gen, + CoroutineKind::AsyncGen { .. } => hir::CoroutineDesugaring::AsyncGen, }; + let coroutine_expr = this.make_desugared_coroutine_expr( + CaptureBy::Value { move_kw: rustc_span::DUMMY_SP }, + closure_id, + None, + body.span, + desugaring_kind, + hir::CoroutineSource::Fn, + mkbody, + ); let hir_id = this.lower_node_id(closure_id); this.maybe_forward_track_caller(body.span, fn_id, hir_id);