diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1681,6 +1681,114 @@ return nullptr; } +// Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z)) +static Instruction *simplifyIRemMulShl(BinaryOperator &I, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A, *B, *C, *D; + if (!(match(Op0, m_Mul(m_Value(A), m_Value(B))) || + match(Op0, m_Shl(m_Value(A), m_Value(B)))) || + !(match(Op1, m_Mul(m_Value(C), m_Value(D))) || + match(Op1, m_Shl(m_Value(C), m_Value(D))))) + return nullptr; + + Value *X = nullptr, *Y, *Z; + // Do this by hand as opposed to using m_Specific because either A/B (or + // C/D) can be our X. + if (A == C || A == D) { + X = A; + Y = B; + Z = A == C ? D : C; + } else if (B == C || B == D) { + X = B; + Y = A; + Z = B == C ? D : C; + } + + // BO0 = X * Y + auto *BO0 = dyn_cast(Op0); + // BO1 = X * Z + auto *BO1 = dyn_cast(Op1); + + if (!X || !BO0 || !BO1) + return nullptr; + + // If X is constant 1, then we avoid both in the mul and shl case. + auto *CX = dyn_cast(X); + if (CX && CX->isOneValue()) + return nullptr; + + auto *ConstY = dyn_cast(Y); + auto *ConstZ = dyn_cast(Z); + if (I.getType()->isVectorTy()) { + auto *VConstY = dyn_cast(Y); + auto *VConstZ = dyn_cast(Z); + if (VConstY && VConstZ) { + VConstY = VConstY->getSplatValue(); + VConstZ = VConstZ->getSplatValue(); + if (VConstY && VConstZ) { + ConstY = dyn_cast(VConstY); + ConstZ = dyn_cast(VConstZ); + } + } + } + + bool IsSRem = I.getOpcode() == Instruction::SRem; + + bool BO0HasNSW = BO0->hasNoSignedWrap(); + bool BO1HasNSW = BO1->hasNoSignedWrap(); + bool BO0HasNUW = BO0->hasNoUnsignedWrap(); + bool BO1HasNUW = BO1->hasNoUnsignedWrap(); + + if (!ConstY || !ConstZ) + return nullptr; + + APInt APIntY = ConstY->getValue(); + APInt APIntZ = ConstZ->getValue(); + + // Just treat the shifts as mul, we may end up returning a mul by power + // of 2 but that will be cleaned up later. + if (BO0->getOpcode() == Instruction::Shl) + APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY; + if (BO1->getOpcode() == Instruction::Shl) + APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ; + + APInt RemYZ = IsSRem ? APIntY.srem(APIntZ) : APIntY.urem(APIntZ); + + // (rem (mul nuw/nsw X, Y), (mul X, Z)) + // if (rem Y, Z) == 0 + // -> 0 + if (RemYZ.isZero() && (IsSRem ? BO0HasNSW : BO0HasNUW)) + return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); + + // (rem (mul X, Y), (mul nuw/nsw X, Z)) + // if (rem Y, Z) == Y + // -> (mul nuw/nsw X, Y) + if (RemYZ == APIntY && (IsSRem ? BO1HasNSW : BO1HasNUW)) { + BinaryOperator *BO = + BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), APIntY)); + // Copy any overflow flags from Op0. + if (IsSRem || BO0HasNSW) + BO->setHasNoSignedWrap(); + if (!IsSRem || BO0HasNUW) + BO->setHasNoUnsignedWrap(); + return BO; + } + + // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z)) + // if Y >= Z + // -> (mul {nuw} nsw X, (rem Y, Z)) + if (APIntY.uge(APIntZ) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { + BinaryOperator *BO = + BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ)); + BO->setHasNoSignedWrap(); + if (BO0HasNUW) + BO->setHasNoUnsignedWrap(); + return BO; + } + + return nullptr; +} + /// This function implements the transforms common to both integer remainder /// instructions (urem and srem). It is called by the visitors to those integer /// remainder instructions. @@ -1733,6 +1841,9 @@ } } + if (Instruction *R = simplifyIRemMulShl(I, *this)) + return R; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll --- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll +++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll @@ -31,10 +31,7 @@ define @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable( %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0_scalable( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw [[X:%.*]], shufflevector ( insertelement ( poison, i8 15, i64 0), poison, zeroinitializer) -; CHECK-NEXT: [[BO1:%.*]] = mul [[X]], shufflevector ( insertelement ( poison, i8 5, i64 0), poison, zeroinitializer) -; CHECK-NEXT: [[R:%.*]] = urem [[BO0]], [[BO1]] -; CHECK-NEXT: ret [[R]] +; CHECK-NEXT: ret zeroinitializer ; %BO0 = mul nuw %X, shufflevector( insertelement( poison, i8 15, i64 0) , poison, zeroinitializer) %BO1 = mul %X, shufflevector( insertelement( poison, i8 5, i64 0) , poison, zeroinitializer) @@ -44,10 +41,7 @@ define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 15 -; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 5 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: ret i8 0 ; %BO0 = mul nuw i8 %X, 15 %BO1 = mul i8 %X, 5 @@ -70,9 +64,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = mul i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = mul nuw i8 [[X]], 12 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul i8 %X, 3 @@ -83,9 +75,7 @@ define <2 x i8> @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out(<2 x i8> %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nuw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = urem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl nsw <2 x i8> %X, @@ -96,9 +86,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 3 @@ -122,9 +110,7 @@ define i8 @urem_XY_XZ_with_CY_gt_CZ(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_gt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 21 -; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 21 @@ -242,10 +228,7 @@ ;; Signed Verions define @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable( %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_rem_CZ_eq_0_scalable( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw [[X:%.*]], shufflevector ( insertelement ( poison, i8 15, i64 0), poison, zeroinitializer) -; CHECK-NEXT: [[BO1:%.*]] = mul [[X]], shufflevector ( insertelement ( poison, i8 5, i64 0), poison, zeroinitializer) -; CHECK-NEXT: [[R:%.*]] = srem [[BO0]], [[BO1]] -; CHECK-NEXT: ret [[R]] +; CHECK-NEXT: ret zeroinitializer ; %BO0 = mul nsw %X, shufflevector( insertelement( poison, i8 15, i64 0) , poison, zeroinitializer) %BO1 = mul %X, shufflevector( insertelement( poison, i8 5, i64 0) , poison, zeroinitializer) @@ -255,10 +238,7 @@ define i8 @srem_XY_XZ_with_CY_rem_CZ_eq_0(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_rem_CZ_eq_0( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 9 -; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: ret i8 0 ; %BO0 = mul nsw i8 %X, 9 %BO1 = mul i8 %X, 3 @@ -281,9 +261,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ(<2 x i8> %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = shl <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl <2 x i8> %X, @@ -294,9 +272,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 5 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 15 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 5 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 5 @@ -307,9 +283,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 5 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 4 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nsw i8 [[X:%.*]], 5 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw i8 %X, 5 @@ -333,9 +307,7 @@ define i8 @srem_XY_XZ_with_CY_gt_CZ(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw i8 [[X:%.*]], 1 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = shl nsw i8 %X, 3 @@ -346,9 +318,7 @@ define i8 @srem_XY_XZ_with_CY_gt_CZ_with_nuw_out(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ_with_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw nsw i8 [[X:%.*]], 10 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw i8 [[X:%.*]], 2 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw nuw i8 %X, 10 @@ -359,9 +329,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out(<2 x i8> %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = mul nsw <2 x i8> %X,