impl the proper partial order between fn types

This commit is contained in:
Niko Matsakis 2011-12-16 13:50:22 -08:00
parent 1bc6e72b97
commit 98cbbbb642
3 changed files with 52 additions and 39 deletions

View File

@ -1918,20 +1918,19 @@ mod unify {
actual_inputs: [arg], actual_output: t, actual_inputs: [arg], actual_output: t,
variance: variance) -> variance: variance) ->
fn_common_res { fn_common_res {
let expected_len = vec::len::<arg>(expected_inputs); if !vec::same_length(expected_inputs, actual_inputs) {
let actual_len = vec::len::<arg>(actual_inputs);
if expected_len != actual_len {
ret fn_common_res_err(ures_err(terr_arg_count)); ret fn_common_res_err(ures_err(terr_arg_count));
} }
// TODO: as above, we should have an iter2 iterator.
let result_ins: [arg] = []; // Would use vec::map2(), but for the need to return in case of
let i = 0u; // error:
while i < expected_len { let i = 0u, n = vec::len(expected_inputs);
let result_ins = [];
while i < n {
let expected_input = expected_inputs[i]; let expected_input = expected_inputs[i];
let actual_input = actual_inputs[i]; let actual_input = actual_inputs[i];
// Unify the result modes.
// Unify the result modes.
let result_mode = if expected_input.mode == ast::mode_infer { let result_mode = if expected_input.mode == ast::mode_infer {
actual_input.mode actual_input.mode
} else if actual_input.mode == ast::mode_infer { } else if actual_input.mode == ast::mode_infer {
@ -1941,6 +1940,7 @@ mod unify {
(ures_err(terr_mode_mismatch(expected_input.mode, (ures_err(terr_mode_mismatch(expected_input.mode,
actual_input.mode))); actual_input.mode)));
} else { expected_input.mode }; } else { expected_input.mode };
// The variance changes (flips basically) when descending // The variance changes (flips basically) when descending
// into arguments of function types // into arguments of function types
let result = unify_step( let result = unify_step(
@ -1949,11 +1949,11 @@ mod unify {
alt result { alt result {
ures_ok(rty) { result_ins += [{mode: result_mode, ty: rty}]; } ures_ok(rty) { result_ins += [{mode: result_mode, ty: rty}]; }
_ { ret fn_common_res_err(result); } _ { ret fn_common_res_err(result); }
} };
i += 1u; i += 1u;
} }
// Check the output.
// Check the output.
let result = unify_step(cx, expected_output, actual_output, variance); let result = unify_step(cx, expected_output, actual_output, variance);
alt result { alt result {
ures_ok(rty) { ret fn_common_res_ok(result_ins, rty); } ures_ok(rty) { ret fn_common_res_ok(result_ins, rty); }
@ -1962,38 +1962,33 @@ mod unify {
} }
fn unify_fn_proto(e_proto: ast::proto, a_proto: ast::proto, fn unify_fn_proto(e_proto: ast::proto, a_proto: ast::proto,
variance: variance) -> option::t<result> { variance: variance) -> option::t<result> {
fn rank(proto: ast::proto) -> int { // Prototypes form a diamond-shaped partial order:
ret alt proto { //
ast::proto_block. { 0 } // block
ast::proto_shared(_) { 1 } // ^ ^
ast::proto_send. { 2 } // shared send
ast::proto_bare. { 3 } // ^ ^
// bare
//
// where "^" means "subtype of" (forgive the abuse of the term
// subtype).
fn sub_proto(p_sub: ast::proto, p_sup: ast::proto) -> bool {
ret alt (p_sub, p_sup) {
(_, ast::proto_block.) { true }
(ast::proto_bare., _) { true }
// Equal prototypes (modulo sugar) are always subprotos:
(ast::proto_shared(_), ast::proto_shared(_)) { true }
(_, _) { p_sub == p_sup }
}; };
} }
fn gt(e_proto: ast::proto, a_proto: ast::proto) -> bool { ret alt variance {
ret rank(e_proto) > rank(a_proto); invariant. when e_proto == a_proto { none }
} covariant. when sub_proto(a_proto, e_proto) { none }
contravariant. when sub_proto(e_proto, a_proto) { none }
ret if e_proto == a_proto { _ { some(ures_err(terr_mismatch)) }
none };
} else if variance == invariant {
some(ures_err(terr_mismatch))
} else if variance == covariant {
if gt(e_proto, a_proto) {
some(ures_err(terr_mismatch))
} else {
none
}
} else if variance == contravariant {
if gt(a_proto, e_proto) {
some(ures_err(terr_mismatch))
} else {
none
}
} else {
fail
}
} }
fn unify_fn(cx: @ctxt, e_proto: ast::proto, a_proto: ast::proto, fn unify_fn(cx: @ctxt, e_proto: ast::proto, a_proto: ast::proto,
expected: t, actual: t, expected_inputs: [arg], expected: t, actual: t, expected_inputs: [arg],

View File

@ -0,0 +1,10 @@
// error-pattern: mismatched types: expected lambda(++uint) -> uint
fn test(f: lambda(uint) -> uint) -> uint {
ret f(22u);
}
fn main() {
let f = sendfn(x: uint) -> uint { ret 4u; };
log test(f);
}

View File

@ -0,0 +1,8 @@
fn test(f: block(uint) -> uint) -> uint {
ret f(22u);
}
fn main() {
let y = test(sendfn(x: uint) -> uint { ret 4u * x; });
assert y == 88u;
}