diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1762,12 +1762,22 @@ // Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z)) static Instruction *simplifyIRemMulShl(BinaryOperator &I, InstCombinerImpl &IC) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X; - const APInt *Y, *Z; - if (!(match(Op0, m_Mul(m_Value(X), m_APInt(Y))) && - match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z)))) && - !(match(Op0, m_Mul(m_APInt(Y), m_Value(X))) && - match(Op1, m_c_Mul(m_Specific(X), m_APInt(Z))))) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr; + APInt Y, Z; + + // If V is not nullptr, it will be matched using m_Specific. + auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp))))) + C = *Tmp; + else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp))))) + C = APInt(Tmp->getBitWidth(), 1) << *Tmp; + return Tmp != nullptr; + }; + + if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z)) return nullptr; bool IsSRem = I.getOpcode() == Instruction::SRem; @@ -1779,7 +1789,7 @@ bool BO0HasNUW = BO0->hasNoUnsignedWrap(); bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW; - APInt RemYZ = IsSRem ? Y->srem(*Z) : Y->urem(*Z); + APInt RemYZ = IsSRem ? Y.srem(Z) : Y.urem(Z); // (rem (mul nuw/nsw X, Y), (mul X, Z)) // if (rem Y, Z) == 0 // -> 0 @@ -1793,9 +1803,9 @@ // (rem (mul X, Y), (mul nuw/nsw X, Z)) // if (rem Y, Z) == Y // -> (mul nuw/nsw X, Y) - if (RemYZ == *Y && BO1NoWrap) { + if (RemYZ == Y && BO1NoWrap) { BinaryOperator *BO = - BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), *Y)); + BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), Y)); // Copy any overflow flags from Op0. BO->setHasNoSignedWrap(IsSRem || BO0HasNSW); BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW); @@ -1805,7 +1815,7 @@ // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z)) // if Y >= Z // -> (mul {nuw} nsw X, (rem Y, Z)) - if (Y->uge(*Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { + if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { BinaryOperator *BO = BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ)); BO->setHasNoSignedWrap(); diff --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll --- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll +++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll @@ -75,9 +75,7 @@ define <2 x i8> @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out(<2 x i8> %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nuw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = urem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl nsw <2 x i8> %X, @@ -88,9 +86,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 3 @@ -265,9 +261,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ(<2 x i8> %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = shl <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl <2 x i8> %X, @@ -289,9 +283,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 5 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 4 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nsw i8 [[X:%.*]], 5 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw i8 %X, 5 @@ -315,9 +307,7 @@ define i8 @srem_XY_XZ_with_CY_gt_CZ(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw i8 [[X:%.*]], 1 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = shl nsw i8 %X, 3 @@ -339,9 +329,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out(<2 x i8> %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = mul nsw <2 x i8> %X,