diff --git a/src/comp/middle/trans/base.rs b/src/comp/middle/trans/base.rs index a86681da8ce..4ec06b83772 100644 --- a/src/comp/middle/trans/base.rs +++ b/src/comp/middle/trans/base.rs @@ -1641,6 +1641,44 @@ fn trans_compare(cx: block, op: ast::binop, lhs: ValueRef, } } +fn cast_shift_expr_rhs(cx: block, op: ast::binop, + lhs: ValueRef, rhs: ValueRef) -> ValueRef { + cast_shift_rhs(op, lhs, rhs, + bind Trunc(cx, _, _), bind ZExt(cx, _, _)) +} + +fn cast_shift_const_rhs(op: ast::binop, + lhs: ValueRef, rhs: ValueRef) -> ValueRef { + cast_shift_rhs(op, lhs, rhs, + llvm::LLVMConstTrunc, llvm::LLVMConstZExt) +} + +fn cast_shift_rhs(op: ast::binop, + lhs: ValueRef, rhs: ValueRef, + trunc: fn(ValueRef, TypeRef) -> ValueRef, + zext: fn(ValueRef, TypeRef) -> ValueRef + ) -> ValueRef { + + // Shifts may have any size int on the rhs + if ast_util::is_shift_binop(op) { + let rhs_llty = val_ty(rhs); + let lhs_llty = val_ty(lhs); + let rhs_sz = llvm::LLVMGetIntTypeWidth(rhs_llty); + let lhs_sz = llvm::LLVMGetIntTypeWidth(lhs_llty); + if lhs_sz < rhs_sz { + trunc(rhs, lhs_llty) + } else if lhs_sz > rhs_sz { + // FIXME: If shifting by negative values becomes not undefined + // then this is wrong. + zext(rhs, lhs_llty) + } else { + rhs + } + } else { + rhs + } +} + // Important to get types for both lhs and rhs, because one might be _|_ // and the other not. fn trans_eager_binop(cx: block, op: ast::binop, lhs: ValueRef, @@ -1651,6 +1689,8 @@ fn trans_eager_binop(cx: block, op: ast::binop, lhs: ValueRef, if ty::type_is_bot(intype) { intype = rhs_t; } let is_float = ty::type_is_fp(intype); + let rhs = cast_shift_expr_rhs(cx, op, lhs, rhs); + if op == ast::add && ty::type_is_sequence(intype) { ret tvec::trans_add(cx, intype, lhs, rhs, dest); } @@ -4059,6 +4099,9 @@ fn trans_const_expr(cx: crate_ctxt, e: @ast::expr) -> ValueRef { ast::expr_binary(b, e1, e2) { let te1 = trans_const_expr(cx, e1); let te2 = trans_const_expr(cx, e2); + + let te2 = cast_shift_const_rhs(b, te1, te2); + /* Neither type is bottom, and we expect them to be unified already, * so the following is safe. */ let ty = ty::expr_ty(cx.tcx, e1); diff --git a/src/comp/middle/typeck.rs b/src/comp/middle/typeck.rs index 75f3deb670f..88753b06091 100644 --- a/src/comp/middle/typeck.rs +++ b/src/comp/middle/typeck.rs @@ -2117,8 +2117,17 @@ fn check_expr_with_unifier(fcx: @fn_ctxt, expr: @ast::expr, unify: unifier, let lhs_t = next_ty_var(fcx); bot = check_expr_with(fcx, lhs, lhs_t); - let rhs_bot = check_expr_with(fcx, rhs, lhs_t); + let rhs_bot = if !ast_util::is_shift_binop(binop) { + check_expr_with(fcx, rhs, lhs_t) + } else { + let rhs_bot = check_expr(fcx, rhs); + let rhs_t = expr_ty(tcx, rhs); + require_integral(fcx, rhs.span, rhs_t); + rhs_bot + }; + if !ast_util::lazy_binop(binop) { bot |= rhs_bot; } + let result = check_binop(fcx, expr, lhs_t, binop, rhs); write_ty(tcx, id, result); } @@ -2572,13 +2581,6 @@ fn check_expr_with_unifier(fcx: @fn_ctxt, expr: @ast::expr, unify: unifier, let base_t = do_autoderef(fcx, expr.span, raw_base_t); bot |= check_expr(fcx, idx); let idx_t = expr_ty(tcx, idx); - fn require_integral(fcx: @fn_ctxt, sp: span, t: ty::t) { - if !type_is_integral(fcx, sp, t) { - fcx.ccx.tcx.sess.span_err(sp, "mismatched types: expected \ - `integer` but found `" - + ty_to_str(fcx.ccx.tcx, t) + "`"); - } - } alt structure_of(fcx, expr.span, base_t) { ty::ty_vec(mt) { require_integral(fcx, idx.span, idx_t); @@ -2612,6 +2614,14 @@ fn check_expr_with_unifier(fcx: @fn_ctxt, expr: @ast::expr, unify: unifier, ret bot; } +fn require_integral(fcx: @fn_ctxt, sp: span, t: ty::t) { + if !type_is_integral(fcx, sp, t) { + fcx.ccx.tcx.sess.span_err(sp, "mismatched types: expected \ + `integer` but found `" + + ty_to_str(fcx.ccx.tcx, t) + "`"); + } +} + fn next_ty_var_id(fcx: @fn_ctxt) -> int { let id = *fcx.next_var_id; *fcx.next_var_id += 1; diff --git a/src/comp/syntax/ast_util.rs b/src/comp/syntax/ast_util.rs index b9d48cb5fb0..446f74c8680 100644 --- a/src/comp/syntax/ast_util.rs +++ b/src/comp/syntax/ast_util.rs @@ -64,6 +64,15 @@ pure fn lazy_binop(b: binop) -> bool { alt b { and { true } or { true } _ { false } } } +pure fn is_shift_binop(b: binop) -> bool { + alt b { + lsl { true } + lsr { true } + asr { true } + _ { false } + } +} + fn unop_to_str(op: unop) -> str { alt op { box(mt) { if mt == m_mutbl { ret "@mut "; } ret "@"; } diff --git a/src/test/run-pass/shift.rs b/src/test/run-pass/shift.rs new file mode 100644 index 00000000000..93b2c29f35a --- /dev/null +++ b/src/test/run-pass/shift.rs @@ -0,0 +1,86 @@ +// Testing shifts for various combinations of integers +// Issue #1570 + +fn main() { + test_misc(); + test_expr(); + test_const(); +} + +fn test_misc() { + assert 1 << 1i8 << 1u8 << 1i16 << 1 as char << 1u64 == 32; +} + +fn test_expr() { + let v10 = 10 as uint; + let v4 = 4 as u8; + let v2 = 2 as u8; + assert (v10 >> v2 == v2 as uint); + assert (v10 >>> v2 == v2 as uint); + assert (v10 << v4 == 160 as uint); + + let v10 = 10 as u8; + let v4 = 4 as uint; + let v2 = 2 as uint; + assert (v10 >> v2 == v2 as u8); + assert (v10 >>> v2 == v2 as u8); + assert (v10 << v4 == 160 as u8); + + let v10 = 10 as int; + let v4 = 4 as i8; + let v2 = 2 as i8; + assert (v10 >> v2 == v2 as int); + assert (v10 >>> v2 == v2 as int); + assert (v10 << v4 == 160 as int); + + let v10 = 10 as i8; + let v4 = 4 as int; + let v2 = 2 as int; + assert (v10 >> v2 == v2 as i8); + assert (v10 >>> v2 == v2 as i8); + assert (v10 << v4 == 160 as i8); + + let v10 = 10 as uint; + let v4 = 4 as int; + let v2 = 2 as int; + assert (v10 >> v2 == v2 as uint); + assert (v10 >>> v2 == v2 as uint); + assert (v10 << v4 == 160 as uint); +} + +fn test_const() { + const r1_1: uint = 10u >> 2u8; + const r2_1: uint = 10u >>> 2u8; + const r3_1: uint = 10u << 4u8; + assert r1_1 == 2 as uint; + assert r2_1 == 2 as uint; + assert r3_1 == 160 as uint; + + const r1_2: u8 = 10u8 >> 2u; + const r2_2: u8 = 10u8 >>> 2u; + const r3_2: u8 = 10u8 << 4u; + assert r1_2 == 2 as u8; + assert r2_2 == 2 as u8; + assert r3_2 == 160 as u8; + + const r1_3: int = 10 >> 2i8; + const r2_3: int = 10 >>> 2i8; + const r3_3: int = 10 << 4i8; + assert r1_3 == 2 as int; + assert r2_3 == 2 as int; + assert r3_3 == 160 as int; + + const r1_4: i8 = 10i8 >> 2; + const r2_4: i8 = 10i8 >>> 2; + const r3_4: i8 = 10i8 << 4; + assert r1_4 == 2 as i8; + assert r2_4 == 2 as i8; + assert r3_4 == 160 as i8; + + const r1_5: uint = 10u >> 2i8; + const r2_5: uint = 10u >>> 2i8; + const r3_5: uint = 10u << 4i8; + assert r1_5 == 2 as uint; + assert r2_5 == 2 as uint; + assert r3_5 == 160 as uint; +}