Index: llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/trunk/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -726,6 +726,15 @@ } } +// APInts must be the same size for most operations, this helper +// function zero extends the shorter of the pair so that they match. +// We provide an Offset so that we can create bitwidths that won't overflow. +static void zeroExtendToMatch(APInt &LHS, APInt &RHS, unsigned Offset = 0) { + unsigned Bits = Offset + std::max(LHS.getBitWidth(), RHS.getBitWidth()); + LHS = LHS.zextOrSelf(Bits); + RHS = RHS.zextOrSelf(Bits); +} + // Return true if this node is a setcc, or is a select_cc // that selects between the target values used for true and false, making it // equivalent to a setcc. Also, set the incoming LHS, RHS, and CC references to @@ -4464,13 +4473,18 @@ // fold (shl (shl x, c1), c2) -> 0 or (shl x, (add c1, c2)) if (N1C && N0.getOpcode() == ISD::SHL) { if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { - uint64_t c1 = N0C1->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); SDLoc DL(N); - if (c1 + c2 >= OpSizeInBits) + APInt c1 = N0C1->getAPIntValue(); + APInt c2 = N1C->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + + APInt Sum = c1 + c2; + if (Sum.uge(OpSizeInBits)) return DAG.getConstant(0, DL, VT); - return DAG.getNode(ISD::SHL, DL, VT, N0.getOperand(0), - DAG.getConstant(c1 + c2, DL, N1.getValueType())); + + return DAG.getNode( + ISD::SHL, DL, VT, N0.getOperand(0), + DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType())); } } @@ -4656,13 +4670,19 @@ // fold (sra (sra x, c1), c2) -> (sra x, (add c1, c2)) if (N1C && N0.getOpcode() == ISD::SRA) { - if (ConstantSDNode *C1 = isConstOrConstSplat(N0.getOperand(1))) { - unsigned Sum = N1C->getZExtValue() + C1->getZExtValue(); - if (Sum >= OpSizeInBits) - Sum = OpSizeInBits - 1; + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { SDLoc DL(N); - return DAG.getNode(ISD::SRA, DL, VT, N0.getOperand(0), - DAG.getConstant(Sum, DL, N1.getValueType())); + APInt c1 = N0C1->getAPIntValue(); + APInt c2 = N1C->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + + APInt Sum = c1 + c2; + if (Sum.uge(OpSizeInBits)) + Sum = APInt(OpSizeInBits, OpSizeInBits - 1); + + return DAG.getNode( + ISD::SRA, DL, VT, N0.getOperand(0), + DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType())); } } @@ -4790,14 +4810,19 @@ // fold (srl (srl x, c1), c2) -> 0 or (srl x, (add c1, c2)) if (N1C && N0.getOpcode() == ISD::SRL) { - if (ConstantSDNode *N01C = isConstOrConstSplat(N0.getOperand(1))) { - uint64_t c1 = N01C->getZExtValue(); - uint64_t c2 = N1C->getZExtValue(); + if (ConstantSDNode *N0C1 = isConstOrConstSplat(N0.getOperand(1))) { SDLoc DL(N); - if (c1 + c2 >= OpSizeInBits) + APInt c1 = N0C1->getAPIntValue(); + APInt c2 = N1C->getAPIntValue(); + zeroExtendToMatch(c1, c2, 1 /* Overflow Bit */); + + APInt Sum = c1 + c2; + if (Sum.uge(OpSizeInBits)) return DAG.getConstant(0, DL, VT); - return DAG.getNode(ISD::SRL, DL, VT, N0.getOperand(0), - DAG.getConstant(c1 + c2, DL, N1.getValueType())); + + return DAG.getNode( + ISD::SRL, DL, VT, N0.getOperand(0), + DAG.getConstant(Sum.getZExtValue(), DL, N1.getValueType())); } } Index: llvm/trunk/test/CodeGen/X86/shift-i128.ll =================================================================== --- llvm/trunk/test/CodeGen/X86/shift-i128.ll +++ llvm/trunk/test/CodeGen/X86/shift-i128.ll @@ -92,3 +92,27 @@ store <2 x i128> %0, <2 x i128>* %r, align 16 ret void } + +define void @test_lshr_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind { +entry: + %0 = lshr <2 x i128> %x, + %1 = lshr <2 x i128> %0, + store <2 x i128> %1, <2 x i128>* %r, align 16 + ret void +} + +define void @test_ashr_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind { +entry: + %0 = ashr <2 x i128> %x, + %1 = ashr <2 x i128> %0, + store <2 x i128> %1, <2 x i128>* %r, align 16 + ret void +} + +define void @test_shl_v2i128_outofrange_sum(<2 x i128> %x, <2 x i128>* nocapture %r) nounwind { +entry: + %0 = shl <2 x i128> %x, + %1 = shl <2 x i128> %0, + store <2 x i128> %1, <2 x i128>* %r, align 16 + ret void +}