From bff374873136bacf8352e05f73cb3252761dc2d6 Mon Sep 17 00:00:00 2001 From: Huon Wilson Date: Tue, 2 Apr 2013 22:02:46 +1100 Subject: [PATCH] libsyntax: short-circuit on non-matching variants in deriving code. Allow a deriving instance using the generic code to short-circuit for any non-matching enum variants (grouping them all into a _ match), reducing the number of arms required. Use this to speed up the Eq & TotalEq implementations. --- src/libsyntax/ext/deriving/clone.rs | 1 + src/libsyntax/ext/deriving/cmp/eq.rs | 26 +-- src/libsyntax/ext/deriving/cmp/ord.rs | 42 ++--- src/libsyntax/ext/deriving/cmp/totaleq.rs | 1 + src/libsyntax/ext/deriving/cmp/totalord.rs | 1 + src/libsyntax/ext/deriving/generic.rs | 159 +++++++++++++----- .../run-pass/deriving-cmp-generic-enum.rs | 50 ++++++ .../deriving-cmp-generic-struct-enum.rs | 52 ++++++ .../run-pass/deriving-cmp-generic-struct.rs | 49 ++++++ .../deriving-cmp-generic-tuple-struct.rs | 47 ++++++ src/test/run-pass/deriving-cmp.rs | 75 --------- 11 files changed, 348 insertions(+), 155 deletions(-) create mode 100644 src/test/run-pass/deriving-cmp-generic-enum.rs create mode 100644 src/test/run-pass/deriving-cmp-generic-struct-enum.rs create mode 100644 src/test/run-pass/deriving-cmp-generic-struct.rs create mode 100644 src/test/run-pass/deriving-cmp-generic-tuple-struct.rs delete mode 100644 src/test/run-pass/deriving-cmp.rs diff --git a/src/libsyntax/ext/deriving/clone.rs b/src/libsyntax/ext/deriving/clone.rs index 0c62566702d..d996bca60a3 100644 --- a/src/libsyntax/ext/deriving/clone.rs +++ b/src/libsyntax/ext/deriving/clone.rs @@ -29,6 +29,7 @@ pub fn expand_deriving_clone(cx: @ext_ctxt, name: ~"clone", nargs: 0, output_type: None, // return Self + const_nonmatching: false, combine_substructure: cs_clone } ] diff --git a/src/libsyntax/ext/deriving/cmp/eq.rs b/src/libsyntax/ext/deriving/cmp/eq.rs index 142f0565e14..c0060cc67dc 100644 --- a/src/libsyntax/ext/deriving/cmp/eq.rs +++ b/src/libsyntax/ext/deriving/cmp/eq.rs @@ -31,24 +31,24 @@ pub fn expand_deriving_eq(cx: @ext_ctxt, cs_or(|cx, span, _| build::mk_bool(cx, span, true), cx, span, substr) } - + macro_rules! md ( + ($name:expr, $f:ident) => { + MethodDef { + name: $name, + output_type: Some(~[~"bool"]), + nargs: 1, + const_nonmatching: true, + combine_substructure: $f + }, + } + ) let trait_def = TraitDef { path: ~[~"core", ~"cmp", ~"Eq"], additional_bounds: ~[], methods: ~[ - MethodDef { - name: ~"ne", - output_type: Some(~[~"bool"]), - nargs: 1, - combine_substructure: cs_ne - }, - MethodDef { - name: ~"eq", - output_type: Some(~[~"bool"]), - nargs: 1, - combine_substructure: cs_eq - } + md!(~"eq", cs_eq), + md!(~"ne", cs_ne) ] }; diff --git a/src/libsyntax/ext/deriving/cmp/ord.rs b/src/libsyntax/ext/deriving/cmp/ord.rs index 7f7babab45c..398e27eb3e3 100644 --- a/src/libsyntax/ext/deriving/cmp/ord.rs +++ b/src/libsyntax/ext/deriving/cmp/ord.rs @@ -16,10 +16,16 @@ use ext::build; use ext::deriving::generic::*; use core::option::Some; -macro_rules! mk_cso { - ($less:expr, $equal:expr) => { - |cx, span, substr| - cs_ord($less, $equal, cx, span, substr) +macro_rules! md { + ($name:expr, $less:expr, $equal:expr) => { + MethodDef { + name: $name, + output_type: Some(~[~"bool"]), + nargs: 1, + const_nonmatching: false, + combine_substructure: |cx, span, substr| + cs_ord($less, $equal, cx, span, substr) + } } } @@ -32,30 +38,10 @@ pub fn expand_deriving_ord(cx: @ext_ctxt, // XXX: Ord doesn't imply Eq yet additional_bounds: ~[~[~"core", ~"cmp", ~"Eq"]], methods: ~[ - MethodDef { - name: ~"lt", - output_type: Some(~[~"bool"]), - nargs: 1, - combine_substructure: mk_cso!(true, false) - }, - MethodDef { - name: ~"le", - output_type: Some(~[~"bool"]), - nargs: 1, - combine_substructure: mk_cso!(true, true) - }, - MethodDef { - name: ~"gt", - output_type: Some(~[~"bool"]), - nargs: 1, - combine_substructure: mk_cso!(false, false) - }, - MethodDef { - name: ~"ge", - output_type: Some(~[~"bool"]), - nargs: 1, - combine_substructure: mk_cso!(false, true) - }, + md!(~"lt", true, false), + md!(~"le", true, true), + md!(~"gt", false, false), + md!(~"ge", false, true) ] }; diff --git a/src/libsyntax/ext/deriving/cmp/totaleq.rs b/src/libsyntax/ext/deriving/cmp/totaleq.rs index d71db22591d..fc8ec103a60 100644 --- a/src/libsyntax/ext/deriving/cmp/totaleq.rs +++ b/src/libsyntax/ext/deriving/cmp/totaleq.rs @@ -35,6 +35,7 @@ pub fn expand_deriving_totaleq(cx: @ext_ctxt, name: ~"equals", output_type: Some(~[~"bool"]), nargs: 1, + const_nonmatching: true, combine_substructure: cs_equals } ] diff --git a/src/libsyntax/ext/deriving/cmp/totalord.rs b/src/libsyntax/ext/deriving/cmp/totalord.rs index d82c63e9dd3..9c20a0be87c 100644 --- a/src/libsyntax/ext/deriving/cmp/totalord.rs +++ b/src/libsyntax/ext/deriving/cmp/totalord.rs @@ -28,6 +28,7 @@ pub fn expand_deriving_totalord(cx: @ext_ctxt, name: ~"cmp", output_type: Some(~[~"core", ~"cmp", ~"Ordering"]), nargs: 1, + const_nonmatching: false, combine_substructure: cs_cmp } ] diff --git a/src/libsyntax/ext/deriving/generic.rs b/src/libsyntax/ext/deriving/generic.rs index 23a075ef001..8fe2ca1a1a1 100644 --- a/src/libsyntax/ext/deriving/generic.rs +++ b/src/libsyntax/ext/deriving/generic.rs @@ -40,7 +40,8 @@ arguments: - `EnumMatching`, when `Self` is an enum and all the arguments are the same variant of the enum (e.g. `Some(1)`, `Some(3)` and `Some(4)`) - `EnumNonMatching` when `Self` is an enum and the arguments are not - the same variant (e.g. `None`, `Some(1)` and `None`) + the same variant (e.g. `None`, `Some(1)` and `None`). If + `const_nonmatching` is true, this will contain an empty list. In the first two cases, the values from the corresponding fields in all the arguments are grouped together. In the `EnumNonMatching` case @@ -129,9 +130,11 @@ use core::prelude::*; use ast; use ast::{ + and, binop, deref, enum_def, expr, expr_match, ident, impure_fn, - item, Generics, m_imm, meta_item, method, named_field, or, public, - struct_def, sty_region, ty_rptr, ty_path, variant}; + item, Generics, m_imm, meta_item, method, named_field, or, + pat_wild, public, struct_def, sty_region, ty_rptr, ty_path, + variant}; use ast_util; use ext::base::ext_ctxt; @@ -177,6 +180,10 @@ pub struct MethodDef<'self> { /// Number of arguments other than `self` (all of type `&Self`) nargs: uint, + /// if the value of the nonmatching enums is independent of the + /// actual enums, i.e. can use _ => .. match. + const_nonmatching: bool, + combine_substructure: CombineSubstructureFunc<'self> } @@ -555,12 +562,13 @@ impl<'self> MethodDef<'self> { enum_def: &enum_def, type_ident: ident) -> @expr { - self.build_enum_match(cx, span, enum_def, type_ident, ~[]) + self.build_enum_match(cx, span, enum_def, type_ident, + None, ~[], 0) } /** - Creates the nested matches for an enum definition, i.e. + Creates the nested matches for an enum definition recursively, i.e. ``` match self { @@ -575,14 +583,20 @@ impl<'self> MethodDef<'self> { the tree are the same. Hopefully the optimisers get rid of any repetition, otherwise derived methods with many Self arguments will be exponentially large. + + `matching` is Some(n) if all branches in the tree above the + current position are variant `n`, `None` otherwise (including on + the first call). */ fn build_enum_match(&self, cx: @ext_ctxt, span: span, enum_def: &enum_def, type_ident: ident, + matching: Option, matches_so_far: ~[(uint, variant, - ~[(Option, @expr)])]) -> @expr { - if matches_so_far.len() == self.nargs + 1 { + ~[(Option, @expr)])], + match_count: uint) -> @expr { + if match_count == self.nargs + 1 { // we've matched against all arguments, so make the final // expression at the bottom of the match tree match matches_so_far { @@ -594,41 +608,44 @@ impl<'self> MethodDef<'self> { // vec of tuples, where each tuple represents a // field. - // `ref` inside let matches is buggy. Causes havoc wih rusc. - // let (variant_index, ref self_vec) = matches_so_far[0]; - let (variant_index, variant, self_vec) = match matches_so_far[0] { - (i, v, ref s) => (i, v, s) - }; - let substructure; // most arms don't have matching variants, so do a // quick check to see if they match (even though // this means iterating twice) instead of being // optimistic and doing a pile of allocations etc. - if matches_so_far.all(|&(v_i, _, _)| v_i == variant_index) { - let mut enum_matching_fields = vec::from_elem(self_vec.len(), ~[]); + match matching { + Some(variant_index) => { + // `ref` inside let matches is buggy. Causes havoc wih rusc. + // let (variant_index, ref self_vec) = matches_so_far[0]; + let (variant, self_vec) = match matches_so_far[0] { + (_, v, ref s) => (v, s) + }; - for matches_so_far.tail().each |&(_, _, other_fields)| { - for other_fields.eachi |i, &(_, other_field)| { - enum_matching_fields[i].push(other_field); + let mut enum_matching_fields = vec::from_elem(self_vec.len(), ~[]); + + for matches_so_far.tail().each |&(_, _, other_fields)| { + for other_fields.eachi |i, &(_, other_field)| { + enum_matching_fields[i].push(other_field); + } } + let field_tuples = + do vec::map2(*self_vec, + enum_matching_fields) |&(id, self_f), &other| { + (id, self_f, other) + }; + substructure = EnumMatching(variant_index, variant, field_tuples); + } + None => { + substructure = EnumNonMatching(matches_so_far); } - let field_tuples = - do vec::map2(*self_vec, - enum_matching_fields) |&(id, self_f), &other| { - (id, self_f, other) - }; - substructure = EnumMatching(variant_index, variant, field_tuples); - } else { - substructure = EnumNonMatching(matches_so_far); } self.call_substructure_method(cx, span, type_ident, &substructure) } } } else { // there are still matches to create - let (current_match_ident, current_match_str) = if matches_so_far.is_empty() { + let (current_match_ident, current_match_str) = if match_count == 0 { (cx.ident_of(~"self"), ~"__self") } else { let s = fmt!("__other_%u", matches_so_far.len() - 1); @@ -640,8 +657,32 @@ impl<'self> MethodDef<'self> { // this is used as a stack let mut matches_so_far = matches_so_far; - // create an arm matching on each variant - for enum_def.variants.eachi |index, variant| { + macro_rules! mk_arm( + ($pat:expr, $expr:expr) => { + { + let blk = build::mk_simple_block(cx, span, $expr); + let arm = ast::arm { + pats: ~[$ pat ], + guard: None, + body: blk + }; + arm + } + } + ) + + // the code for nonmatching variants only matters when + // we've seen at least one other variant already + if self.const_nonmatching && match_count > 0 { + // make a matching-variant match, and a _ match. + let index = match matching { + Some(i) => i, + None => cx.span_bug(span, ~"Non-matching variants when required to\ + be matching in `deriving_generic`") + }; + + // matching-variant match + let variant = &enum_def.variants[index]; let pattern = create_enum_variant_pattern(cx, span, variant, current_match_str); @@ -653,23 +694,63 @@ impl<'self> MethodDef<'self> { } }; - matches_so_far.push((index, *variant, idents)); let arm_expr = self.build_enum_match(cx, span, enum_def, type_ident, - matches_so_far); + matching, + matches_so_far, + match_count + 1); matches_so_far.pop(); - - let arm_block = build::mk_simple_block(cx, span, arm_expr); - let arm = ast::arm { - pats: ~[ pattern ], - guard: None, - body: arm_block - }; + let arm = mk_arm!(pattern, arm_expr); arms.push(arm); - } + if enum_def.variants.len() > 1 { + // _ match, if necessary + let wild_pat = @ast::pat { + id: cx.next_id(), + node: pat_wild, + span: span + }; + + let wild_expr = self.call_substructure_method(cx, span, type_ident, + &EnumNonMatching(~[])); + let wild_arm = mk_arm!(wild_pat, wild_expr); + arms.push(wild_arm); + } + } else { + // create an arm matching on each variant + for enum_def.variants.eachi |index, variant| { + let pattern = create_enum_variant_pattern(cx, span, + variant, + current_match_str); + + let idents = do vec::build |push| { + for each_variant_arg_ident(cx, span, variant) |i, field_id| { + let id = cx.ident_of(fmt!("%s_%u", current_match_str, i)); + push((field_id, build::mk_path(cx, span, ~[ id ]))); + } + }; + + matches_so_far.push((index, *variant, idents)); + let new_matching = + match matching { + _ if match_count == 0 => Some(index), + Some(i) if index == i => Some(i), + _ => None + }; + let arm_expr = self.build_enum_match(cx, span, + enum_def, + type_ident, + new_matching, + matches_so_far, + match_count + 1); + matches_so_far.pop(); + + let arm = mk_arm!(pattern, arm_expr); + arms.push(arm); + } + } let deref_expr = build::mk_unary(cx, span, deref, build::mk_path(cx, span, ~[ current_match_ident ])); diff --git a/src/test/run-pass/deriving-cmp-generic-enum.rs b/src/test/run-pass/deriving-cmp-generic-enum.rs new file mode 100644 index 00000000000..a2651ddac3d --- /dev/null +++ b/src/test/run-pass/deriving-cmp-generic-enum.rs @@ -0,0 +1,50 @@ +// Copyright 2013 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[deriving(Eq, TotalEq, Ord, TotalOrd)] +enum E { + E0, + E1(T), + E2(T,T) +} + +pub fn main() { + let e0 = E0, e11 = E1(1), e12 = E1(2), e21 = E2(1,1), e22 = E2(1, 2); + + // in order for both Ord and TotalOrd + let es = [e0, e11, e12, e21, e22]; + + for es.eachi |i, e1| { + for es.eachi |j, e2| { + let ord = i.cmp(&j); + + let eq = i == j; + let lt = i < j, le = i <= j; + let gt = i > j, ge = i >= j; + + // Eq + assert_eq!(*e1 == *e2, eq); + assert_eq!(*e1 != *e2, !eq); + + // TotalEq + assert_eq!(e1.equals(e2), eq); + + // Ord + assert_eq!(*e1 < *e2, lt); + assert_eq!(*e1 > *e2, gt); + + assert_eq!(*e1 <= *e2, le); + assert_eq!(*e1 >= *e2, ge); + + // TotalOrd + assert_eq!(e1.cmp(e2), ord); + } + } +} diff --git a/src/test/run-pass/deriving-cmp-generic-struct-enum.rs b/src/test/run-pass/deriving-cmp-generic-struct-enum.rs new file mode 100644 index 00000000000..6f6e8d79d8b --- /dev/null +++ b/src/test/run-pass/deriving-cmp-generic-struct-enum.rs @@ -0,0 +1,52 @@ +// xfail-test #5530 + +// Copyright 2013 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[deriving(Eq, TotalEq, Ord, TotalOrd)] +enum ES { + ES1 { x: T }, + ES2 { x: T, y: T } +} + + +pub fn main() { + let es11 = ES1 {x: 1}, es12 = ES1 {x: 2}, es21 = ES2 {x: 1, y: 1}, es22 = ES2 {x: 1, y: 2}; + + // in order for both Ord and TotalOrd + let ess = [es11, es12, es21, es22]; + + for ess.eachi |i, es1| { + for ess.eachi |j, es2| { + let ord = i.cmp(&j); + + let eq = i == j; + let lt = i < j, le = i <= j; + let gt = i > j, ge = i >= j; + + // Eq + assert_eq!(*es1 == *es2, eq); + assert_eq!(*es1 != *es2, !eq); + + // TotalEq + assert_eq!(es1.equals(es2), eq); + + // Ord + assert_eq!(*es1 < *es2, lt); + assert_eq!(*es1 > *es2, gt); + + assert_eq!(*es1 <= *es2, le); + assert_eq!(*es1 >= *es2, ge); + + // TotalOrd + assert_eq!(es1.cmp(es2), ord); + } + } +} \ No newline at end of file diff --git a/src/test/run-pass/deriving-cmp-generic-struct.rs b/src/test/run-pass/deriving-cmp-generic-struct.rs new file mode 100644 index 00000000000..bd3e02ba29b --- /dev/null +++ b/src/test/run-pass/deriving-cmp-generic-struct.rs @@ -0,0 +1,49 @@ +// Copyright 2013 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[deriving(Eq, TotalEq, Ord, TotalOrd)] +struct S { + x: T, + y: T +} + +pub fn main() { + let s1 = S {x: 1, y: 1}, s2 = S {x: 1, y: 2}; + + // in order for both Ord and TotalOrd + let ss = [s1, s2]; + + for ss.eachi |i, s1| { + for ss.eachi |j, s2| { + let ord = i.cmp(&j); + + let eq = i == j; + let lt = i < j, le = i <= j; + let gt = i > j, ge = i >= j; + + // Eq + assert_eq!(*s1 == *s2, eq); + assert_eq!(*s1 != *s2, !eq); + + // TotalEq + assert_eq!(s1.equals(s2), eq); + + // Ord + assert_eq!(*s1 < *s2, lt); + assert_eq!(*s1 > *s2, gt); + + assert_eq!(*s1 <= *s2, le); + assert_eq!(*s1 >= *s2, ge); + + // TotalOrd + assert_eq!(s1.cmp(s2), ord); + } + } +} \ No newline at end of file diff --git a/src/test/run-pass/deriving-cmp-generic-tuple-struct.rs b/src/test/run-pass/deriving-cmp-generic-tuple-struct.rs new file mode 100644 index 00000000000..733b19a9ae2 --- /dev/null +++ b/src/test/run-pass/deriving-cmp-generic-tuple-struct.rs @@ -0,0 +1,47 @@ +// Copyright 2013 The Rust Project Developers. See the COPYRIGHT +// file at the top-level directory of this distribution and at +// http://rust-lang.org/COPYRIGHT. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +#[deriving(Eq, TotalEq, Ord, TotalOrd)] +struct TS(T,T); + + +pub fn main() { + let ts1 = TS(1, 1), ts2 = TS(1,2); + + // in order for both Ord and TotalOrd + let tss = [ts1, ts2]; + + for tss.eachi |i, ts1| { + for tss.eachi |j, ts2| { + let ord = i.cmp(&j); + + let eq = i == j; + let lt = i < j, le = i <= j; + let gt = i > j, ge = i >= j; + + // Eq + assert_eq!(*ts1 == *ts2, eq); + assert_eq!(*ts1 != *ts2, !eq); + + // TotalEq + assert_eq!(ts1.equals(ts2), eq); + + // Ord + assert_eq!(*ts1 < *ts2, lt); + assert_eq!(*ts1 > *ts2, gt); + + assert_eq!(*ts1 <= *ts2, le); + assert_eq!(*ts1 >= *ts2, ge); + + // TotalOrd + assert_eq!(ts1.cmp(ts2), ord); + } + } +} \ No newline at end of file diff --git a/src/test/run-pass/deriving-cmp.rs b/src/test/run-pass/deriving-cmp.rs deleted file mode 100644 index 56968fc1210..00000000000 --- a/src/test/run-pass/deriving-cmp.rs +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2013 The Rust Project Developers. See the COPYRIGHT -// file at the top-level directory of this distribution and at -// http://rust-lang.org/COPYRIGHT. -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -#[deriving(Eq, TotalEq, Ord, TotalOrd)] -struct S { - x: T, - y: T -} - -#[deriving(Eq, TotalEq, Ord, TotalOrd)] -struct TS(T,T); - -#[deriving(Eq, TotalEq, Ord, TotalOrd)] -enum E { - E0, - E1(T), - E2(T,T) -} - -#[deriving(Eq, TotalEq, Ord, TotalOrd)] -enum ES { - ES1 { x: T }, - ES2 { x: T, y: T } -} - - -pub fn main() { - let s1 = S {x: 1, y: 1}, s2 = S {x: 1, y: 2}; - let ts1 = TS(1, 1), ts2 = TS(1,2); - let e0 = E0, e11 = E1(1), e12 = E1(2), e21 = E2(1,1), e22 = E2(1, 2); - let es11 = ES1 {x: 1}, es12 = ES1 {x: 2}, es21 = ES2 {x: 1, y: 1}, es22 = ES2 {x: 1, y: 2}; - - test([s1, s2]); - test([ts1, ts2]); - test([e0, e11, e12, e21, e22]); - test([es11, es12, es21, es22]); -} - -fn test(ts: &[T]) { - // compare each element against all other elements. The list - // should be in sorted order, so that if i < j, then ts[i] < - // ts[j], etc. - for ts.eachi |i, t1| { - for ts.eachi |j, t2| { - let ord = i.cmp(&j); - - let eq = i == j; - let lt = i < j, le = i <= j; - let gt = i > j, ge = i >= j; - - // Eq - assert_eq!(*t1 == *t2, eq); - - // TotalEq - assert_eq!(t1.equals(t2), eq); - - // Ord - assert_eq!(*t1 < *t2, lt); - assert_eq!(*t1 > *t2, gt); - - assert_eq!(*t1 <= *t2, le); - assert_eq!(*t1 >= *t2, ge); - - // TotalOrd - assert_eq!(t1.cmp(t2), ord); - } - } -} \ No newline at end of file