diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -2020,6 +2020,24 @@ return computeOverflowForSub(IsSigned, N0, N1) == OFK_Never; } + /// Determine if the result of the signed mul of 2 nodes can overflow. + OverflowKind computeOverflowForSignedMul(SDValue N0, SDValue N1) const; + + /// Determine if the result of the unsigned mul of 2 nodes can overflow. + OverflowKind computeOverflowForUnsignedMul(SDValue N0, SDValue N1) const; + + /// Determine if the result of the mul of 2 nodes can overflow. + OverflowKind computeOverflowForMul(bool IsSigned, SDValue N0, + SDValue N1) const { + return IsSigned ? computeOverflowForSignedMul(N0, N1) + : computeOverflowForUnsignedMul(N0, N1); + } + + /// Determine if the result of the mul of 2 nodes can never overflow. + bool willNotOverflowMul(bool IsSigned, SDValue N0, SDValue N1) const { + return computeOverflowForMul(IsSigned, N0, N1) == OFK_Never; + } + /// Test if the given value is known to have exactly one bit set. This differs /// from computeKnownBits in that it doesn't necessarily determine which bit /// is set. diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -5432,34 +5432,18 @@ return DAG.getNode(IsSigned ? ISD::SADDO : ISD::UADDO, DL, N->getVTList(), N0, N0); - if (IsSigned) { - // A 1 bit SMULO overflows if both inputs are 1. - if (VT.getScalarSizeInBits() == 1) { - SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1); - return CombineTo(N, And, - DAG.getSetCC(DL, CarryVT, And, - DAG.getConstant(0, DL, VT), ISD::SETNE)); - } - - // Multiplying n * m significant bits yields a result of n + m significant - // bits. If the total number of significant bits does not exceed the - // result bit width (minus 1), there is no overflow. - unsigned SignBits = DAG.ComputeNumSignBits(N0); - if (SignBits > 1) - SignBits += DAG.ComputeNumSignBits(N1); - if (SignBits > VT.getScalarSizeInBits() + 1) - return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1), - DAG.getConstant(0, DL, CarryVT)); - } else { - KnownBits N1Known = DAG.computeKnownBits(N1); - KnownBits N0Known = DAG.computeKnownBits(N0); - bool Overflow; - (void)N0Known.getMaxValue().umul_ov(N1Known.getMaxValue(), Overflow); - if (!Overflow) - return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1), - DAG.getConstant(0, DL, CarryVT)); + // A 1 bit SMULO overflows if both inputs are 1. + if (IsSigned && VT.getScalarSizeInBits() == 1) { + SDValue And = DAG.getNode(ISD::AND, DL, VT, N0, N1); + SDValue Cmp = DAG.getSetCC(DL, CarryVT, And, + DAG.getConstant(0, DL, VT), ISD::SETNE); + return CombineTo(N, And, Cmp); } + // If it cannot overflow, transform into a mul. + if (DAG.willNotOverflowMul(IsSigned, N0, N1)) + return CombineTo(N, DAG.getNode(ISD::MUL, DL, VT, N0, N1), + DAG.getConstant(0, DL, CarryVT)); return SDValue(); } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4099,6 +4099,49 @@ return OFK_Sometime; } +SelectionDAG::OverflowKind +SelectionDAG::computeOverflowForUnsignedMul(SDValue N0, SDValue N1) const { + // X * 0 and X * 1 never overflow. + if (isNullConstant(N1) || isOneConstant(N1)) + return OFK_Never; + + KnownBits N0Known = computeKnownBits(N0); + KnownBits N1Known = computeKnownBits(N1); + ConstantRange N0Range = ConstantRange::fromKnownBits(N0Known, false); + ConstantRange N1Range = ConstantRange::fromKnownBits(N1Known, false); + return mapOverflowResult(N0Range.unsignedMulMayOverflow(N1Range)); +} + +SelectionDAG::OverflowKind +SelectionDAG::computeOverflowForSignedMul(SDValue N0, SDValue N1) const { + // X * 0 and X * 1 never overflow. + if (isNullConstant(N1) || isOneConstant(N1)) + return OFK_Never; + + // Get the size of the result. + unsigned BitWidth = N0.getScalarValueSizeInBits(); + + // Sum of the sign bits. + unsigned SignBits = ComputeNumSignBits(N0) + ComputeNumSignBits(N1); + + // If we have enough sign bits, then there's no overflow. + if (SignBits > BitWidth + 1) + return OFK_Never; + + if (SignBits == BitWidth + 1) { + // The overflow occurs when the true multiplication of the + // the operands is the minimum negative number. + KnownBits N0Known = computeKnownBits(N0); + KnownBits N1Known = computeKnownBits(N1); + // If one of the operands is non-negative, then there's no + // overflow. + if (N0Known.isNonNegative() || N1Known.isNonNegative()) + return OFK_Never; + } + + return OFK_Sometime; +} + bool SelectionDAG::isKnownToBeAPowerOfTwo(SDValue Val, unsigned Depth) const { if (Depth >= MaxRecursionDepth) return false; // Limit search depth. diff --git a/llvm/test/CodeGen/X86/combine-mulo.ll b/llvm/test/CodeGen/X86/combine-mulo.ll --- a/llvm/test/CodeGen/X86/combine-mulo.ll +++ b/llvm/test/CodeGen/X86/combine-mulo.ll @@ -96,7 +96,7 @@ ; CHECK-NEXT: andl $4095, %edi # imm = 0xFFF ; CHECK-NEXT: andl $524287, %eax # imm = 0x7FFFF ; CHECK-NEXT: imull %edi, %eax -; CHECK-NEXT: seto %dl +; CHECK-NEXT: xorl %edx, %edx ; CHECK-NEXT: retq %aa = and i32 %a, 4095 ; 0xfff %bb = and i32 %b, 524287; 0x7ffff @@ -109,19 +109,8 @@ ; SSE: # %bb.0: ; SSE-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm0 ; SSE-NEXT: pand {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %xmm1 -; SSE-NEXT: pshufd {{.*#+}} xmm2 = xmm1[1,1,3,3] -; SSE-NEXT: pshufd {{.*#+}} xmm3 = xmm0[1,1,3,3] -; SSE-NEXT: pmuldq %xmm2, %xmm3 -; SSE-NEXT: movdqa %xmm0, %xmm2 -; SSE-NEXT: pmuldq %xmm1, %xmm2 -; SSE-NEXT: pshufd {{.*#+}} xmm2 = xmm2[1,1,3,3] -; SSE-NEXT: pblendw {{.*#+}} xmm2 = xmm2[0,1],xmm3[2,3],xmm2[4,5],xmm3[6,7] -; SSE-NEXT: pxor %xmm3, %xmm3 -; SSE-NEXT: pcmpeqd %xmm2, %xmm3 -; SSE-NEXT: pcmpeqd %xmm2, %xmm2 -; SSE-NEXT: pxor %xmm3, %xmm2 ; SSE-NEXT: pmulld %xmm1, %xmm0 -; SSE-NEXT: movdqa %xmm2, %xmm1 +; SSE-NEXT: pxor %xmm1, %xmm1 ; SSE-NEXT: retq ; ; AVX-LABEL: combine_vec_smul_nsw: @@ -129,18 +118,9 @@ ; AVX-NEXT: vpbroadcastd {{.*#+}} xmm2 = [4095,4095,4095,4095] ; AVX-NEXT: vpand %xmm2, %xmm0, %xmm0 ; AVX-NEXT: vpbroadcastd {{.*#+}} xmm2 = [524287,524287,524287,524287] -; AVX-NEXT: vpand %xmm2, %xmm1, %xmm2 -; AVX-NEXT: vpshufd {{.*#+}} xmm1 = xmm2[1,1,3,3] -; AVX-NEXT: vpshufd {{.*#+}} xmm3 = xmm0[1,1,3,3] -; AVX-NEXT: vpmuldq %xmm1, %xmm3, %xmm1 -; AVX-NEXT: vpmuldq %xmm2, %xmm0, %xmm3 -; AVX-NEXT: vpshufd {{.*#+}} xmm3 = xmm3[1,1,3,3] -; AVX-NEXT: vpblendd {{.*#+}} xmm1 = xmm3[0],xmm1[1],xmm3[2],xmm1[3] -; AVX-NEXT: vpxor %xmm3, %xmm3, %xmm3 -; AVX-NEXT: vpcmpeqd %xmm3, %xmm1, %xmm1 -; AVX-NEXT: vpcmpeqd %xmm3, %xmm3, %xmm3 -; AVX-NEXT: vpxor %xmm3, %xmm1, %xmm1 -; AVX-NEXT: vpmulld %xmm2, %xmm0, %xmm0 +; AVX-NEXT: vpand %xmm2, %xmm1, %xmm1 +; AVX-NEXT: vpmulld %xmm1, %xmm0, %xmm0 +; AVX-NEXT: vpxor %xmm1, %xmm1, %xmm1 ; AVX-NEXT: retq %aa = and <4 x i32> %a, %bb = and <4 x i32> %b,