diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -923,8 +923,11 @@ /// Check for common or similar folds of integer division or integer remainder. /// This applies to all 4 opcodes (sdiv/udiv/srem/urem). -static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv, - const SimplifyQuery &Q) { +static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0, + Value *Op1, const SimplifyQuery &Q) { + bool IsDiv = (Opcode == Instruction::SDiv || Opcode == Instruction::UDiv); + bool IsSigned = (Opcode == Instruction::SDiv || Opcode == Instruction::SRem); + Type *Ty = Op0->getType(); // X / undef -> poison @@ -976,6 +979,21 @@ (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) return IsDiv ? Op0 : Constant::getNullValue(Ty); + // If X * Y does not overflow, then: + // X * Y / Y -> X + // X * Y % Y -> 0 + if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { + auto *Mul = cast(Op0); + // The multiplication can't overflow if it is defined not to, or if + // X == A / Y for some A. + if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) || + (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul)) || + (IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || + (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) { + return IsDiv ? X : Constant::getNullValue(Op0->getType()); + } + } + return nullptr; } @@ -1047,25 +1065,11 @@ if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, true, Q)) + if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q)) return V; bool IsSigned = Opcode == Instruction::SDiv; - // (X * Y) / Y -> X if the multiplication does not overflow. - Value *X; - if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { - auto *Mul = cast(Op0); - // If the Mul does not overflow, then we are good to go. - if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) || - (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul))) - return X; - // If X has the form X = A / Y, then X * Y cannot overflow. - if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || - (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) - return X; - } - // (X rem Y) / Y -> 0 if ((IsSigned && match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || (!IsSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) @@ -1073,7 +1077,7 @@ // (X /u C1) /u C2 -> 0 if C1 * C2 overflow ConstantInt *C1, *C2; - if (!IsSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && + if (!IsSigned && match(Op0, m_UDiv(m_Value(), m_ConstantInt(C1))) && match(Op1, m_ConstantInt(C2))) { bool Overflow; (void)C1->getValue().umul_ov(C2->getValue(), Overflow); @@ -1105,7 +1109,7 @@ if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, false, Q)) + if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q)) return V; // (X % Y) % Y -> X % Y diff --git a/llvm/test/Transforms/InstSimplify/rem.ll b/llvm/test/Transforms/InstSimplify/rem.ll --- a/llvm/test/Transforms/InstSimplify/rem.ll +++ b/llvm/test/Transforms/InstSimplify/rem.ll @@ -335,9 +335,7 @@ define i32 @srem_of_mul_nsw(i32 %x, i32 %y) { ; CHECK-LABEL: @srem_of_mul_nsw( -; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[MOD:%.*]] = srem i32 [[MUL]], [[Y]] -; CHECK-NEXT: ret i32 [[MOD]] +; CHECK-NEXT: ret i32 0 ; %mul = mul nsw i32 %x, %y %mod = srem i32 %mul, %y @@ -349,9 +347,7 @@ ; - vector types define <2 x i32> @srem_of_mul_nsw_vec_commuted(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @srem_of_mul_nsw_vec_commuted( -; CHECK-NEXT: [[MUL:%.*]] = mul nsw <2 x i32> [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[MOD:%.*]] = srem <2 x i32> [[MUL]], [[Y]] -; CHECK-NEXT: ret <2 x i32> [[MOD]] +; CHECK-NEXT: ret <2 x i32> zeroinitializer ; %mul = mul nsw <2 x i32> %y, %x %mod = srem <2 x i32> %mul, %y @@ -393,9 +389,7 @@ define i32 @urem_of_mul_nuw(i32 %x, i32 %y) { ; CHECK-LABEL: @urem_of_mul_nuw( -; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[MOD:%.*]] = urem i32 [[MUL]], [[Y]] -; CHECK-NEXT: ret i32 [[MOD]] +; CHECK-NEXT: ret i32 0 ; %mul = mul nuw i32 %x, %y %mod = urem i32 %mul, %y @@ -404,9 +398,7 @@ define <2 x i32> @srem_of_mul_nuw_vec_commuted(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @srem_of_mul_nuw_vec_commuted( -; CHECK-NEXT: [[MUL:%.*]] = mul nuw <2 x i32> [[Y:%.*]], [[X:%.*]] -; CHECK-NEXT: [[MOD:%.*]] = urem <2 x i32> [[MUL]], [[Y]] -; CHECK-NEXT: ret <2 x i32> [[MOD]] +; CHECK-NEXT: ret <2 x i32> zeroinitializer ; %mul = mul nuw <2 x i32> %y, %x %mod = urem <2 x i32> %mul, %y @@ -426,10 +418,7 @@ define i4 @srem_mul_sdiv(i4 %x, i4 %y) { ; CHECK-LABEL: @srem_mul_sdiv( -; CHECK-NEXT: [[D:%.*]] = sdiv i4 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = mul i4 [[D]], [[Y]] -; CHECK-NEXT: [[MOD:%.*]] = srem i4 [[MUL]], [[Y]] -; CHECK-NEXT: ret i4 [[MOD]] +; CHECK-NEXT: ret i4 0 ; %d = sdiv i4 %x, %y %mul = mul i4 %d, %y @@ -452,10 +441,7 @@ define <3 x i7> @urem_mul_udiv_vec_commuted(<3 x i7> %x, <3 x i7> %y) { ; CHECK-LABEL: @urem_mul_udiv_vec_commuted( -; CHECK-NEXT: [[D:%.*]] = udiv <3 x i7> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[MUL:%.*]] = mul <3 x i7> [[Y]], [[D]] -; CHECK-NEXT: [[MOD:%.*]] = urem <3 x i7> [[MUL]], [[Y]] -; CHECK-NEXT: ret <3 x i7> [[MOD]] +; CHECK-NEXT: ret <3 x i7> zeroinitializer ; %d = udiv <3 x i7> %x, %y %mul = mul <3 x i7> %y, %d