diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1759,14 +1759,20 @@ return nullptr; } -// Variety of transform for (urem/srem (mul/shl X, Y), (mul/shl X, Z)) +// Variety of transform for: +// (urem/srem (mul X, Y), (mul X, Z)) +// (urem/srem (shl X, Y), (shl X, Z)) +// (urem/srem (shl Y, X), (shl Z, X)) +// NB: The shift cases are really just extensions of the mul case. We treat +// shift as Val * (1 << Amt). static Instruction *simplifyIRemMulShl(BinaryOperator &I, InstCombinerImpl &IC) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr; APInt Y, Z; + bool ShiftByX = false; // If V is not nullptr, it will be matched using m_Specific. - auto MatchShiftOrMul = [](Value *Op, Value *&V, APInt &C) -> bool { + auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool { const APInt *Tmp = nullptr; if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) || (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp))))) @@ -1774,11 +1780,34 @@ else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) || (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp))))) C = APInt(Tmp->getBitWidth(), 1) << *Tmp; - return Tmp != nullptr; + if (Tmp != nullptr) + return true; + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; + }; + + auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) || + (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) { + C = *Tmp; + return true; + } + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; }; - if (!MatchShiftOrMul(Op0, X, Y) || !MatchShiftOrMul(Op1, X, Z)) + if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) { + // pass + } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) { + ShiftByX = true; + } else { return nullptr; + } bool IsSRem = I.getOpcode() == Instruction::SRem; @@ -1796,6 +1825,17 @@ if (RemYZ.isZero() && BO0NoWrap) return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); + // Helper function to emit either (RemSimplificationC << X) or + // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as + // (shl V, X) or (mul V, X) respectively. + auto CreateMulOrShift = + [&](const APInt &RemSimplificationC) -> BinaryOperator * { + Value *RemSimplification = + ConstantInt::get(I.getType(), RemSimplificationC); + return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X) + : BinaryOperator::CreateMul(X, RemSimplification); + }; + OverflowingBinaryOperator *BO1 = cast(Op1); bool BO1HasNSW = BO1->hasNoSignedWrap(); bool BO1HasNUW = BO1->hasNoUnsignedWrap(); @@ -1804,8 +1844,7 @@ // if (rem Y, Z) == Y // -> (mul nuw/nsw X, Y) if (RemYZ == Y && BO1NoWrap) { - BinaryOperator *BO = - BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), Y)); + BinaryOperator *BO = CreateMulOrShift(Y); // Copy any overflow flags from Op0. BO->setHasNoSignedWrap(IsSRem || BO0HasNSW); BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW); @@ -1816,8 +1855,7 @@ // if Y >= Z // -> (mul {nuw} nsw X, (rem Y, Z)) if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { - BinaryOperator *BO = - BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ)); + BinaryOperator *BO = CreateMulOrShift(RemYZ); BO->setHasNoSignedWrap(); BO->setHasNoUnsignedWrap(BO0HasNUW); return BO; diff --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll --- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll +++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll @@ -51,10 +51,7 @@ define i8 @urem_XY_XZ_with_CY_rem_CZ_eq_0_with_shl(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_rem_CZ_eq_0_with_shl( -; CHECK-NEXT: [[BO0:%.*]] = shl nuw i8 15, [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = shl i8 5, [[X]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] -; CHECK-NEXT: ret i8 [[R]] +; CHECK-NEXT: ret i8 0 ; %BO0 = shl nuw i8 15, %X %BO1 = shl i8 5, %X @@ -88,9 +85,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ_with_shl(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_with_shl( -; CHECK-NEXT: [[BO0:%.*]] = shl i8 3, [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = shl nuw i8 12, [[X]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw i8 3, [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = shl i8 3, %X @@ -309,9 +304,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out_with_shl(<2 x i8> %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_with_nuw_out_with_shl( -; CHECK-NEXT: [[BO0:%.*]] = shl nuw <2 x i8> , [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = shl nsw <2 x i8> , [[X]] -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw <2 x i8> , [[X:%.*]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl nuw <2 x i8> , %X