diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -1759,6 +1759,27 @@ return nullptr; } +/// Check that the Op1 is in expected form, i.e.: +/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +static bool omitCheckForZeroBeforeMulWithOverflowInternal(Value *Op1, + Value *X) { + auto *Extract = dyn_cast(Op1); + // We should only be extracting the overflow bit. + if (!Extract || !Extract->getIndices().equals(1)) + return false; + Value *Agg = Extract->getAggregateOperand(); + // This should be a multiplication-with-overflow intrinsic. + if (!match(Agg, m_CombineOr(m_Intrinsic(), + m_Intrinsic()))) + return false; + // One of its multipliers should be the value we checked for zero before. + if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)), + m_Argument<1>(m_Specific(X))))) + return false; + return true; +} + /// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some /// other form of check, e.g. one that was using division; it may have been /// guarded against division-by-zero. We can drop that check now. @@ -1774,23 +1795,41 @@ if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || Pred != ICmpInst::Predicate::ICMP_NE) return nullptr; - auto *Extract = dyn_cast(Op1); - // We should only be extracting the overflow bit. - if (!Extract || !Extract->getIndices().equals(1)) - return nullptr; - Value *Agg = Extract->getAggregateOperand(); - // This should be a multiplication-with-overflow intrinsic. - if (!match(Agg, m_CombineOr(m_Intrinsic(), - m_Intrinsic()))) - return nullptr; - // One of its multipliers should be the value we checked for zero before. - if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)), - m_Argument<1>(m_Specific(X))))) + // Is Op1 in expected form? + if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) return nullptr; // Can omit 'and', and just return the overflow bit. return Op1; } +/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some +/// other form of check, e.g. one that was using division; it may have been +/// guarded against division-by-zero. We can drop that check now. +/// Look for: +/// %Op0 = icmp eq i4 %X, 0 +/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +/// %NotOp1 = xor i1 %Op1, true +/// %or = or i1 %Op0, %NotOp1 +/// We can just return %NotOp1 +static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0, + Value *NotOp1) { + ICmpInst::Predicate Pred; + Value *X; + if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || + Pred != ICmpInst::Predicate::ICMP_EQ) + return nullptr; + // We expect the other hand of an 'or' to be a 'not'. + Value *Op1; + if (!match(NotOp1, m_Not(m_Value(Op1)))) + return nullptr; + // Is Op1 in expected form? + if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) + return nullptr; + // Can omit 'and', and just return the inverted overflow bit. + return NotOp1; +} + /// Given operands for an And, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, @@ -2027,6 +2066,14 @@ if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false)) return V; + // If we have a multiplication overflow check that is being 'and'ed with a + // check that one of the multipliers is not zero, we can omit the 'and', and + // only keep the overflow check. + if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op0, Op1)) + return V; + if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op1, Op0)) + return V; + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, MaxRecurse)) diff --git a/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-smul_ov-not.ll b/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-smul_ov-not.ll --- a/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-smul_ov-not.ll +++ b/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-smul_ov-not.ll @@ -5,12 +5,10 @@ define i1 @t0_umul(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t0_umul( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB:%.*]]) ; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[SMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = or i1 [[CMP]], [[PHITMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb) @@ -22,12 +20,10 @@ define i1 @t1_commutative(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t1_commutative( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB:%.*]]) ; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[SMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = or i1 [[PHITMP]], [[CMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-umul_ov-not.ll b/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-umul_ov-not.ll --- a/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-umul_ov-not.ll +++ b/llvm/test/Transforms/InstSimplify/div-by-0-guard-before-umul_ov-not.ll @@ -5,12 +5,10 @@ define i1 @t0_umul(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t0_umul( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB:%.*]]) ; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = or i1 [[CMP]], [[PHITMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb) @@ -22,12 +20,10 @@ define i1 @t1_commutative(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t1_commutative( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB:%.*]]) ; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = or i1 [[PHITMP]], [[CMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll b/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll --- a/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll +++ b/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll @@ -124,12 +124,10 @@ ; ; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-LABEL: @will_overflow( ; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: bb: -; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: [[T0:%.*]] = icmp eq i64 [[ARG:%.*]], 0 -; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG]], i64 [[ARG1:%.*]]) +; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG:%.*]], i64 [[ARG1:%.*]]) ; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: [[UMUL_OV:%.*]] = extractvalue { i64, i1 } [[UMUL]], 1 ; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true -; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: [[T6:%.*]] = or i1 [[T0]], [[PHITMP]] -; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: ret i1 [[T6]] +; INSTCOMBINESIMPLIFYCFGINSTCOMBINE-NEXT: ret i1 [[PHITMP]] ; ; INSTCOMBINESIMPLIFYCFGCOSTLYONLY-LABEL: @will_overflow( ; INSTCOMBINESIMPLIFYCFGCOSTLYONLY-NEXT: bb: @@ -142,12 +140,10 @@ ; ; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-LABEL: @will_overflow( ; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: bb: -; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: [[T0:%.*]] = icmp eq i64 [[ARG:%.*]], 0 -; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG]], i64 [[ARG1:%.*]]) +; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG:%.*]], i64 [[ARG1:%.*]]) ; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: [[UMUL_OV:%.*]] = extractvalue { i64, i1 } [[UMUL]], 1 ; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true -; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: [[T6:%.*]] = or i1 [[T0]], [[PHITMP]] -; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: ret i1 [[T6]] +; INSTCOMBINESIMPLIFYCFGCOSTLYINSTCOMBINE-NEXT: ret i1 [[PHITMP]] ; bb: %t0 = icmp eq i64 %arg, 0