9863: feat: Generate default trait fn impl when generating `PartialEq` r=yoshuawuyts a=yoshuawuyts

Implements a default trait function body when generating the `PartialEq` trait for a type. Thanks!

r? `@veykril`

Co-authored-by: Yoshua Wuyts <yoshuawuyts@gmail.com>
This commit is contained in:
bors[bot] 2021-08-12 10:18:02 +00:00 committed by GitHub
commit 1376ece497
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 418 additions and 20 deletions

View File

@ -606,6 +606,177 @@ impl Clone for Foo {
}
}
}
"#,
)
}
#[test]
fn add_custom_impl_partial_eq_record_struct() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: eq
#[derive(Partial$0Eq)]
struct Foo {
bin: usize,
bar: usize,
}
"#,
r#"
struct Foo {
bin: usize,
bar: usize,
}
impl PartialEq for Foo {
$0fn eq(&self, other: &Self) -> bool {
self.bin == other.bin && self.bar == other.bar
}
}
"#,
)
}
#[test]
fn add_custom_impl_partial_eq_tuple_struct() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: eq
#[derive(Partial$0Eq)]
struct Foo(usize, usize);
"#,
r#"
struct Foo(usize, usize);
impl PartialEq for Foo {
$0fn eq(&self, other: &Self) -> bool {
self.0 == other.0 && self.1 == other.1
}
}
"#,
)
}
#[test]
fn add_custom_impl_partial_eq_empty_struct() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: eq
#[derive(Partial$0Eq)]
struct Foo;
"#,
r#"
struct Foo;
impl PartialEq for Foo {
$0fn eq(&self, other: &Self) -> bool {
true
}
}
"#,
)
}
#[test]
fn add_custom_impl_partial_eq_enum() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: eq
#[derive(Partial$0Eq)]
enum Foo {
Bar,
Baz,
}
"#,
r#"
enum Foo {
Bar,
Baz,
}
impl PartialEq for Foo {
$0fn eq(&self, other: &Self) -> bool {
core::mem::discriminant(self) == core::mem::discriminant(other)
}
}
"#,
)
}
#[test]
fn add_custom_impl_partial_eq_tuple_enum() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: eq
#[derive(Partial$0Eq)]
enum Foo {
Bar(String),
Baz,
}
"#,
r#"
enum Foo {
Bar(String),
Baz,
}
impl PartialEq for Foo {
$0fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Bar(l0), Self::Bar(r0)) => l0 == r0,
_ => core::mem::discriminant(self) == core::mem::discriminant(other),
}
}
}
"#,
)
}
#[test]
fn add_custom_impl_partial_eq_record_enum() {
check_assist(
replace_derive_with_manual_impl,
r#"
//- minicore: eq
#[derive(Partial$0Eq)]
enum Foo {
Bar {
bin: String,
},
Baz {
qux: String,
fez: String,
},
Qux {},
Bin,
}
"#,
r#"
enum Foo {
Bar {
bin: String,
},
Baz {
qux: String,
fez: String,
},
Qux {},
Bin,
}
impl PartialEq for Foo {
$0fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
(Self::Baz { qux: l_qux, fez: l_fez }, Self::Baz { qux: r_qux, fez: r_fez }) => l_qux == r_qux && l_fez == r_fez,
_ => core::mem::discriminant(self) == core::mem::discriminant(other),
}
}
}
"#,
)
}

View File

@ -20,6 +20,7 @@ pub(crate) fn gen_trait_fn_body(
"Debug" => gen_debug_impl(adt, func),
"Default" => gen_default_impl(adt, func),
"Hash" => gen_hash_impl(adt, func),
"PartialEq" => gen_partial_eq(adt, func),
_ => None,
}
}
@ -38,9 +39,7 @@ fn gen_clone_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
let mut arms = vec![];
for variant in list.variants() {
let name = variant.name()?;
let left = make::ext::ident_path("Self");
let right = make::ext::ident_path(&format!("{}", name));
let variant_name = make::path_concat(left, right);
let variant_name = make::ext::path_from_idents(["Self", &format!("{}", name)])?;
match variant.field_list() {
// => match self { Self::Name { x } => Self::Name { x: x.clone() } }
@ -150,9 +149,8 @@ fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
let mut arms = vec![];
for variant in list.variants() {
let name = variant.name()?;
let left = make::ext::ident_path("Self");
let right = make::ext::ident_path(&format!("{}", name));
let variant_name = make::path_pat(make::path_concat(left, right));
let variant_name =
make::path_pat(make::ext::path_from_idents(["Self", &format!("{}", name)])?);
let target = make::expr_path(make::ext::ident_path("f").into());
let fmt_string = make::expr_literal(&(format!("\"{}\"", name))).into();
@ -224,11 +222,9 @@ fn gen_debug_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
/// Generate a `Debug` impl based on the fields and members of the target type.
fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
fn gen_default_call() -> ast::Expr {
let trait_name = make::ext::ident_path("Default");
let method_name = make::ext::ident_path("default");
let fn_name = make::expr_path(make::path_concat(trait_name, method_name));
make::expr_call(fn_name, make::arg_list(None))
fn gen_default_call() -> Option<ast::Expr> {
let fn_name = make::ext::path_from_idents(["Default", "default"])?;
Some(make::expr_call(make::expr_path(fn_name), make::arg_list(None)))
}
match adt {
// `Debug` cannot be derived for unions, so no default impl can be provided.
@ -240,7 +236,7 @@ fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut fields = vec![];
for field in field_list.fields() {
let method_call = gen_default_call();
let method_call = gen_default_call()?;
let name_ref = make::name_ref(&field.name()?.to_string());
let field = make::record_expr_field(name_ref, Some(method_call));
fields.push(field);
@ -251,7 +247,10 @@ fn gen_default_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let struct_name = make::expr_path(make::ext::ident_path("Self"));
let fields = field_list.fields().map(|_| gen_default_call());
let fields = field_list
.fields()
.map(|_| gen_default_call())
.collect::<Option<Vec<ast::Expr>>>()?;
make::expr_call(struct_name, make::arg_list(fields))
}
None => {
@ -273,8 +272,7 @@ fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
let method = make::name_ref("hash");
let arg = make::expr_path(make::ext::ident_path("state"));
let expr = make::expr_method_call(target, method, make::arg_list(Some(arg)));
let stmt = make::expr_stmt(expr);
stmt.into()
make::expr_stmt(expr).into()
}
let body = match adt {
@ -283,11 +281,7 @@ fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
// => std::mem::discriminant(self).hash(state);
ast::Adt::Enum(_) => {
let root = make::ext::ident_path("core");
let submodule = make::ext::ident_path("mem");
let fn_name = make::ext::ident_path("discriminant");
let fn_name = make::path_concat(submodule, fn_name);
let fn_name = make::expr_path(make::path_concat(root, fn_name));
let fn_name = make_discriminant()?;
let arg = make::expr_path(make::ext::ident_path("self"));
let fn_call = make::expr_call(fn_name, make::arg_list(Some(arg)));
@ -326,3 +320,173 @@ fn gen_hash_impl(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
Some(())
}
/// Generate a `PartialEq` impl based on the fields and members of the target type.
fn gen_partial_eq(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
match expr {
Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
None => Some(cmp),
}
}
fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
let pat = make::ext::simple_ident_pat(make::name(&pat_name));
let name_ref = make::name_ref(field_name);
make::record_pat_field(name_ref, pat.into())
}
fn gen_record_pat(record_name: ast::Path, fields: Vec<ast::RecordPatField>) -> ast::RecordPat {
let list = make::record_pat_field_list(fields);
make::record_pat_with_fields(record_name, list)
}
fn gen_variant_path(variant: &ast::Variant) -> Option<ast::Path> {
make::ext::path_from_idents(["Self", &variant.name()?.to_string()])
}
fn gen_tuple_field(field_name: &String) -> ast::Pat {
ast::Pat::IdentPat(make::ident_pat(false, false, make::name(field_name)))
}
// FIXME: return `None` if the trait carries a generic type; we can only
// generate this code `Self` for the time being.
let body = match adt {
// `Hash` cannot be derived for unions, so no default impl can be provided.
ast::Adt::Union(_) => return None,
ast::Adt::Enum(enum_) => {
// => std::mem::discriminant(self) == std::mem::discriminant(other)
let lhs_name = make::expr_path(make::ext::ident_path("self"));
let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone())));
let rhs_name = make::expr_path(make::ext::ident_path("other"));
let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone())));
let eq_check = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
let mut case_count = 0;
let mut arms = vec![];
for variant in enum_.variant_list()?.variants() {
case_count += 1;
match variant.field_list() {
// => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
Some(ast::FieldList::RecordFieldList(list)) => {
let mut expr = None;
let mut l_fields = vec![];
let mut r_fields = vec![];
for field in list.fields() {
let field_name = field.name()?.to_string();
let l_name = &format!("l_{}", field_name);
l_fields.push(gen_record_pat_field(&field_name, &l_name));
let r_name = &format!("r_{}", field_name);
r_fields.push(gen_record_pat_field(&field_name, &r_name));
let lhs = make::expr_path(make::ext::ident_path(l_name));
let rhs = make::expr_path(make::ext::ident_path(r_name));
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
let tuple = make::tuple_pat(vec![left.into(), right.into()]);
if let Some(expr) = expr {
arms.push(make::match_arm(Some(tuple.into()), None, expr));
}
}
Some(ast::FieldList::TupleFieldList(list)) => {
let mut expr = None;
let mut l_fields = vec![];
let mut r_fields = vec![];
for (i, _) in list.fields().enumerate() {
let field_name = format!("{}", i);
let l_name = format!("l{}", field_name);
l_fields.push(gen_tuple_field(&l_name));
let r_name = format!("r{}", field_name);
r_fields.push(gen_tuple_field(&r_name));
let lhs = make::expr_path(make::ext::ident_path(&l_name));
let rhs = make::expr_path(make::ext::ident_path(&r_name));
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
let tuple = make::tuple_pat(vec![left.into(), right.into()]);
if let Some(expr) = expr {
arms.push(make::match_arm(Some(tuple.into()), None, expr));
}
}
None => continue,
}
}
let expr = match arms.len() {
0 => eq_check,
_ => {
if case_count > arms.len() {
let lhs = make::wildcard_pat().into();
arms.push(make::match_arm(Some(lhs), None, eq_check));
}
let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
make::expr_match(match_target, list)
}
};
make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
}
ast::Adt::Struct(strukt) => match strukt.field_list() {
Some(ast::FieldList::RecordFieldList(field_list)) => {
let mut expr = None;
for field in field_list.fields() {
let lhs = make::expr_path(make::ext::ident_path("self"));
let lhs = make::expr_field(lhs, &field.name()?.to_string());
let rhs = make::expr_path(make::ext::ident_path("other"));
let rhs = make::expr_field(rhs, &field.name()?.to_string());
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
}
Some(ast::FieldList::TupleFieldList(field_list)) => {
let mut expr = None;
for (i, _) in field_list.fields().enumerate() {
let idx = format!("{}", i);
let lhs = make::expr_path(make::ext::ident_path("self"));
let lhs = make::expr_field(lhs, &idx);
let rhs = make::expr_path(make::ext::ident_path("other"));
let rhs = make::expr_field(rhs, &idx);
let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
expr = gen_eq_chain(expr, cmp);
}
make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
}
// No fields in the body means there's nothing to hash.
None => {
let expr = make::expr_literal("true").into();
make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
}
},
};
ted::replace(func.body()?.syntax(), body.clone_for_update().syntax());
Some(())
}
fn make_discriminant() -> Option<ast::Expr> {
Some(make::expr_path(make::ext::path_from_idents(["core", "mem", "discriminant"])?))
}

View File

@ -160,6 +160,18 @@ pub enum ElseBranch {
IfExpr(ast::IfExpr),
}
impl From<ast::BlockExpr> for ElseBranch {
fn from(block_expr: ast::BlockExpr) -> Self {
Self::Block(block_expr)
}
}
impl From<ast::IfExpr> for ElseBranch {
fn from(if_expr: ast::IfExpr) -> Self {
Self::IfExpr(if_expr)
}
}
impl ast::IfExpr {
pub fn then_branch(&self) -> Option<ast::BlockExpr> {
self.blocks().next()
@ -350,6 +362,42 @@ impl ast::BinExpr {
}
}
impl std::fmt::Display for BinOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BinOp::BooleanOr => write!(f, "||"),
BinOp::BooleanAnd => write!(f, "&&"),
BinOp::EqualityTest => write!(f, "=="),
BinOp::NegatedEqualityTest => write!(f, "!="),
BinOp::LesserEqualTest => write!(f, "<="),
BinOp::GreaterEqualTest => write!(f, ">="),
BinOp::LesserTest => write!(f, "<"),
BinOp::GreaterTest => write!(f, ">"),
BinOp::Addition => write!(f, "+"),
BinOp::Multiplication => write!(f, "*"),
BinOp::Subtraction => write!(f, "-"),
BinOp::Division => write!(f, "/"),
BinOp::Remainder => write!(f, "%"),
BinOp::LeftShift => write!(f, "<<"),
BinOp::RightShift => write!(f, ">>"),
BinOp::BitwiseXor => write!(f, "^"),
BinOp::BitwiseOr => write!(f, "|"),
BinOp::BitwiseAnd => write!(f, "&"),
BinOp::Assignment => write!(f, "="),
BinOp::AddAssign => write!(f, "+="),
BinOp::DivAssign => write!(f, "/="),
BinOp::MulAssign => write!(f, "*="),
BinOp::RemAssign => write!(f, "%="),
BinOp::ShrAssign => write!(f, ">>="),
BinOp::ShlAssign => write!(f, "<<="),
BinOp::SubAssign => write!(f, "-"),
BinOp::BitOrAssign => write!(f, "|="),
BinOp::BitAndAssign => write!(f, "&="),
BinOp::BitXorAssign => write!(f, "^="),
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum RangeOp {
/// `..`

View File

@ -32,6 +32,18 @@ pub mod ext {
path_unqualified(path_segment(name_ref(ident)))
}
pub fn path_from_idents<'a>(
parts: impl std::iter::IntoIterator<Item = &'a str>,
) -> Option<ast::Path> {
let mut iter = parts.into_iter();
let base = ext::ident_path(iter.next()?);
let path = iter.fold(base, |base, s| {
let path = ext::ident_path(s);
path_concat(base, path)
});
Some(path)
}
pub fn expr_unreachable() -> ast::Expr {
expr_from_text("unreachable!()")
}
@ -264,6 +276,9 @@ pub fn expr_path(path: ast::Path) -> ast::Expr {
pub fn expr_continue() -> ast::Expr {
expr_from_text("continue")
}
pub fn expr_op(op: ast::BinOp, lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr {
expr_from_text(&format!("{} {} {}", lhs, op, rhs))
}
pub fn expr_break(expr: Option<ast::Expr>) -> ast::Expr {
match expr {
Some(expr) => expr_from_text(&format!("break {}", expr)),