Infer correct expected type for generic struct fields

This commit is contained in:
Florian Diebold 2021-05-23 18:10:40 +02:00
parent 4a6cdd776d
commit 7a0c93c58a
4 changed files with 46 additions and 19 deletions

View File

@ -513,9 +513,9 @@ impl Field {
} }
/// Returns the type as in the signature of the struct (i.e., with /// Returns the type as in the signature of the struct (i.e., with
/// placeholder types for type parameters). This is good for showing /// placeholder types for type parameters). Only use this in the context of
/// signature help, but not so good to actually get the type of the field /// the field *definition*; if you've already got a variable of the struct
/// when you actually have a variable of the struct. /// type, use `Type::field_type` to get to the field type.
pub fn ty(&self, db: &dyn HirDatabase) -> Type { pub fn ty(&self, db: &dyn HirDatabase) -> Type {
let var_id = self.parent.into(); let var_id = self.parent.into();
let generic_def_id: GenericDefId = match self.parent { let generic_def_id: GenericDefId = match self.parent {
@ -1944,6 +1944,18 @@ impl Type {
} }
} }
pub fn field_type(&self, db: &dyn HirDatabase, field: Field) -> Option<Type> {
let (adt_id, substs) = self.ty.as_adt()?;
let variant_id: hir_def::VariantId = field.parent.into();
if variant_id.adt_id() != adt_id {
return None;
}
let ty = db.field_types(variant_id).get(field.id)?.clone();
let ty = ty.substitute(&Interner, substs);
Some(self.derived(ty))
}
pub fn fields(&self, db: &dyn HirDatabase) -> Vec<(Field, Type)> { pub fn fields(&self, db: &dyn HirDatabase) -> Vec<(Field, Type)> {
let (variant_id, substs) = match self.ty.kind(&Interner) { let (variant_id, substs) = match self.ty.kind(&Interner) {
&TyKind::Adt(hir_ty::AdtId(AdtId::StructId(s)), ref substs) => (s.into(), substs), &TyKind::Adt(hir_ty::AdtId(AdtId::StructId(s)), ref substs) => (s.into(), substs),

View File

@ -485,6 +485,14 @@ impl VariantId {
VariantId::UnionId(it) => it.lookup(db).id.file_id(), VariantId::UnionId(it) => it.lookup(db).id.file_id(),
} }
} }
pub fn adt_id(self) -> AdtId {
match self {
VariantId::EnumVariantId(it) => it.parent.into(),
VariantId::StructId(it) => it.into(),
VariantId::UnionId(it) => it.into(),
}
}
} }
trait Intern { trait Intern {

View File

@ -337,25 +337,25 @@ impl<'a> CompletionContext<'a> {
}, },
ast::RecordExprFieldList(_it) => { ast::RecordExprFieldList(_it) => {
cov_mark::hit!(expected_type_struct_field_without_leading_char); cov_mark::hit!(expected_type_struct_field_without_leading_char);
self.token.prev_sibling_or_token() // wouldn't try {} be nice...
.and_then(|se| se.into_node()) (|| {
.and_then(|node| ast::RecordExprField::cast(node)) let record_ty = self.sema.type_of_expr(&ast::Expr::cast(node.parent()?)?)?;
.and_then(|rf| self.sema.resolve_record_field(&rf).zip(Some(rf))) let expr_field = self.token.prev_sibling_or_token()?
.map(|(f, rf)|( .into_node()
Some(f.0.ty(self.db)), .and_then(|node| ast::RecordExprField::cast(node))?;
rf.field_name().map(NameOrNameRef::NameRef), let field = self.sema.resolve_record_field(&expr_field)?.0;
Some((
record_ty.field_type(self.db, field),
expr_field.field_name().map(NameOrNameRef::NameRef),
)) ))
.unwrap_or((None, None)) })().unwrap_or((None, None))
}, },
ast::RecordExprField(it) => { ast::RecordExprField(it) => {
cov_mark::hit!(expected_type_struct_field_with_leading_char); cov_mark::hit!(expected_type_struct_field_with_leading_char);
self.sema (
.resolve_record_field(&it) it.expr().as_ref().and_then(|e| self.sema.type_of_expr(e)),
.map(|f|( it.field_name().map(NameOrNameRef::NameRef),
Some(f.0.ty(self.db)), )
it.field_name().map(NameOrNameRef::NameRef),
))
.unwrap_or((None, None))
}, },
ast::MatchExpr(it) => { ast::MatchExpr(it) => {
cov_mark::hit!(expected_type_match_arm_without_leading_char); cov_mark::hit!(expected_type_match_arm_without_leading_char);
@ -910,7 +910,7 @@ fn foo() -> u32 {
} }
#[test] #[test]
fn expected_type_closure_param() { fn expected_type_closure_param_return() {
check_expected_type_and_name( check_expected_type_and_name(
r#" r#"
fn foo() { fn foo() {

View File

@ -667,6 +667,13 @@ fn foo() { A { the$0 } }
), ),
detail: "u32", detail: "u32",
deprecated: true, deprecated: true,
relevance: CompletionRelevance {
exact_name_match: false,
type_match: Some(
CouldUnify,
),
is_local: false,
},
}, },
] ]
"#]], "#]],