diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1733,6 +1733,161 @@ } } + // Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z)) + Value *A, *B, *C, *D; + if ((match(Op0, m_Mul(m_Value(A), m_Value(B))) || + match(Op0, m_Shl(m_Value(A), m_Value(B)))) && + (match(Op1, m_Mul(m_Value(C), m_Value(D))) || + match(Op1, m_Shl(m_Value(C), m_Value(D))))) { + Value *X, *Y, *Z; + X = nullptr; + // Do this by hand as opposed to using m_Specific because either A/B (or + // C/D) can be our X. + if (A == C || A == D) { + X = A; + Y = B; + Z = A == C ? D : C; + } else if (B == C || B == D) { + X = B; + Y = A; + Z = B == C ? D : C; + } + + // BO0 = X * Y + BinaryOperator *BO0 = dyn_cast(Op0); + // BO1 = X * Z + BinaryOperator *BO1 = dyn_cast(Op1); + + // If X is constant 1, then we avoid both in the mul and shl case. + Constant *CX = X ? dyn_cast(X) : nullptr; + if (X && BO0 && BO1 && (!CX || !CX->isOneValue())) { + bool NSW0 = BO0->hasNoSignedWrap(); + bool NSW1 = BO1->hasNoSignedWrap(); + bool NUW0 = BO0->hasNoUnsignedWrap(); + bool NUW1 = BO1->hasNoUnsignedWrap(); + + Type *Ty = I.getType(); + bool IsSigned = I.getOpcode() == Instruction::SRem; + + // Check constant folds first. + ConstantInt *CY = nullptr; + ConstantInt *CZ = nullptr; + if (Ty->isVectorTy()) { + auto *VCY = dyn_cast(Y); + auto *VCZ = dyn_cast(Z); + if (VCY && VCZ) { + VCY = VCY->getSplatValue(); + VCZ = VCZ->getSplatValue(); + if (VCY && VCZ) { + CY = dyn_cast(VCY); + CZ = dyn_cast(VCZ); + } + } + } else { + CY = dyn_cast(Y); + CZ = dyn_cast(Z); + } + + if (CY && CZ) { + APInt AY = CY->getValue(); + APInt AZ = CZ->getValue(); + + // Just treat the shifts as mul, we may end up returning a mul by power + // of 2 but that will be cleaned up later. + if (BO0->getOpcode() == Instruction::Shl) + AY = APInt(AY.getBitWidth(), 1) << AY; + if (BO1->getOpcode() == Instruction::Shl) + AZ = APInt(AZ.getBitWidth(), 1) << AZ; + + APInt RemYZ = IsSigned ? AY.srem(AZ) : AY.urem(AZ); + + // (urem (mul nuw X, Y), (mul X, Z)) + // if (urem Y, Z) == 0 + // -> 0 + // (srem (mul nsw X, Y), (mul X, Z)) + // if (srem Y, Z) == 0 + // -> 0 + if (RemYZ.isZero() && (IsSigned ? NSW0 : NUW0)) + return replaceInstUsesWith(I, ConstantInt::getNullValue(Ty)); + + // (urem (mul X, Y), (mul nuw X, Z)) + // if (urem Y, Z) == Y + // -> (mul nuw X, Y) + // (srem (mul X, Y), (mul nsw X, Z)) + // if (srem Y, Z) == Y + // -> (mul nsw X, Y) + if (RemYZ == AY && (IsSigned ? NSW1 : NUW1)) { + // We are returning Op0 essentially but we can also add no wrap flags. + BinaryOperator *BO = + BinaryOperator::CreateMul(X, ConstantInt::get(Ty, AY)); + // We can add nsw/nuw if remainder op is signed/unsigned, also we + // can copy any overflow flags from Op0. + if (IsSigned || NSW0) + BO->setHasNoSignedWrap(); + if (!IsSigned || NUW0) + BO->setHasNoUnsignedWrap(); + return BO; + } + + // (urem (mul nuw X, Y), (mul X, Z)) + // if Y >= Z + // -> (mul nuw nsw X, (urem Y, Z)) + // (urem (mul nsw X, Y), (mul nsw X, Z)) + // if Y >= Z + // -> (mul nsw X, (srem Y, Z)) + if (AY.uge(AZ) && (IsSigned ? (NSW0 && NSW1) : NUW0)) { + BinaryOperator *BO = + BinaryOperator::CreateMul(X, ConstantInt::get(Ty, RemYZ)); + BO->setHasNoSignedWrap(); + // We can add nuw if remainder is not signed or we have nuw on Op0 + if (!IsSigned || NUW0) + BO->setHasNoUnsignedWrap(); + return BO; + } + } + + // If Op0/Op1 only have one use (here), then we will be saving several + // instructions at no real cost so proceed. Otherwise we will either being + // going even (either Op0 or Op1 has one use but not both), or adding an + // additional instruction (neither have one use) so only do so if Y and Z + // are constant. As well, don't replace for urem Op0 is shl and Op1 is mul + // or for srem if either Op0/Op1 are shl. In the urem case we go even, and + // the srem case slightly worse, so there is no real benefit. + + // NB: It may be beneficial to do this if we have X << Z even if there are + // multiple uses of Op0/Op1 as it will eliminate the urem (urem of a power + // of 2 is converted to add/and) and urem is pretty expensive (maybe more + // sense in DAGCombiner). + if ((CY && CZ) || (Op0->hasOneUse() && Op1->hasOneUse() && + (IsSigned ? (BO0->getOpcode() != Instruction::Shl && + BO1->getOpcode() != Instruction::Shl) + : (BO0->getOpcode() != Instruction::Shl || + BO1->getOpcode() == Instruction::Shl)))) { + // (urem (mul nuw X, Y), (mul nuw X, Z)) + // -> (mul nuw X, (urem Y, Z)) + // (srem (mul nsw X, Y), (mul nsw nuw X, Z) + // -> (mul nsw X, (srem Y, Z)) + if (IsSigned ? (NSW0 && NSW1 && NUW1) : (NUW0 && NUW1)) { + // Convert the Shifts to multiplies, cleaned up elsewhere. + if (BO0->getOpcode() == Instruction::Shl) + Y = Builder.CreateShl(ConstantInt::get(Ty, 1), Y); + if (BO1->getOpcode() == Instruction::Shl) + Z = Builder.CreateShl(ConstantInt::get(Ty, 1), Z); + BinaryOperator *BO = + BinaryOperator::CreateMul(X, IsSigned ? Builder.CreateSRem(Y, Z) + : Builder.CreateURem(Y, Z)); + + // We can add extra flags based on signed/unsigned or existing flags. + if (IsSigned || NSW0 || NSW1) + BO->setHasNoSignedWrap(); + if (!IsSigned || (NUW0 && NUW1)) + BO->setHasNoUnsignedWrap(); + return BO; + } + } + } + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/urem-mul.ll b/llvm/test/Transforms/InstCombine/urem-mul.ll --- a/llvm/test/Transforms/InstCombine/urem-mul.ll +++ b/llvm/test/Transforms/InstCombine/urem-mul.ll @@ -31,10 +31,7 @@ define i8 @urem_CY_CZ_is_zero(i8 %X) { ; CHECK-LABEL: @urem_CY_CZ_is_zero( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 15 -; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 5 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: ret i8 0 ; %BO0 = mul nuw i8 %X, 15 %BO1 = mul i8 %X, 5 @@ -57,9 +54,7 @@ define i8 @urem_CY_CZ_is_BO0(i8 %X) { ; CHECK-LABEL: @urem_CY_CZ_is_BO0( -; CHECK-NEXT: [[BO0:%.*]] = mul i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = mul nuw i8 [[X]], 12 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul i8 %X, 3 @@ -70,9 +65,7 @@ define <2 x i8> @urem_CY_CZ_is_BO0_with_nsw_out(<2 x i8> %X) { ; CHECK-LABEL: @urem_CY_CZ_is_BO0_with_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nuw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = urem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl nsw <2 x i8> %X, @@ -83,9 +76,7 @@ define i8 @urem_CY_CZ_is_BO0_no_nsw_out(i8 %X) { ; CHECK-LABEL: @urem_CY_CZ_is_BO0_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 3 @@ -109,9 +100,7 @@ define i8 @urem_CY_CZ_is_mul_X_RemYZ(i8 %X) { ; CHECK-LABEL: @urem_CY_CZ_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 21 -; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 21 @@ -135,9 +124,8 @@ define i8 @urem_Y_Z_is_mul_X_RemYZ(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @urem_Y_Z_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = mul nuw i8 [[Z:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = urem i8 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, %Y @@ -148,9 +136,10 @@ define i8 @urem_CX_Y_Z_is_mul_X_RemYZ(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @urem_CX_Y_Z_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[Y:%.*]], 10 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw i8 10, [[Z:%.*]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i8 -1, [[Z:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[TMP2]], 10 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 10, %Y @@ -161,9 +150,10 @@ define i8 @urem_Y_Z_is_mul_X_RemYZ_with_nsw_out1(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @urem_Y_Z_is_mul_X_RemYZ_with_nsw_out1( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw nsw i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = shl nuw i8 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i8 -1, [[Z:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[TMP2]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw nsw i8 %X, %Y @@ -174,9 +164,11 @@ define <2 x i8> @urem_Y_Z_is_mul_X_RemYZ_with_nsw_out2(<2 x i8> %X, <2 x i8> %Y, <2 x i8> %Z) { ; CHECK-LABEL: @urem_Y_Z_is_mul_X_RemYZ_with_nsw_out2( -; CHECK-NEXT: [[BO0:%.*]] = shl nuw <2 x i8> [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw <2 x i8> [[Z:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = urem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = shl nuw <2 x i8> , [[Y:%.*]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw <2 x i8> , [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = xor <2 x i8> [[NOTMASK]], +; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i8> [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw <2 x i8> [[TMP3]], [[X:%.*]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl nuw <2 x i8> %Y, %X @@ -229,10 +221,7 @@ ;; Signed Verions define i8 @srem_CY_CZ_is_zero(i8 %X) { ; CHECK-LABEL: @srem_CY_CZ_is_zero( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 9 -; CHECK-NEXT: [[BO1:%.*]] = mul i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: ret i8 0 ; %BO0 = mul nsw i8 %X, 9 %BO1 = mul i8 %X, 3 @@ -255,9 +244,7 @@ define <2 x i8> @srem_CY_CZ_is_BO0(<2 x i8> %X) { ; CHECK-LABEL: @srem_CY_CZ_is_BO0( -; CHECK-NEXT: [[BO0:%.*]] = shl <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl <2 x i8> %X, @@ -268,9 +255,7 @@ define i8 @srem_CY_CZ_is_BO0_with_nuw_out(i8 %X) { ; CHECK-LABEL: @srem_CY_CZ_is_BO0_with_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 5 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 15 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[X:%.*]], 5 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 5 @@ -281,9 +266,7 @@ define i8 @srem_CY_CZ_is_BO0_no_nsw_out(i8 %X) { ; CHECK-LABEL: @srem_CY_CZ_is_BO0_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 5 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 4 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nsw i8 [[X:%.*]], 5 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw i8 %X, 5 @@ -307,9 +290,7 @@ define i8 @srem_CY_CZ_is_mul_X_RemYZ(i8 %X) { ; CHECK-LABEL: @srem_CY_CZ_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw i8 [[X:%.*]], 1 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = shl nsw i8 %X, 3 @@ -320,9 +301,7 @@ define i8 @srem_CY_CZ_is_mul_X_RemYZ_with_nuw_out(i8 %X) { ; CHECK-LABEL: @srem_CY_CZ_is_mul_X_RemYZ_with_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw nsw i8 [[X:%.*]], 10 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw i8 [[X:%.*]], 2 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw nuw i8 %X, 10 @@ -333,9 +312,7 @@ define <2 x i8> @srem_CY_CZ_is_mul_X_RemYZ_no_nuw_out(<2 x i8> %X) { ; CHECK-LABEL: @srem_CY_CZ_is_mul_X_RemYZ_no_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = mul nsw <2 x i8> %X, @@ -372,9 +349,8 @@ define i8 @srem_Y_Z_is_mul_X_RemYZ(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @srem_Y_Z_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = mul nuw nsw i8 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i8 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nsw i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw i8 %Y, %X @@ -385,9 +361,8 @@ define i8 @srem_Y_Z_is_mul_X_RemYZ_with_nuw_out(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @srem_Y_Z_is_mul_X_RemYZ_with_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw nsw i8 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = mul nuw nsw i8 [[Z:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i8 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw nuw i8 %Y, %X