diff --git a/crates/ra_hir/src/marks.rs b/crates/ra_hir/src/marks.rs index bbf57004d5b..5b640004288 100644 --- a/crates/ra_hir/src/marks.rs +++ b/crates/ra_hir/src/marks.rs @@ -8,4 +8,5 @@ test_utils::marks!( glob_enum glob_across_crates std_prelude + match_ergonomics_ref ); diff --git a/crates/ra_hir/src/ty/infer.rs b/crates/ra_hir/src/ty/infer.rs index c9a5bc7a100..735cdecb910 100644 --- a/crates/ra_hir/src/ty/infer.rs +++ b/crates/ra_hir/src/ty/infer.rs @@ -63,6 +63,30 @@ enum ExprOrPatId { impl_froms!(ExprOrPatId: ExprId, PatId); +/// Binding modes inferred for patterns. +/// https://doc.rust-lang.org/reference/patterns.html#binding-modes +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +enum BindingMode { + Move, + Ref(Mutability), +} + +impl BindingMode { + pub fn convert(annotation: &BindingAnnotation) -> BindingMode { + match annotation { + BindingAnnotation::Unannotated | BindingAnnotation::Mutable => BindingMode::Move, + BindingAnnotation::Ref => BindingMode::Ref(Mutability::Shared), + BindingAnnotation::RefMut => BindingMode::Ref(Mutability::Mut), + } + } +} + +impl Default for BindingMode { + fn default() -> Self { + BindingMode::Move + } +} + /// The result of type inference: A mapping from expressions and patterns to types. #[derive(Clone, PartialEq, Eq, Debug)] pub struct InferenceResult { @@ -530,6 +554,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { path: Option<&Path>, subpats: &[PatId], expected: &Ty, + default_bm: BindingMode, ) -> Ty { let (ty, def) = self.resolve_variant(path); @@ -542,13 +567,19 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { .and_then(|d| d.field(self.db, &Name::tuple_field_name(i))) .map_or(Ty::Unknown, |field| field.ty(self.db)) .subst(&substs); - self.infer_pat(subpat, &expected_ty); + self.infer_pat(subpat, &expected_ty, default_bm); } ty } - fn infer_struct_pat(&mut self, path: Option<&Path>, subpats: &[FieldPat], expected: &Ty) -> Ty { + fn infer_struct_pat( + &mut self, + path: Option<&Path>, + subpats: &[FieldPat], + expected: &Ty, + default_bm: BindingMode, + ) -> Ty { let (ty, def) = self.resolve_variant(path); self.unify(&ty, expected); @@ -559,15 +590,45 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { let matching_field = def.and_then(|it| it.field(self.db, &subpat.name)); let expected_ty = matching_field.map_or(Ty::Unknown, |field| field.ty(self.db)).subst(&substs); - self.infer_pat(subpat.pat, &expected_ty); + self.infer_pat(subpat.pat, &expected_ty, default_bm); } ty } - fn infer_pat(&mut self, pat: PatId, expected: &Ty) -> Ty { + fn infer_pat(&mut self, pat: PatId, mut expected: &Ty, mut default_bm: BindingMode) -> Ty { let body = Arc::clone(&self.body); // avoid borrow checker problem + let is_non_ref_pat = match &body[pat] { + Pat::Tuple(..) + | Pat::TupleStruct { .. } + | Pat::Struct { .. } + | Pat::Range { .. } + | Pat::Slice { .. } => true, + // TODO: Path/Lit might actually evaluate to ref, but inference is unimplemented. + Pat::Path(..) | Pat::Lit(..) => true, + Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Missing => false, + }; + if is_non_ref_pat { + while let Ty::Ref(inner, mutability) = expected { + expected = inner; + default_bm = match default_bm { + BindingMode::Move => BindingMode::Ref(*mutability), + BindingMode::Ref(Mutability::Shared) => BindingMode::Ref(Mutability::Shared), + BindingMode::Ref(Mutability::Mut) => BindingMode::Ref(*mutability), + } + } + } else if let Pat::Ref { .. } = &body[pat] { + tested_by!(match_ergonomics_ref); + // When you encounter a `&pat` pattern, reset to Move. + // This is so that `w` is by value: `let (_, &w) = &(1, &2);` + default_bm = BindingMode::Move; + } + + // Lose mutability. + let default_bm = default_bm; + let expected = expected; + let ty = match &body[pat] { Pat::Tuple(ref args) => { let expectations = match *expected { @@ -579,7 +640,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { let inner_tys = args .iter() .zip(expectations_iter) - .map(|(&pat, ty)| self.infer_pat(pat, ty)) + .map(|(&pat, ty)| self.infer_pat(pat, ty, default_bm)) .collect::<Vec<_>>() .into(); @@ -595,14 +656,14 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } _ => &Ty::Unknown, }; - let subty = self.infer_pat(*pat, expectation); + let subty = self.infer_pat(*pat, expectation, default_bm); Ty::Ref(subty.into(), *mutability) } Pat::TupleStruct { path: ref p, args: ref subpats } => { - self.infer_tuple_struct_pat(p.as_ref(), subpats, expected) + self.infer_tuple_struct_pat(p.as_ref(), subpats, expected, default_bm) } Pat::Struct { path: ref p, args: ref fields } => { - self.infer_struct_pat(p.as_ref(), fields, expected) + self.infer_struct_pat(p.as_ref(), fields, expected, default_bm) } Pat::Path(path) => { // TODO use correct resolver for the surrounding expression @@ -610,17 +671,21 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { self.infer_path_expr(&resolver, &path, pat.into()).unwrap_or(Ty::Unknown) } Pat::Bind { mode, name: _name, subpat } => { + let mode = if mode == &BindingAnnotation::Unannotated { + default_bm + } else { + BindingMode::convert(mode) + }; let inner_ty = if let Some(subpat) = subpat { - self.infer_pat(*subpat, expected) + self.infer_pat(*subpat, expected, default_bm) } else { expected.clone() }; let inner_ty = self.insert_type_vars_shallow(inner_ty); let bound_ty = match mode { - BindingAnnotation::Ref => Ty::Ref(inner_ty.clone().into(), Mutability::Shared), - BindingAnnotation::RefMut => Ty::Ref(inner_ty.clone().into(), Mutability::Mut), - BindingAnnotation::Mutable | BindingAnnotation::Unannotated => inner_ty.clone(), + BindingMode::Ref(mutability) => Ty::Ref(inner_ty.clone().into(), mutability), + BindingMode::Move => inner_ty.clone(), }; let bound_ty = self.resolve_ty_as_possible(&mut vec![], bound_ty); self.write_pat_ty(pat, bound_ty); @@ -700,7 +765,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } Expr::For { iterable, body, pat } => { let _iterable_ty = self.infer_expr(*iterable, &Expectation::none()); - self.infer_pat(*pat, &Ty::Unknown); + self.infer_pat(*pat, &Ty::Unknown, BindingMode::default()); self.infer_expr(*body, &Expectation::has_type(Ty::unit())); Ty::unit() } @@ -714,7 +779,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { } else { Ty::Unknown }; - self.infer_pat(*arg_pat, &expected); + self.infer_pat(*arg_pat, &expected, BindingMode::default()); } // TODO: infer lambda type etc. @@ -807,7 +872,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { for arm in arms { for &pat in &arm.pats { - let _pat_ty = self.infer_pat(pat, &input_ty); + let _pat_ty = self.infer_pat(pat, &input_ty, BindingMode::default()); } if let Some(guard_expr) = arm.guard { self.infer_expr(guard_expr, &Expectation::has_type(Ty::Bool)); @@ -1007,7 +1072,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { decl_ty }; - self.infer_pat(*pat, &ty); + self.infer_pat(*pat, &ty, BindingMode::default()); } Statement::Expr(expr) => { self.infer_expr(*expr, &Expectation::none()); @@ -1023,7 +1088,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { for (type_ref, pat) in signature.params().iter().zip(body.params()) { let ty = self.make_ty(type_ref); - self.infer_pat(*pat, &ty); + self.infer_pat(*pat, &ty, BindingMode::default()); } self.return_ty = self.make_ty(signature.ret_type()); } diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs index acae71c266d..0f2172ddfc7 100644 --- a/crates/ra_hir/src/ty/tests.rs +++ b/crates/ra_hir/src/ty/tests.rs @@ -830,6 +830,60 @@ fn test(x: &i32) { ); } +#[test] +fn infer_pattern_match_ergonomics() { + assert_snapshot_matches!( + infer(r#" +struct A<T>(T); + +fn test() { + let A(n) = &A(1); + let A(n) = &mut A(1); +} +"#), + @r###" +[28; 79) '{ ...(1); }': () +[38; 42) 'A(n)': A<i32> +[40; 41) 'n': &i32 +[45; 50) '&A(1)': &A<i32> +[46; 47) 'A': A<i32>(T) -> A<T> +[46; 50) 'A(1)': A<i32> +[48; 49) '1': i32 +[60; 64) 'A(n)': A<i32> +[62; 63) 'n': &mut i32 +[67; 76) '&mut A(1)': &mut A<i32> +[72; 73) 'A': A<i32>(T) -> A<T> +[72; 76) 'A(1)': A<i32> +[74; 75) '1': i32"### + ); +} + +#[test] +fn infer_pattern_match_ergonomics_ref() { + covers!(match_ergonomics_ref); + assert_snapshot_matches!( + infer(r#" +fn test() { + let v = &(1, &2); + let (_, &w) = v; +} +"#), + @r###" +[11; 57) '{ ...= v; }': () +[21; 22) 'v': &(i32, &i32) +[25; 33) '&(1, &2)': &(i32, &i32) +[26; 33) '(1, &2)': (i32, &i32) +[27; 28) '1': i32 +[30; 32) '&2': &i32 +[31; 32) '2': i32 +[43; 50) '(_, &w)': (i32, &i32) +[44; 45) '_': i32 +[47; 49) '&w': &i32 +[48; 49) 'w': i32 +[53; 54) 'v': &(i32, &i32)"### + ); +} + #[test] fn infer_adt_pattern() { assert_snapshot_matches!(