diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2585,6 +2585,74 @@ return nullptr; } +/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some +/// other form of check, e.g. one that was using division; it may have been +/// guarded against division-by-zero. We can drop that check now. +/// +/// Look for: +/// %Op0 = icmp ne i4 %X, 0 +/// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +/// %ret = select i1 %Op0, i1 %Op1, i1 false +/// +/// %Op0 = icmp eq i4 %X, 0 +/// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y) +/// %NotOp1 = extractvalue { i4, i1 } %Agg, 1 +/// %Op1 = xor i1 %NotOp1, true +/// %ret = select i1 %Op0, i1 true, i1 %Op1 +/// +/// We can freeze %Y and just return %Op1 in both cases. + +static Value * +omitCheckForZeroBeforeMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd, + InstCombinerImpl &IC) { + ICmpInst::Predicate Pred; + Value *X, *NotOp1; + int XIdx; + IntrinsicInst *II; + + if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero()))) + return nullptr; + + /// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) + /// %V = extractvalue { i4, i1 } %Agg, 1 + auto matchMulOverflowCheck = [X, &II, &XIdx](Value *V) { + auto *Extract = dyn_cast(V); + // We should only be extracting the overflow bit. + if (!Extract || !Extract->getIndices().equals(1)) + return false; + + II = dyn_cast(Extract->getAggregateOperand()); + if (!match(II, m_CombineOr(m_Intrinsic(), + m_Intrinsic()))) + return false; + + if (II->getArgOperand(0) == X) + XIdx = 0; + else if (II->getArgOperand(1) == X) + XIdx = 1; + else + return false; + return true; + }; + + bool Matched = (IsAnd && Pred == ICmpInst::Predicate::ICMP_NE && + matchMulOverflowCheck(Op1)) || + (!IsAnd && Pred == ICmpInst::Predicate::ICMP_EQ && + match(Op1, m_Not(m_Value(NotOp1))) && + matchMulOverflowCheck(NotOp1)); + + if (!Matched) + return nullptr; + + int IdxToFreeze = !XIdx; + Value *ValueToFr = II->getArgOperand(IdxToFreeze); + FreezeInst *FI = new FreezeInst(ValueToFr, ValueToFr->getName() + ".fr"); + IC.InsertNewInstBefore(FI, *II); + IC.replaceUse(II->getArgOperandUse(IdxToFreeze), FI); + return Op1; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -2697,6 +2765,14 @@ if (Value *S = SimplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, /* AllowRefinement */ true)) return replaceOperand(SI, 2, S); + + if (match(FalseVal, m_Zero())) { + if (Value *V = omitCheckForZeroBeforeMulWithOverflow(CondVal, TrueVal, true, *this)) + return replaceInstUsesWith(SI, V); + } else if (match(TrueVal, m_One())) { + if (Value *V = omitCheckForZeroBeforeMulWithOverflow(CondVal, FalseVal, false, *this)) + return replaceInstUsesWith(SI, V); + } } // Selecting between two integer or vector splat integer constants? diff --git a/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll b/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll --- a/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll +++ b/llvm/test/Transforms/PhaseOrdering/unsigned-multiply-overflow-check.ll @@ -59,11 +59,10 @@ ; ; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-LABEL: @will_not_overflow( ; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: bb: -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[T0:%.*]] = icmp ne i64 [[ARG:%.*]], 0 -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG]], i64 [[ARG1:%.*]]) +; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[ARG1_FR:%.*]] = freeze i64 [[ARG1:%.*]] +; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG:%.*]], i64 [[ARG1_FR]]) ; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[UMUL_OV:%.*]] = extractvalue { i64, i1 } [[UMUL]], 1 -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[T6:%.*]] = select i1 [[T0]], i1 [[UMUL_OV]], i1 false -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: ret i1 [[T6]] +; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: ret i1 [[UMUL_OV]] ; bb: %t0 = icmp eq i64 %arg, 0 @@ -126,12 +125,11 @@ ; ; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-LABEL: @will_overflow( ; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: bb: -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[T0:%.*]] = icmp eq i64 [[ARG:%.*]], 0 -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG]], i64 [[ARG1:%.*]]) +; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[ARG1_FR:%.*]] = freeze i64 [[ARG1:%.*]] +; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[UMUL:%.*]] = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[ARG:%.*]], i64 [[ARG1_FR]]) ; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[UMUL_OV:%.*]] = extractvalue { i64, i1 } [[UMUL]], 1 ; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[PHI_BO:%.*]] = xor i1 [[UMUL_OV]], true -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: [[T6:%.*]] = select i1 [[T0]], i1 true, i1 [[PHI_BO]] -; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: ret i1 [[T6]] +; INSTCOMBINESIMPLIFYCFGINSTCOMBINEUNSAFE-NEXT: ret i1 [[PHI_BO]] ; bb: %t0 = icmp eq i64 %arg, 0