diff --git a/llvm/include/llvm/Analysis/OverflowInstAnalysis.h b/llvm/include/llvm/Analysis/OverflowInstAnalysis.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Analysis/OverflowInstAnalysis.h @@ -0,0 +1,45 @@ +//===-- OverflowInstAnalysis.h - Utils to fold overflow insts ----*- C++ -*-==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file holds routines to help analyse overflow instructions +// and fold them into constants or other overflow instructions +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_OVERFLOWINSTANALYSIS_H +#define LLVM_ANALYSIS_OVERFLOWINSTANALYSIS_H + +#include "llvm/IR/InstrTypes.h" + +namespace llvm { +class Value; +class Use; + +/// Match one of the patterns up to the select/logic op: +/// %Op0 = icmp ne i4 %X, 0 +/// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y) +/// %Op1 = extractvalue { i4, i1 } %Agg, 1 +/// %ret = select i1 %Op0, i1 %Op1, i1 false / %ret = and i1 %Op0, %Op1 +/// +/// %Op0 = icmp eq i4 %X, 0 +/// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %Y) +/// %NotOp1 = extractvalue { i4, i1 } %Agg, 1 +/// %Op1 = xor i1 %NotOp1, true +/// %ret = select i1 %Op0, i1 true, i1 %Op1 / %ret = or i1 %Op0, %Op1 +/// +/// Callers are expected to align that with the operands of the select/logic. +/// IsAnd is set to true if the Op0 and Op1 are used as the first pattern. +/// If Op0 and Op1 match one of the patterns above, return true and fill Y's +/// use. + +bool isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd, + Use *&Y); +bool isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd); +} // end namespace llvm + +#endif diff --git a/llvm/lib/Analysis/CMakeLists.txt b/llvm/lib/Analysis/CMakeLists.txt --- a/llvm/lib/Analysis/CMakeLists.txt +++ b/llvm/lib/Analysis/CMakeLists.txt @@ -101,6 +101,7 @@ ObjCARCAnalysisUtils.cpp ObjCARCInstKind.cpp OptimizationRemarkEmitter.cpp + OverflowInstAnalysis.cpp PHITransAddr.cpp PhiValues.cpp PostDominators.cpp diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -26,6 +26,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" @@ -1947,77 +1948,6 @@ return nullptr; } -/// Check that the Op1 is in expected form, i.e.: -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -static bool omitCheckForZeroBeforeMulWithOverflowInternal(Value *Op1, - Value *X) { - auto *Extract = dyn_cast(Op1); - // We should only be extracting the overflow bit. - if (!Extract || !Extract->getIndices().equals(1)) - return false; - Value *Agg = Extract->getAggregateOperand(); - // This should be a multiplication-with-overflow intrinsic. - if (!match(Agg, m_CombineOr(m_Intrinsic(), - m_Intrinsic()))) - return false; - // One of its multipliers should be the value we checked for zero before. - if (!match(Agg, m_CombineOr(m_Argument<0>(m_Specific(X)), - m_Argument<1>(m_Specific(X))))) - return false; - return true; -} - -/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some -/// other form of check, e.g. one that was using division; it may have been -/// guarded against division-by-zero. We can drop that check now. -/// Look for: -/// %Op0 = icmp ne i4 %X, 0 -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -/// %??? = and i1 %Op0, %Op1 -/// We can just return %Op1 -static Value *omitCheckForZeroBeforeMulWithOverflow(Value *Op0, Value *Op1) { - ICmpInst::Predicate Pred; - Value *X; - if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || - Pred != ICmpInst::Predicate::ICMP_NE) - return nullptr; - // Is Op1 in expected form? - if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) - return nullptr; - // Can omit 'and', and just return the overflow bit. - return Op1; -} - -/// The @llvm.[us]mul.with.overflow intrinsic could have been folded from some -/// other form of check, e.g. one that was using division; it may have been -/// guarded against division-by-zero. We can drop that check now. -/// Look for: -/// %Op0 = icmp eq i4 %X, 0 -/// %Agg = tail call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) -/// %Op1 = extractvalue { i4, i1 } %Agg, 1 -/// %NotOp1 = xor i1 %Op1, true -/// %or = or i1 %Op0, %NotOp1 -/// We can just return %NotOp1 -static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0, - Value *NotOp1) { - ICmpInst::Predicate Pred; - Value *X; - if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero())) || - Pred != ICmpInst::Predicate::ICMP_EQ) - return nullptr; - // We expect the other hand of an 'or' to be a 'not'. - Value *Op1; - if (!match(NotOp1, m_Not(m_Value(Op1)))) - return nullptr; - // Is Op1 in expected form? - if (!omitCheckForZeroBeforeMulWithOverflowInternal(Op1, X)) - return nullptr; - // Can omit 'and', and just return the inverted overflow bit. - return NotOp1; -} - /// Given a bitwise logic op, check if the operands are add/sub with a common /// source value and inverted constant (identity: C - X -> ~(X + ~C)). static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1, @@ -2102,10 +2032,10 @@ // If we have a multiplication overflow check that is being 'and'ed with a // check that one of the multipliers is not zero, we can omit the 'and', and // only keep the overflow check. - if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op0, Op1)) - return V; - if (Value *V = omitCheckForZeroBeforeMulWithOverflow(Op1, Op0)) - return V; + if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true)) + return Op1; + if (isCheckForZeroAndMulWithOverflow(Op1, Op0, true)) + return Op0; // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || @@ -2316,10 +2246,10 @@ // If we have a multiplication overflow check that is being 'and'ed with a // check that one of the multipliers is not zero, we can omit the 'and', and // only keep the overflow check. - if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op0, Op1)) - return V; - if (Value *V = omitCheckForZeroBeforeInvertedMulWithOverflow(Op1, Op0)) - return V; + if (isCheckForZeroAndMulWithOverflow(Op0, Op1, false)) + return Op1; + if (isCheckForZeroAndMulWithOverflow(Op1, Op0, false)) + return Op0; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, diff --git a/llvm/lib/Analysis/OverflowInstAnalysis.cpp b/llvm/lib/Analysis/OverflowInstAnalysis.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Analysis/OverflowInstAnalysis.cpp @@ -0,0 +1,71 @@ +//==-- OverflowInstAnalysis.cpp - Utils to fold overflow insts ----*- C++ -*-=// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file holds routines to help analyse overflow instructions +// and fold them into constants or other overflow instructions +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OverflowInstAnalysis.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +bool llvm::isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, bool IsAnd, + Use *&Y) { + ICmpInst::Predicate Pred; + Value *X, *NotOp1; + int XIdx; + IntrinsicInst *II; + + if (!match(Op0, m_ICmp(Pred, m_Value(X), m_Zero()))) + return false; + + /// %Agg = call { i4, i1 } @llvm.[us]mul.with.overflow.i4(i4 %X, i4 %???) + /// %V = extractvalue { i4, i1 } %Agg, 1 + auto matchMulOverflowCheck = [X, &II, &XIdx](Value *V) { + auto *Extract = dyn_cast(V); + // We should only be extracting the overflow bit. + if (!Extract || !Extract->getIndices().equals(1)) + return false; + + II = dyn_cast(Extract->getAggregateOperand()); + if (!match(II, m_CombineOr(m_Intrinsic(), + m_Intrinsic()))) + return false; + + if (II->getArgOperand(0) == X) + XIdx = 0; + else if (II->getArgOperand(1) == X) + XIdx = 1; + else + return false; + return true; + }; + + bool Matched = + (IsAnd && Pred == ICmpInst::Predicate::ICMP_NE && + matchMulOverflowCheck(Op1)) || + (!IsAnd && Pred == ICmpInst::Predicate::ICMP_EQ && + match(Op1, m_Not(m_Value(NotOp1))) && matchMulOverflowCheck(NotOp1)); + + if (!Matched) + return false; + + Y = &II->getArgOperandUse(!XIdx); + return true; +} + +bool llvm::isCheckForZeroAndMulWithOverflow(Value *Op0, Value *Op1, + bool IsAnd) { + Use *Y; + return isCheckForZeroAndMulWithOverflow(Op0, Op1, IsAnd, Y); +} \ No newline at end of file diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -18,6 +18,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" @@ -2697,6 +2698,18 @@ if (Value *S = SimplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, /* AllowRefinement */ true)) return replaceOperand(SI, 2, S); + + if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { + Use *Y = nullptr; + bool IsAnd = match(FalseVal, m_Zero()) ? true : false; + Value *Op1 = IsAnd ? TrueVal : FalseVal; + if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { + auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); + InsertNewInstBefore(FI, *cast(Y->getUser())); + replaceUse(*Y, FI); + return replaceInstUsesWith(SI, Op1); + } + } } // Selecting between two integer or vector splat integer constants? diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov-not.ll @@ -5,12 +5,11 @@ define i1 @t0_umul(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t0_umul( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[SMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 true, i1 [[PHITMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-smul_ov.ll @@ -20,11 +20,10 @@ define i1 @t1_commutative(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t1_commutative( -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[SMUL:%.*]] = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[SMUL_OV:%.*]] = extractvalue { i4, i1 } [[SMUL]], 1 -; CHECK-NEXT: [[AND:%.*]] = select i1 [[CMP]], i1 [[SMUL_OV]], i1 false -; CHECK-NEXT: ret i1 [[AND]] +; CHECK-NEXT: ret i1 [[SMUL_OV]] ; %cmp = icmp ne i4 %size, 0 %smul = tail call { i4, i1 } @llvm.smul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov-not.ll @@ -5,12 +5,11 @@ define i1 @t0_umul(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t0_umul( -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1 ; CHECK-NEXT: [[PHITMP:%.*]] = xor i1 [[UMUL_OV]], true -; CHECK-NEXT: [[OR:%.*]] = select i1 [[CMP]], i1 true, i1 [[PHITMP]] -; CHECK-NEXT: ret i1 [[OR]] +; CHECK-NEXT: ret i1 [[PHITMP]] ; %cmp = icmp eq i4 %size, 0 %umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb) diff --git a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll --- a/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll +++ b/llvm/test/Transforms/InstCombine/div-by-0-guard-before-umul_ov.ll @@ -20,11 +20,10 @@ define i1 @t1_commutative(i4 %size, i4 %nmemb) { ; CHECK-LABEL: @t1_commutative( -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i4 [[SIZE:%.*]], 0 -; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE]], i4 [[NMEMB:%.*]]) +; CHECK-NEXT: [[NMEMB_FR:%.*]] = freeze i4 [[NMEMB:%.*]] +; CHECK-NEXT: [[UMUL:%.*]] = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 [[SIZE:%.*]], i4 [[NMEMB_FR]]) ; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i4, i1 } [[UMUL]], 1 -; CHECK-NEXT: [[AND:%.*]] = select i1 [[CMP]], i1 [[UMUL_OV]], i1 false -; CHECK-NEXT: ret i1 [[AND]] +; CHECK-NEXT: ret i1 [[UMUL_OV]] ; %cmp = icmp ne i4 %size, 0 %umul = tail call { i4, i1 } @llvm.umul.with.overflow.i4(i4 %size, i4 %nmemb)