diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1732,58 +1732,90 @@ } } - if (!ConstY || !ConstZ) - return nullptr; - - APInt APIntY = ConstY->getValue(); - APInt APIntZ = ConstZ->getValue(); - - // Just treat the shifts as mul, we may end up returning a mul by power - // of 2 but that will be cleaned up later. - if (BO0->getOpcode() == Instruction::Shl) - APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY; - if (BO1->getOpcode() == Instruction::Shl) - APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ; - bool IsSRem = I.getOpcode() == Instruction::SRem; - APInt RemYZ = IsSRem ? APIntY.srem(APIntZ) : APIntY.urem(APIntZ); bool BO0HasNSW = BO0->hasNoSignedWrap(); - bool BO0HasNUW = BO0->hasNoUnsignedWrap(); - - // (rem (mul nuw/nsw X, Y), (mul X, Z)) - // if (rem Y, Z) == 0 - // -> 0 - if (RemYZ.isZero() && (IsSRem ? BO0HasNSW : BO0HasNUW)) - return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); - bool BO1HasNSW = BO1->hasNoSignedWrap(); + bool BO0HasNUW = BO0->hasNoUnsignedWrap(); bool BO1HasNUW = BO1->hasNoUnsignedWrap(); - // (rem (mul X, Y), (mul nuw/nsw X, Z)) - // if (rem Y, Z) == Y - // -> (mul nuw/nsw X, Y) - if (RemYZ == APIntY && (IsSRem ? BO1HasNSW : BO1HasNUW)) { - BinaryOperator *BO = - BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), APIntY)); - // Copy any overflow flags from Op0. - if (IsSRem || BO0HasNSW) + // Try to handle constant cases first. + if (ConstY && ConstZ) { + APInt APIntY = ConstY->getValue(); + APInt APIntZ = ConstZ->getValue(); + + // Just treat the shifts as mul, we may end up returning a mul by power + // of 2 but that will be cleaned up later. + if (BO0->getOpcode() == Instruction::Shl) + APIntY = APInt(APIntY.getBitWidth(), 1) << APIntY; + if (BO1->getOpcode() == Instruction::Shl) + APIntZ = APInt(APIntZ.getBitWidth(), 1) << APIntZ; + + APInt RemYZ = IsSRem ? APIntY.srem(APIntZ) : APIntY.urem(APIntZ); + + // (rem (mul nuw/nsw X, Y), (mul X, Z)) + // if (rem Y, Z) == 0 + // -> 0 + if (RemYZ.isZero() && (IsSRem ? BO0HasNSW : BO0HasNUW)) + return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); + + // (rem (mul X, Y), (mul nuw/nsw X, Z)) + // if (rem Y, Z) == Y + // -> (mul nuw/nsw X, Y) + if (RemYZ == APIntY && (IsSRem ? BO1HasNSW : BO1HasNUW)) { + BinaryOperator *BO = + BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), APIntY)); + // Copy any overflow flags from Op0. + if (IsSRem || BO0HasNSW) + BO->setHasNoSignedWrap(); + if (!IsSRem || BO0HasNUW) + BO->setHasNoUnsignedWrap(); + return BO; + } + + // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z)) + // if Y >= Z + // -> (mul {nuw} nsw X, (rem Y, Z)) + if (APIntY.uge(APIntZ) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { + BinaryOperator *BO = + BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ)); BO->setHasNoSignedWrap(); - if (!IsSRem || BO0HasNUW) - BO->setHasNoUnsignedWrap(); - return BO; - } - - // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z)) - // if Y >= Z - // -> (mul {nuw} nsw X, (rem Y, Z)) - if (APIntY.uge(APIntZ) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { - BinaryOperator *BO = - BinaryOperator::CreateMul(X, ConstantInt::get(I.getType(), RemYZ)); - BO->setHasNoSignedWrap(); - if (BO0HasNUW) - BO->setHasNoUnsignedWrap(); - return BO; + if (BO0HasNUW) + BO->setHasNoUnsignedWrap(); + return BO; + } + } + + // Check if desirable to do generic replacement. + // NB: It may be beneficial to do this if we have X << Z even if there are + // multiple uses of Op0/Op1 as it will eliminate the urem (urem of a power + // of 2 is converted to add/and) and urem is pretty expensive (maybe more + // sense in DAGCombiner). + if ((ConstY && ConstZ) || + (Op0->hasOneUse() && Op1->hasOneUse() && + (IsSRem ? (BO0->getOpcode() != Instruction::Shl && + BO1->getOpcode() != Instruction::Shl) + : (BO0->getOpcode() != Instruction::Shl || + BO1->getOpcode() == Instruction::Shl)))) { + // (rem (mul nuw/nsw X, Y), (mul nuw {nsw} X, Z) + // -> (mul nuw/nsw X, (rem Y, Z)) + if (IsSRem ? (BO0HasNSW && BO1HasNSW && BO1HasNUW) + : (BO0HasNUW && BO1HasNUW)) { + // Convert the shifts to multiplies, cleaned up elsewhere. + if (BO0->getOpcode() == Instruction::Shl) + Y = IC.Builder.CreateShl(ConstantInt::get(I.getType(), 1), Y); + if (BO1->getOpcode() == Instruction::Shl) + Z = IC.Builder.CreateShl(ConstantInt::get(I.getType(), 1), Z); + BinaryOperator *BO = + BinaryOperator::CreateMul(X, IsSRem ? IC.Builder.CreateSRem(Y, Z) + : IC.Builder.CreateURem(Y, Z)); + + if (BO0HasNSW || BO1HasNSW) + BO->setHasNoSignedWrap(); + if (!IsSRem || (BO0HasNUW && BO1HasNUW)) + BO->setHasNoUnsignedWrap(); + return BO; + } } return nullptr; diff --git a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll --- a/llvm/test/Transforms/InstCombine/rem-mul-shl.ll +++ b/llvm/test/Transforms/InstCombine/rem-mul-shl.ll @@ -134,9 +134,8 @@ define i8 @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = mul nuw i8 [[Z:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = urem i8 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 %X, %Y @@ -147,9 +146,10 @@ define i8 @urem_XY_XZ_with_CX_Y_Z_is_mul_X_RemYZ(i8 %Y, i8 %Z) { ; CHECK-LABEL: @urem_XY_XZ_with_CX_Y_Z_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw i8 [[Y:%.*]], 10 -; CHECK-NEXT: [[BO1:%.*]] = shl nuw i8 10, [[Z:%.*]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i8 -1, [[Z:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw i8 [[TMP2]], 10 ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw i8 10, %Y @@ -160,9 +160,10 @@ define i8 @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_with_nsw_out1(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_with_nsw_out1( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw nsw i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = shl nuw i8 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[R:%.*]] = urem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw i8 -1, [[Z:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = xor i8 [[NOTMASK]], -1 +; CHECK-NEXT: [[TMP2:%.*]] = and i8 [[TMP1]], [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[TMP2]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nuw nsw i8 %X, %Y @@ -173,9 +174,11 @@ define <2 x i8> @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_with_nsw_out2(<2 x i8> %X, <2 x i8> %Y, <2 x i8> %Z) { ; CHECK-LABEL: @urem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_with_nsw_out2( -; CHECK-NEXT: [[BO0:%.*]] = shl nuw <2 x i8> [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = shl nuw nsw <2 x i8> [[Z:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = urem <2 x i8> [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = shl nuw <2 x i8> , [[Y:%.*]] +; CHECK-NEXT: [[NOTMASK:%.*]] = shl nsw <2 x i8> , [[Z:%.*]] +; CHECK-NEXT: [[TMP2:%.*]] = xor <2 x i8> [[NOTMASK]], +; CHECK-NEXT: [[TMP3:%.*]] = and <2 x i8> [[TMP1]], [[TMP2]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw <2 x i8> [[TMP3]], [[X:%.*]] ; CHECK-NEXT: ret <2 x i8> [[R]] ; %BO0 = shl nuw <2 x i8> %Y, %X @@ -366,9 +369,8 @@ define i8 @srem_XY_XZ_with_Y_Z_is_mul_X_RemYZ(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @srem_XY_XZ_with_Y_Z_is_mul_X_RemYZ( -; CHECK-NEXT: [[BO0:%.*]] = mul nsw i8 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = mul nuw nsw i8 [[X]], [[Z:%.*]] -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i8 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nsw i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw i8 %Y, %X @@ -379,9 +381,8 @@ define i8 @srem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_with_nuw_out(i8 %X, i8 %Y, i8 %Z) { ; CHECK-LABEL: @srem_XY_XZ_with_Y_Z_is_mul_X_RemYZ_with_nuw_out( -; CHECK-NEXT: [[BO0:%.*]] = mul nuw nsw i8 [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[BO1:%.*]] = mul nuw nsw i8 [[Z:%.*]], [[X]] -; CHECK-NEXT: [[R:%.*]] = srem i8 [[BO0]], [[BO1]] +; CHECK-NEXT: [[TMP1:%.*]] = srem i8 [[Y:%.*]], [[Z:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i8 [[TMP1]], [[X:%.*]] ; CHECK-NEXT: ret i8 [[R]] ; %BO0 = mul nsw nuw i8 %Y, %X