diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1685,22 +1685,68 @@ static Instruction *simplifyIRemMulShl(BinaryOperator &I, InstCombinerImpl &IC) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A, *B, *C, *D; - if (!match(Op0, m_Mul(m_Value(A), m_Value(B))) || - !match(Op1, m_Mul(m_Value(C), m_Value(D)))) + if (!(match(Op0, m_Mul(m_Value(A), m_Value(B))) || + match(Op0, m_Shl(m_Value(A), m_Value(B)))) || + !(match(Op1, m_Mul(m_Value(C), m_Value(D))) || + match(Op1, m_Shl(m_Value(C), m_Value(D))))) return nullptr; Value *X = nullptr, *Y, *Z; + BinaryOperator *BO0 = cast(Op0); + BinaryOperator *BO1 = cast(Op1); + + bool ShiftX = false, ShiftY = false, ShiftZ = false; // Do this by hand as opposed to using m_Specific because either A/B (or // C/D) can be our X. + // TODO: There is a special case when `X == Y`, in which case we can choose + // whether we want to treat `X` as a shift or `Y` as a shift. At the moment we + // default to treating `Y` as a shift, but there are some cases where its + // preferable to treat `X` as a shift if `Z` is not a shift. if (A == C || A == D) { X = A; Y = B; Z = A == C ? D : C; + + // Check if we have: + // 1. (rem (shl X, Y), (shl Z, X) + // 2. (rem (mul X, Y), (shl Z, X) + // which don't match as they translate to + // 1. (rem (mul X, 1 << Y), (mul Z, 1 << X)) + // 2. (rem (mul X, Y), (mul X, 1 << X)) + // respectively and we can't match `X` with `1 << X`. + if (BO1->getOpcode() == Instruction::Shl && Z == C) + return nullptr; + + ShiftY = BO0->getOpcode() == Instruction::Shl; + ShiftZ = BO1->getOpcode() == Instruction::Shl; } else if (B == C || B == D) { X = B; Y = A; Z = B == C ? D : C; + + // Check if we have: + // 1. (rem (shl Y, X), (shl X, Z) + // 2. (rem (shl Y, X), (mul X/Z, X/Z) + // 3. (rem (mul Y, X), (shl Z, X) + // which don't match as they translate to + // 1. (rem (mul Y, 1 << X), (mul X, 1 << Z)) + // 2. (rem (mul Y, 1 << X), (mul X/Z, X/Z)) + // 3. (rem (mul Y, X), (mul Z, 1 << X)) + // respectively and we can't match `X` with `1 << X`. + + // Cases 1/2 + if (BO0->getOpcode() == Instruction::Shl) { + if (BO1->getOpcode() != Instruction::Shl || Z != C) + return nullptr; + + ShiftX = true; + } + // Case 3. + else if (BO1->getOpcode() == Instruction::Shl && Z == C) + return nullptr; + + ShiftZ = BO1->getOpcode() == Instruction::Shl && Z != C; } if (!X) @@ -1731,8 +1777,12 @@ APInt APIntY = ConstY->getValue(); APInt APIntZ = ConstZ->getValue(); - BinaryOperator *BO0 = cast(Op0); - BinaryOperator *BO1 = cast(Op1); + // Just treat the shifts as mul, we may end up returning a mul by power + // of 2 but that will be cleaned up later. + if (ShiftY) + APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY; + if (ShiftZ) + APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ; bool BO0HasNSW = BO0->hasNoSignedWrap(); bool BO0HasNUW = BO0->hasNoUnsignedWrap(); @@ -1744,6 +1794,11 @@ if (RemYZ.isZero() && (IsSRem ? BO0HasNSW : BO0HasNUW)) return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); + auto GetBinOpOut = [&](Value *RemSimplification) -> BinaryOperator * { + return ShiftX ? BinaryOperator::CreateShl(RemSimplification, X) + : BinaryOperator::CreateMul(X, RemSimplification); + }; + bool BO1HasNSW = BO1->hasNoSignedWrap(); bool BO1HasNUW = BO1->hasNoUnsignedWrap(); @@ -1751,8 +1806,7 @@ // if (rem Y, Z) == Y // -> (mul nuw/nsw X, Y) if (RemYZ == APIntY && (IsSRem ? BO1HasNSW : BO1HasNUW)) { - BinaryOperator *BO = - BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), APIntY)); + BinaryOperator *BO = GetBinOpOut(ConstantInt::get(I.getType(), APIntY)); // Copy any overflow flags from Op0. if (IsSRem || BO0HasNSW) BO->setHasNoSignedWrap(); @@ -1765,8 +1819,7 @@ // if Y >= Z // -> (mul {nuw} nsw X, (rem Y, Z)) if (APIntY.uge(APIntZ) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { - BinaryOperator *BO = - BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ)); + BinaryOperator *BO = GetBinOpOut(ConstantInt::get(I.getType(), RemYZ)); BO->setHasNoSignedWrap(); if (BO0HasNUW) BO->setHasNoUnsignedWrap(); diff --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll --- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll +++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll @@ -75,9 +75,7 @@ define <2 x i8> @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out(<2 x i8> %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_with_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nuw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = urem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nuw nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl nsw <2 x i8> %X, @@ -88,9 +86,7 @@ define i8 @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) { ; CHECK-LABEL: @urem_XY_XZ_with_CY_lt_CZ_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 3 -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[X:%.*]], 3 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, 3 @@ -265,9 +261,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_lt_CZ(<2 x i8> %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = shl <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = mul nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl <2 x i8> %X, @@ -289,9 +283,7 @@ define i8 @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_lt_CZ_no_nsw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[X:%.*]], 5 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw i8 [[X]], 4 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = mul nsw i8 [[X:%.*]], 5 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw i8 %X, 5 @@ -315,9 +307,7 @@ define i8 @srem_XY_XZ_with_CY_gt_CZ(i8 %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ( -; CHECK-NEXT: [[BO0:%.*]] = shl nsw i8 [[X:%.*]], 3 -; CHECK-NEXT: [[BO1:%.*]] = mul nsw i8 [[X]], 6 -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw i8 [[X:%.*]], 1 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = shl nsw i8 %X, 3 @@ -339,9 +329,7 @@ define <2 x i8> @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out(<2 x i8> %X) { ; CHECK-LABEL: @srem_XY_XZ_with_CY_gt_CZ_no_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw <2 x i8> [[X:%.*]], -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw <2 x i8> [[X]], -; CHECK-NEXT: [[R:%.*]] = srem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[R:%.*]] = shl nsw <2 x i8> [[X:%.*]], ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = mul nsw <2 x i8> %X,