libsyntax: short-circuit on non-matching variants in deriving code.

Allow a deriving instance using the generic code to short-circuit for
any non-matching enum variants (grouping them all into a _ match),
reducing the number of arms required. Use this to speed up the Eq &
TotalEq implementations.
This commit is contained in:
Huon Wilson 2013-04-02 22:02:46 +11:00
parent 99492796dc
commit bff3748731
11 changed files with 348 additions and 155 deletions

View File

@ -29,6 +29,7 @@ pub fn expand_deriving_clone(cx: @ext_ctxt,
name: ~"clone", name: ~"clone",
nargs: 0, nargs: 0,
output_type: None, // return Self output_type: None, // return Self
const_nonmatching: false,
combine_substructure: cs_clone combine_substructure: cs_clone
} }
] ]

View File

@ -31,24 +31,24 @@ pub fn expand_deriving_eq(cx: @ext_ctxt,
cs_or(|cx, span, _| build::mk_bool(cx, span, true), cs_or(|cx, span, _| build::mk_bool(cx, span, true),
cx, span, substr) cx, span, substr)
} }
macro_rules! md (
($name:expr, $f:ident) => {
MethodDef {
name: $name,
output_type: Some(~[~"bool"]),
nargs: 1,
const_nonmatching: true,
combine_substructure: $f
},
}
)
let trait_def = TraitDef { let trait_def = TraitDef {
path: ~[~"core", ~"cmp", ~"Eq"], path: ~[~"core", ~"cmp", ~"Eq"],
additional_bounds: ~[], additional_bounds: ~[],
methods: ~[ methods: ~[
MethodDef { md!(~"eq", cs_eq),
name: ~"ne", md!(~"ne", cs_ne)
output_type: Some(~[~"bool"]),
nargs: 1,
combine_substructure: cs_ne
},
MethodDef {
name: ~"eq",
output_type: Some(~[~"bool"]),
nargs: 1,
combine_substructure: cs_eq
}
] ]
}; };

View File

@ -16,10 +16,16 @@ use ext::build;
use ext::deriving::generic::*; use ext::deriving::generic::*;
use core::option::Some; use core::option::Some;
macro_rules! mk_cso { macro_rules! md {
($less:expr, $equal:expr) => { ($name:expr, $less:expr, $equal:expr) => {
|cx, span, substr| MethodDef {
cs_ord($less, $equal, cx, span, substr) name: $name,
output_type: Some(~[~"bool"]),
nargs: 1,
const_nonmatching: false,
combine_substructure: |cx, span, substr|
cs_ord($less, $equal, cx, span, substr)
}
} }
} }
@ -32,30 +38,10 @@ pub fn expand_deriving_ord(cx: @ext_ctxt,
// XXX: Ord doesn't imply Eq yet // XXX: Ord doesn't imply Eq yet
additional_bounds: ~[~[~"core", ~"cmp", ~"Eq"]], additional_bounds: ~[~[~"core", ~"cmp", ~"Eq"]],
methods: ~[ methods: ~[
MethodDef { md!(~"lt", true, false),
name: ~"lt", md!(~"le", true, true),
output_type: Some(~[~"bool"]), md!(~"gt", false, false),
nargs: 1, md!(~"ge", false, true)
combine_substructure: mk_cso!(true, false)
},
MethodDef {
name: ~"le",
output_type: Some(~[~"bool"]),
nargs: 1,
combine_substructure: mk_cso!(true, true)
},
MethodDef {
name: ~"gt",
output_type: Some(~[~"bool"]),
nargs: 1,
combine_substructure: mk_cso!(false, false)
},
MethodDef {
name: ~"ge",
output_type: Some(~[~"bool"]),
nargs: 1,
combine_substructure: mk_cso!(false, true)
},
] ]
}; };

View File

@ -35,6 +35,7 @@ pub fn expand_deriving_totaleq(cx: @ext_ctxt,
name: ~"equals", name: ~"equals",
output_type: Some(~[~"bool"]), output_type: Some(~[~"bool"]),
nargs: 1, nargs: 1,
const_nonmatching: true,
combine_substructure: cs_equals combine_substructure: cs_equals
} }
] ]

View File

@ -28,6 +28,7 @@ pub fn expand_deriving_totalord(cx: @ext_ctxt,
name: ~"cmp", name: ~"cmp",
output_type: Some(~[~"core", ~"cmp", ~"Ordering"]), output_type: Some(~[~"core", ~"cmp", ~"Ordering"]),
nargs: 1, nargs: 1,
const_nonmatching: false,
combine_substructure: cs_cmp combine_substructure: cs_cmp
} }
] ]

View File

@ -40,7 +40,8 @@ arguments:
- `EnumMatching`, when `Self` is an enum and all the arguments are the - `EnumMatching`, when `Self` is an enum and all the arguments are the
same variant of the enum (e.g. `Some(1)`, `Some(3)` and `Some(4)`) same variant of the enum (e.g. `Some(1)`, `Some(3)` and `Some(4)`)
- `EnumNonMatching` when `Self` is an enum and the arguments are not - `EnumNonMatching` when `Self` is an enum and the arguments are not
the same variant (e.g. `None`, `Some(1)` and `None`) the same variant (e.g. `None`, `Some(1)` and `None`). If
`const_nonmatching` is true, this will contain an empty list.
In the first two cases, the values from the corresponding fields in In the first two cases, the values from the corresponding fields in
all the arguments are grouped together. In the `EnumNonMatching` case all the arguments are grouped together. In the `EnumNonMatching` case
@ -129,9 +130,11 @@ use core::prelude::*;
use ast; use ast;
use ast::{ use ast::{
and, binop, deref, enum_def, expr, expr_match, ident, impure_fn, and, binop, deref, enum_def, expr, expr_match, ident, impure_fn,
item, Generics, m_imm, meta_item, method, named_field, or, public, item, Generics, m_imm, meta_item, method, named_field, or,
struct_def, sty_region, ty_rptr, ty_path, variant}; pat_wild, public, struct_def, sty_region, ty_rptr, ty_path,
variant};
use ast_util; use ast_util;
use ext::base::ext_ctxt; use ext::base::ext_ctxt;
@ -177,6 +180,10 @@ pub struct MethodDef<'self> {
/// Number of arguments other than `self` (all of type `&Self`) /// Number of arguments other than `self` (all of type `&Self`)
nargs: uint, nargs: uint,
/// if the value of the nonmatching enums is independent of the
/// actual enums, i.e. can use _ => .. match.
const_nonmatching: bool,
combine_substructure: CombineSubstructureFunc<'self> combine_substructure: CombineSubstructureFunc<'self>
} }
@ -555,12 +562,13 @@ impl<'self> MethodDef<'self> {
enum_def: &enum_def, enum_def: &enum_def,
type_ident: ident) type_ident: ident)
-> @expr { -> @expr {
self.build_enum_match(cx, span, enum_def, type_ident, ~[]) self.build_enum_match(cx, span, enum_def, type_ident,
None, ~[], 0)
} }
/** /**
Creates the nested matches for an enum definition, i.e. Creates the nested matches for an enum definition recursively, i.e.
``` ```
match self { match self {
@ -575,14 +583,20 @@ impl<'self> MethodDef<'self> {
the tree are the same. Hopefully the optimisers get rid of any the tree are the same. Hopefully the optimisers get rid of any
repetition, otherwise derived methods with many Self arguments will be repetition, otherwise derived methods with many Self arguments will be
exponentially large. exponentially large.
`matching` is Some(n) if all branches in the tree above the
current position are variant `n`, `None` otherwise (including on
the first call).
*/ */
fn build_enum_match(&self, fn build_enum_match(&self,
cx: @ext_ctxt, span: span, cx: @ext_ctxt, span: span,
enum_def: &enum_def, enum_def: &enum_def,
type_ident: ident, type_ident: ident,
matching: Option<uint>,
matches_so_far: ~[(uint, variant, matches_so_far: ~[(uint, variant,
~[(Option<ident>, @expr)])]) -> @expr { ~[(Option<ident>, @expr)])],
if matches_so_far.len() == self.nargs + 1 { match_count: uint) -> @expr {
if match_count == self.nargs + 1 {
// we've matched against all arguments, so make the final // we've matched against all arguments, so make the final
// expression at the bottom of the match tree // expression at the bottom of the match tree
match matches_so_far { match matches_so_far {
@ -594,41 +608,44 @@ impl<'self> MethodDef<'self> {
// vec of tuples, where each tuple represents a // vec of tuples, where each tuple represents a
// field. // field.
// `ref` inside let matches is buggy. Causes havoc wih rusc.
// let (variant_index, ref self_vec) = matches_so_far[0];
let (variant_index, variant, self_vec) = match matches_so_far[0] {
(i, v, ref s) => (i, v, s)
};
let substructure; let substructure;
// most arms don't have matching variants, so do a // most arms don't have matching variants, so do a
// quick check to see if they match (even though // quick check to see if they match (even though
// this means iterating twice) instead of being // this means iterating twice) instead of being
// optimistic and doing a pile of allocations etc. // optimistic and doing a pile of allocations etc.
if matches_so_far.all(|&(v_i, _, _)| v_i == variant_index) { match matching {
let mut enum_matching_fields = vec::from_elem(self_vec.len(), ~[]); Some(variant_index) => {
// `ref` inside let matches is buggy. Causes havoc wih rusc.
// let (variant_index, ref self_vec) = matches_so_far[0];
let (variant, self_vec) = match matches_so_far[0] {
(_, v, ref s) => (v, s)
};
for matches_so_far.tail().each |&(_, _, other_fields)| { let mut enum_matching_fields = vec::from_elem(self_vec.len(), ~[]);
for other_fields.eachi |i, &(_, other_field)| {
enum_matching_fields[i].push(other_field); for matches_so_far.tail().each |&(_, _, other_fields)| {
for other_fields.eachi |i, &(_, other_field)| {
enum_matching_fields[i].push(other_field);
}
} }
let field_tuples =
do vec::map2(*self_vec,
enum_matching_fields) |&(id, self_f), &other| {
(id, self_f, other)
};
substructure = EnumMatching(variant_index, variant, field_tuples);
}
None => {
substructure = EnumNonMatching(matches_so_far);
} }
let field_tuples =
do vec::map2(*self_vec,
enum_matching_fields) |&(id, self_f), &other| {
(id, self_f, other)
};
substructure = EnumMatching(variant_index, variant, field_tuples);
} else {
substructure = EnumNonMatching(matches_so_far);
} }
self.call_substructure_method(cx, span, type_ident, &substructure) self.call_substructure_method(cx, span, type_ident, &substructure)
} }
} }
} else { // there are still matches to create } else { // there are still matches to create
let (current_match_ident, current_match_str) = if matches_so_far.is_empty() { let (current_match_ident, current_match_str) = if match_count == 0 {
(cx.ident_of(~"self"), ~"__self") (cx.ident_of(~"self"), ~"__self")
} else { } else {
let s = fmt!("__other_%u", matches_so_far.len() - 1); let s = fmt!("__other_%u", matches_so_far.len() - 1);
@ -640,8 +657,32 @@ impl<'self> MethodDef<'self> {
// this is used as a stack // this is used as a stack
let mut matches_so_far = matches_so_far; let mut matches_so_far = matches_so_far;
// create an arm matching on each variant macro_rules! mk_arm(
for enum_def.variants.eachi |index, variant| { ($pat:expr, $expr:expr) => {
{
let blk = build::mk_simple_block(cx, span, $expr);
let arm = ast::arm {
pats: ~[$ pat ],
guard: None,
body: blk
};
arm
}
}
)
// the code for nonmatching variants only matters when
// we've seen at least one other variant already
if self.const_nonmatching && match_count > 0 {
// make a matching-variant match, and a _ match.
let index = match matching {
Some(i) => i,
None => cx.span_bug(span, ~"Non-matching variants when required to\
be matching in `deriving_generic`")
};
// matching-variant match
let variant = &enum_def.variants[index];
let pattern = create_enum_variant_pattern(cx, span, let pattern = create_enum_variant_pattern(cx, span,
variant, variant,
current_match_str); current_match_str);
@ -653,23 +694,63 @@ impl<'self> MethodDef<'self> {
} }
}; };
matches_so_far.push((index, *variant, idents)); matches_so_far.push((index, *variant, idents));
let arm_expr = self.build_enum_match(cx, span, let arm_expr = self.build_enum_match(cx, span,
enum_def, enum_def,
type_ident, type_ident,
matches_so_far); matching,
matches_so_far,
match_count + 1);
matches_so_far.pop(); matches_so_far.pop();
let arm = mk_arm!(pattern, arm_expr);
let arm_block = build::mk_simple_block(cx, span, arm_expr);
let arm = ast::arm {
pats: ~[ pattern ],
guard: None,
body: arm_block
};
arms.push(arm); arms.push(arm);
}
if enum_def.variants.len() > 1 {
// _ match, if necessary
let wild_pat = @ast::pat {
id: cx.next_id(),
node: pat_wild,
span: span
};
let wild_expr = self.call_substructure_method(cx, span, type_ident,
&EnumNonMatching(~[]));
let wild_arm = mk_arm!(wild_pat, wild_expr);
arms.push(wild_arm);
}
} else {
// create an arm matching on each variant
for enum_def.variants.eachi |index, variant| {
let pattern = create_enum_variant_pattern(cx, span,
variant,
current_match_str);
let idents = do vec::build |push| {
for each_variant_arg_ident(cx, span, variant) |i, field_id| {
let id = cx.ident_of(fmt!("%s_%u", current_match_str, i));
push((field_id, build::mk_path(cx, span, ~[ id ])));
}
};
matches_so_far.push((index, *variant, idents));
let new_matching =
match matching {
_ if match_count == 0 => Some(index),
Some(i) if index == i => Some(i),
_ => None
};
let arm_expr = self.build_enum_match(cx, span,
enum_def,
type_ident,
new_matching,
matches_so_far,
match_count + 1);
matches_so_far.pop();
let arm = mk_arm!(pattern, arm_expr);
arms.push(arm);
}
}
let deref_expr = build::mk_unary(cx, span, deref, let deref_expr = build::mk_unary(cx, span, deref,
build::mk_path(cx, span, build::mk_path(cx, span,
~[ current_match_ident ])); ~[ current_match_ident ]));

View File

@ -0,0 +1,50 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
enum E<T> {
E0,
E1(T),
E2(T,T)
}
pub fn main() {
let e0 = E0, e11 = E1(1), e12 = E1(2), e21 = E2(1,1), e22 = E2(1, 2);
// in order for both Ord and TotalOrd
let es = [e0, e11, e12, e21, e22];
for es.eachi |i, e1| {
for es.eachi |j, e2| {
let ord = i.cmp(&j);
let eq = i == j;
let lt = i < j, le = i <= j;
let gt = i > j, ge = i >= j;
// Eq
assert_eq!(*e1 == *e2, eq);
assert_eq!(*e1 != *e2, !eq);
// TotalEq
assert_eq!(e1.equals(e2), eq);
// Ord
assert_eq!(*e1 < *e2, lt);
assert_eq!(*e1 > *e2, gt);
assert_eq!(*e1 <= *e2, le);
assert_eq!(*e1 >= *e2, ge);
// TotalOrd
assert_eq!(e1.cmp(e2), ord);
}
}
}

View File

@ -0,0 +1,52 @@
// xfail-test #5530
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
enum ES<T> {
ES1 { x: T },
ES2 { x: T, y: T }
}
pub fn main() {
let es11 = ES1 {x: 1}, es12 = ES1 {x: 2}, es21 = ES2 {x: 1, y: 1}, es22 = ES2 {x: 1, y: 2};
// in order for both Ord and TotalOrd
let ess = [es11, es12, es21, es22];
for ess.eachi |i, es1| {
for ess.eachi |j, es2| {
let ord = i.cmp(&j);
let eq = i == j;
let lt = i < j, le = i <= j;
let gt = i > j, ge = i >= j;
// Eq
assert_eq!(*es1 == *es2, eq);
assert_eq!(*es1 != *es2, !eq);
// TotalEq
assert_eq!(es1.equals(es2), eq);
// Ord
assert_eq!(*es1 < *es2, lt);
assert_eq!(*es1 > *es2, gt);
assert_eq!(*es1 <= *es2, le);
assert_eq!(*es1 >= *es2, ge);
// TotalOrd
assert_eq!(es1.cmp(es2), ord);
}
}
}

View File

@ -0,0 +1,49 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
struct S<T> {
x: T,
y: T
}
pub fn main() {
let s1 = S {x: 1, y: 1}, s2 = S {x: 1, y: 2};
// in order for both Ord and TotalOrd
let ss = [s1, s2];
for ss.eachi |i, s1| {
for ss.eachi |j, s2| {
let ord = i.cmp(&j);
let eq = i == j;
let lt = i < j, le = i <= j;
let gt = i > j, ge = i >= j;
// Eq
assert_eq!(*s1 == *s2, eq);
assert_eq!(*s1 != *s2, !eq);
// TotalEq
assert_eq!(s1.equals(s2), eq);
// Ord
assert_eq!(*s1 < *s2, lt);
assert_eq!(*s1 > *s2, gt);
assert_eq!(*s1 <= *s2, le);
assert_eq!(*s1 >= *s2, ge);
// TotalOrd
assert_eq!(s1.cmp(s2), ord);
}
}
}

View File

@ -0,0 +1,47 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
struct TS<T>(T,T);
pub fn main() {
let ts1 = TS(1, 1), ts2 = TS(1,2);
// in order for both Ord and TotalOrd
let tss = [ts1, ts2];
for tss.eachi |i, ts1| {
for tss.eachi |j, ts2| {
let ord = i.cmp(&j);
let eq = i == j;
let lt = i < j, le = i <= j;
let gt = i > j, ge = i >= j;
// Eq
assert_eq!(*ts1 == *ts2, eq);
assert_eq!(*ts1 != *ts2, !eq);
// TotalEq
assert_eq!(ts1.equals(ts2), eq);
// Ord
assert_eq!(*ts1 < *ts2, lt);
assert_eq!(*ts1 > *ts2, gt);
assert_eq!(*ts1 <= *ts2, le);
assert_eq!(*ts1 >= *ts2, ge);
// TotalOrd
assert_eq!(ts1.cmp(ts2), ord);
}
}
}

View File

@ -1,75 +0,0 @@
// Copyright 2013 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// http://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
struct S<T> {
x: T,
y: T
}
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
struct TS<T>(T,T);
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
enum E<T> {
E0,
E1(T),
E2(T,T)
}
#[deriving(Eq, TotalEq, Ord, TotalOrd)]
enum ES<T> {
ES1 { x: T },
ES2 { x: T, y: T }
}
pub fn main() {
let s1 = S {x: 1, y: 1}, s2 = S {x: 1, y: 2};
let ts1 = TS(1, 1), ts2 = TS(1,2);
let e0 = E0, e11 = E1(1), e12 = E1(2), e21 = E2(1,1), e22 = E2(1, 2);
let es11 = ES1 {x: 1}, es12 = ES1 {x: 2}, es21 = ES2 {x: 1, y: 1}, es22 = ES2 {x: 1, y: 2};
test([s1, s2]);
test([ts1, ts2]);
test([e0, e11, e12, e21, e22]);
test([es11, es12, es21, es22]);
}
fn test<T: Eq+TotalEq+Ord+TotalOrd>(ts: &[T]) {
// compare each element against all other elements. The list
// should be in sorted order, so that if i < j, then ts[i] <
// ts[j], etc.
for ts.eachi |i, t1| {
for ts.eachi |j, t2| {
let ord = i.cmp(&j);
let eq = i == j;
let lt = i < j, le = i <= j;
let gt = i > j, ge = i >= j;
// Eq
assert_eq!(*t1 == *t2, eq);
// TotalEq
assert_eq!(t1.equals(t2), eq);
// Ord
assert_eq!(*t1 < *t2, lt);
assert_eq!(*t1 > *t2, gt);
assert_eq!(*t1 <= *t2, le);
assert_eq!(*t1 >= *t2, ge);
// TotalOrd
assert_eq!(t1.cmp(t2), ord);
}
}
}