diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -37,6 +37,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Traits/SemanticTrait.h" namespace llvm { @@ -60,6 +61,7 @@ /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative /// results if the users specified it is safe to use. +/// FIXME: Incorporate this into trait framework. struct InstrInfoQuery { InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {} InstrInfoQuery() : UseInstrInfo(true) {} @@ -83,9 +85,9 @@ return false; } - bool isExact(const BinaryOperator *Op) const { - if (UseInstrInfo && isa(Op)) - return cast(Op)->isExact(); + bool isExact(const Instruction *I) const { + if (UseInstrInfo && isa(I)) + return cast(I)->isExact(); return false; } }; @@ -142,20 +144,37 @@ // Please use the SimplifyQuery versions in new code. /// Given operand for an FNeg, fold the result or return null. -Value *SimplifyFNegInst(Value *Op, FastMathFlags FMF, - const SimplifyQuery &Q); +Value *SimplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q); /// Given operands for an Add, fold the result or return null. +template Value *SimplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *SimplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyAddInst(LHS, RHS, isNSW, isNUW, Q, Matcher); +} /// Given operands for a Sub, fold the result or return null. Value *SimplifySubInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, const SimplifyQuery &Q); -/// Given operands for an FAdd, fold the result or return null. +/// Given operands for an FAdd, and a matcher context that was initialized for +/// the actual instruction, fold the result or return null. +template Value *SimplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +/// Given operands for an FAdd, fold the result or return null. +/// We don't have any information about the traits of the 'fadd' so run this +/// with the unassuming default trait. +inline Value *SimplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyFAddInst(LHS, RHS, FMF, Q, Matcher); +} /// Given operands for an FSub, fold the result or return null. Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, @@ -272,13 +291,27 @@ const SimplifyQuery &Q); /// Given operands for a BinaryOperator, fold the result or return null. +template Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher); +} /// Given operands for a BinaryOperator, fold the result or return null. /// Try to use FastMathFlags when folding the result. -Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q); +template +Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyBinOp(Opcode, LHS, RHS, FMF, Q, Matcher); +} /// Given a callsite, fold the result or return null. Value *SimplifyCall(CallBase *Call, const SimplifyQuery &Q); @@ -289,6 +322,10 @@ /// See if we can compute a simplified version of this instruction. If not, /// return null. +template +Value *SimplifyInstructionWithTrait(Instruction *I, const SimplifyQuery &Q, + OptimizationRemarkEmitter *ORE = nullptr); + Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, OptimizationRemarkEmitter *ORE = nullptr); @@ -336,4 +373,3 @@ } // end namespace llvm #endif - diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -278,6 +278,15 @@ unsigned getFunctionalOpcode() const { return GetFunctionalOpcodeForVP(getIntrinsicID()); } + bool isFunctionalCommutative() const { + return Instruction::isCommutative(getFunctionalOpcode()); + } + bool isFunctionalUnaryOp() const { + return Instruction::isUnaryOp(getFunctionalOpcode()); + } + bool isFunctionalBinaryOp() const { + return Instruction::isBinaryOp(getFunctionalOpcode()); + } // Equivalent non-predicated opcode static unsigned GetFunctionalOpcodeForVP(Intrinsic::ID ID); @@ -288,9 +297,21 @@ public: bool isUnaryOp() const; bool isTernaryOp() const; + bool hasRoundingMode() const; Optional getRoundingMode() const; Optional getExceptionBehavior() const; + unsigned getFunctionalOpcode() const; + bool isFunctionalCommutative() const { + return Instruction::isCommutative(getFunctionalOpcode()); + } + bool isFunctionalUnaryOp() const { + return Instruction::isUnaryOp(getFunctionalOpcode()); + } + bool isFunctionalBinaryOp() const { + return Instruction::isBinaryOp(getFunctionalOpcode()); + } + // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I); static bool classof(const Value *V) { diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -39,17 +39,40 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Traits/SemanticTrait.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" + #include namespace llvm { namespace PatternMatch { +// Trait-match pattern in a given context and update the context. template bool match(Val *V, const Pattern &P) { + // TODO: Use this instead of the Trait-less Pattern::match() functions. This + // single function does the same as the Trait-less 'Pattern::match()' + // functions that are replicated once for every Pattern. return const_cast(P).match(V); } +// Trait-match pattern in a given context and update the context. +template +bool match(Val *V, const Pattern &P, MatcherContext &MContext) { + return const_cast(P).match(V, MContext); +} + +// Trait-match pattern and update the context on match. +template +bool try_match(Val *V, const Pattern &P, MatcherContext &MContext) { + MatcherContext CopyCtx(MContext); + if (const_cast(P).match(V, CopyCtx)) { + MContext = CopyCtx; + return true; + } + return false; +} + template bool match(ArrayRef Mask, const Pattern &P) { return const_cast(P).match(Mask); } @@ -59,8 +82,14 @@ OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + return V->hasOneUse() && SubPattern.match(V, MContext); } }; @@ -69,7 +98,16 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return trait_isa(V); + } }; /// Match an arbitrary value and ignore it. @@ -110,7 +148,15 @@ match_unless(const Ty &Matcher) : M(Matcher) {} - template bool match(ITy *V) { return !M.match(V); } + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return !M.match(V, MContext); + } }; /// Match if the inner matcher does *NOT* match. @@ -126,11 +172,16 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) { - if (L.match(V)) - return true; - if (R.match(V)) + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (try_match(V, L, MContext)) { return true; - return false; + } + return R.match(V, MContext); } }; @@ -141,10 +192,13 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) - return true; - return false; + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return L.match(V, MContext) && R.match(V, MContext); } }; @@ -165,17 +219,23 @@ bool AllowUndef; apint_match(const APInt *&Res, bool AllowUndef) - : Res(Res), AllowUndef(AllowUndef) {} + : Res(Res), AllowUndef(AllowUndef) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; } if (V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) - if (auto *CI = dyn_cast_or_null( - C->getSplatValue(AllowUndef))) { + if (auto *CI = + dyn_cast_or_null(C->getSplatValue(AllowUndef))) { Res = &CI->getValue(); return true; } @@ -193,14 +253,20 @@ : Res(Res), AllowUndef(AllowUndef) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; } if (V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) - if (auto *CI = dyn_cast_or_null( - C->getSplatValue(AllowUndef))) { + if (auto *CI = + dyn_cast_or_null(C->getSplatValue(AllowUndef))) { Res = &CI->getValueAPF(); return true; } @@ -244,6 +310,12 @@ template struct constantint_match { template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -262,14 +334,20 @@ return constantint_match(); } -/// This helper class is used to match constant scalars, vector splats, -/// and fixed width vectors that satisfy a specified predicate. -/// For fixed width vector constants, undefined elements are ignored. +/// This helper class is used to match scalar and fixed width vector integer +/// constants that satisfy a specified predicate. +/// For vector constants, undefined elements are ignored. template struct cstval_pred_ty : public Predicate { template bool match(ITy *V) { - if (const auto *CV = dyn_cast(V)) - return this->isValue(CV->getValue()); + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (const auto *CI = dyn_cast(V)) + return this->isValue(CI->getValue()); if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CV = dyn_cast_or_null(C->getSplatValue())) @@ -318,6 +396,12 @@ api_pred_ty(const APInt *&R) : Res(R) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -344,6 +428,12 @@ apf_pred_ty(const APFloat *&R) : Res(R) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -410,9 +500,7 @@ inline cst_pred_ty m_Negative() { return cst_pred_ty(); } -inline api_pred_ty m_Negative(const APInt *&V) { - return V; -} +inline api_pred_ty m_Negative(const APInt *&V) { return V; } struct is_nonnegative { bool isValue(const APInt &C) { return C.isNonNegative(); } @@ -422,9 +510,7 @@ inline cst_pred_ty m_NonNegative() { return cst_pred_ty(); } -inline api_pred_ty m_NonNegative(const APInt *&V) { - return V; -} +inline api_pred_ty m_NonNegative(const APInt *&V) { return V; } struct is_strictlypositive { bool isValue(const APInt &C) { return C.isStrictlyPositive(); } @@ -453,9 +539,7 @@ }; /// Match an integer 1 or a vector with all elements equal to 1. /// For vectors, this includes constants with undefined elements. -inline cst_pred_ty m_One() { - return cst_pred_ty(); -} +inline cst_pred_ty m_One() { return cst_pred_ty(); } struct is_zero_int { bool isValue(const APInt &C) { return C.isNullValue(); } @@ -468,6 +552,12 @@ struct is_zero { template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { auto *C = dyn_cast(V); // FIXME: this should be able to do something for scalable vectors return C && (C->isNullValue() || cst_pred_ty().match(C)); @@ -475,21 +565,15 @@ }; /// Match any null constant or a vector with all elements equal to 0. /// For vectors, this includes constants with undefined elements. -inline is_zero m_Zero() { - return is_zero(); -} +inline is_zero m_Zero() { return is_zero(); } struct is_power2 { bool isValue(const APInt &C) { return C.isPowerOf2(); } }; /// Match an integer or vector power-of-2. /// For vectors, this includes constants with undefined elements. -inline cst_pred_ty m_Power2() { - return cst_pred_ty(); -} -inline api_pred_ty m_Power2(const APInt *&V) { - return V; -} +inline cst_pred_ty m_Power2() { return cst_pred_ty(); } +inline api_pred_ty m_Power2(const APInt *&V) { return V; } struct is_negated_power2 { bool isValue(const APInt &C) { return (-C).isPowerOf2(); } @@ -578,9 +662,7 @@ }; /// Match an arbitrary NaN constant. This includes quiet and signalling nans. /// For vectors, this includes constants with undefined elements. -inline cstfp_pred_ty m_NaN() { - return cstfp_pred_ty(); -} +inline cstfp_pred_ty m_NaN() { return cstfp_pred_ty(); } struct is_nonnan { bool isValue(const APFloat &C) { return !C.isNaN(); } @@ -596,9 +678,7 @@ }; /// Match a positive or negative infinity FP constant. /// For vectors, this includes constants with undefined elements. -inline cstfp_pred_ty m_Inf() { - return cstfp_pred_ty(); -} +inline cstfp_pred_ty m_Inf() { return cstfp_pred_ty(); } struct is_noninf { bool isValue(const APFloat &C) { return !C.isInfinity(); } @@ -675,6 +755,13 @@ bind_ty(Class *&V) : VR(V) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; if (auto *CV = dyn_cast(V)) { VR = CV; return true; @@ -694,7 +781,9 @@ /// Match a binary operator, capturing it if we match. inline bind_ty m_BinOp(BinaryOperator *&I) { return I; } /// Match a with overflow intrinsic, capturing it if we match. -inline bind_ty m_WithOverflowInst(WithOverflowInst *&I) { return I; } +inline bind_ty m_WithOverflowInst(WithOverflowInst *&I) { + return I; +} /// Match a ConstantInt, capturing the value if we match. inline bind_ty m_ConstantInt(ConstantInt *&CI) { return CI; } @@ -717,7 +806,15 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return V == Val; + } }; /// Match if we have a specific specified value. @@ -730,7 +827,15 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *const V, MatcherContext &MContext) { + return V == Val; + } }; /// A commutative-friendly version of m_Specific(). @@ -747,6 +852,12 @@ specific_fpval(double V) : Val(V) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -770,6 +881,14 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -781,13 +900,20 @@ /// Match a specified integer value or vector of all elements of that /// value. -template -struct specific_intval { +template struct specific_intval { APInt Val; specific_intval(APInt V) : Val(std::move(V)) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -826,6 +952,14 @@ specific_bbval(BasicBlock *Val) : Val(Val) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; const auto *BB = dyn_cast(V); return BB && BB == Val; } @@ -858,10 +992,28 @@ AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + auto *I = trait_dyn_cast(V); + if (!I) + return false; + + if (!MContext.accept(I)) + return false; + + MatcherContext LRContext(MContext); + if (L.match(I->getOperand(0), LRContext) && + R.match(I->getOperand(1), LRContext)) { + MContext = LRContext; + return true; + } + if (Commutable && (L.match(I->getOperand(1), MContext) && + R.match(I->getOperand(0), MContext))) + return true; return false; } }; @@ -881,8 +1033,14 @@ AnyUnaryOp_match(const OP_t &X) : X(X) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return X.match(I->getOperand(0)); + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *I = trait_dyn_cast(V)) + return X.match(I->getOperand(0), MContext); return false; } }; @@ -906,11 +1064,22 @@ BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + if (try_match(I->getOperand(0), L, MContext) && + try_match(I->getOperand(1), R, MContext)) { + return true; + } + return Commutable && (L.match(I->getOperand(1), MContext) && + R.match(I->getOperand(0), MContext)); } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && @@ -950,24 +1119,36 @@ FNeg_match(const Op_t &Op) : X(Op) {} template bool match(OpTy *V) { - auto *FPMO = dyn_cast(V); - if (!FPMO) return false; + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *FPMO = trait_dyn_cast(V); + if (!FPMO) + return false; - if (FPMO->getOpcode() == Instruction::FNeg) - return X.match(FPMO->getOperand(0)); + auto OPC = trait_cast(V)->getOpcode(); - if (FPMO->getOpcode() == Instruction::FSub) { + if (OPC == Instruction::FNeg) + return X.match(FPMO->getOperand(0), MContext); + + if (OPC == Instruction::FSub) { if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!try_match(FPMO->getOperand(0), cstfp_pred_ty(), + MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!try_match(FPMO->getOperand(0), cstfp_pred_ty(), + MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match(FPMO->getOperand(1), MContext); } return false; @@ -975,9 +1156,7 @@ }; /// Match 'fneg X' as 'fsub -0.0, X'. -template -inline FNeg_match -m_FNeg(const OpTy &X) { +template inline FNeg_match m_FNeg(const OpTy &X) { return FNeg_match(X); } @@ -1082,7 +1261,14 @@ : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *Op = dyn_cast(V)) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *Op = trait_dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; if (WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap && @@ -1091,7 +1277,8 @@ if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match(Op->getOperand(0), MContext) && + R.match(Op->getOperand(1), MContext); } return false; } @@ -1102,32 +1289,32 @@ OverflowingBinaryOperator::NoSignedWrap> m_NSWAdd(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template inline OverflowingBinaryOp_match m_NSWSub(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template inline OverflowingBinaryOp_match m_NSWMul(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template inline OverflowingBinaryOp_match m_NSWShl(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template @@ -1174,10 +1361,18 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); - if (auto *CE = dyn_cast(V)) + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && + L.match(I->getOperand(0), MContext) && + R.match(I->getOperand(1), MContext); + if (auto *CE = trait_dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); return false; @@ -1269,8 +1464,15 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *PEO = trait_dyn_cast(V)) + return PEO->isExact() && SubPattern.match(V, MContext); return false; } }; @@ -1296,12 +1498,27 @@ : Predicate(Pred), L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) { + MatcherContext LRContext(MContext); + if (L.match(I->getOperand(0), LRContext) && + R.match(I->getOperand(1), LRContext)) { + MContext = LRContext; Predicate = I->getPredicate(); return true; - } else if (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))) { + } + + if (!Commutable) + return false; + + if (L.match(I->getOperand(1), MContext) && + R.match(I->getOperand(0), MContext)) { Predicate = I->getSwappedPredicate(); return true; } @@ -1339,9 +1556,16 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext); } return false; } @@ -1355,9 +1579,17 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext); } return false; } @@ -1374,10 +1606,18 @@ : Op1(Op1), Op2(Op2), Op3(Op3) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext) && + Op3.match(I->getOperand(2), MContext); } return false; } @@ -1430,8 +1670,16 @@ : Op1(Op1), Op2(Op2), Mask(Mask) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext) && Mask.match(I->getShuffleMask()); } return false; @@ -1509,8 +1757,15 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &Matcher) { + if (!Matcher.accept(V)) + return false; + if (auto O = trait_dyn_cast(V)) + return O->getOpcode() == Opcode && Op.match(O->getOperand(0), Matcher); return false; } }; @@ -1625,7 +1880,14 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *BI = trait_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1646,9 +1908,16 @@ : Cond(C), T(t), F(f) {} template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) - return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *BI = trait_dyn_cast(V)) + if (BI->isConditional() && Cond.match(BI->getCondition(), MContext)) { + return T.match(BI->getSuccessor(0), MContext) && + F.match(BI->getSuccessor(1), MContext); + } return false; } }; @@ -1681,23 +1950,34 @@ MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *II = dyn_cast(V)) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *II = trait_dyn_cast(V)) { Intrinsic::ID IID = II->getIntrinsicID(); if ((IID == Intrinsic::smax && Pred_t::match(ICmpInst::ICMP_SGT)) || (IID == Intrinsic::smin && Pred_t::match(ICmpInst::ICMP_SLT)) || (IID == Intrinsic::umax && Pred_t::match(ICmpInst::ICMP_UGT)) || (IID == Intrinsic::umin && Pred_t::match(ICmpInst::ICMP_ULT))) { Value *LHS = II->getOperand(0), *RHS = II->getOperand(1); - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatcherContext LRContext(MContext); + if (L.match(LHS, LRContext) && R.match(RHS, LRContext)) { + MContext = LRContext; + return true; + } + return Commutable && L.match(RHS, MContext) && R.match(LHS, MContext); } } // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = trait_dyn_cast(V); + if (!SI || !MContext.accept(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = trait_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.accept(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1713,9 +1993,15 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + // TODO factor out commutative matching! + MatcherContext LRContext(MContext); + if (L.match(LHS, LRContext) && R.match(RHS, LRContext)) { + MContext = LRContext; + return true; + } + return Commutable && L.match(RHS, MContext) && R.match(LHS, MContext); } }; @@ -1885,49 +2171,77 @@ : L(L), R(R), S(S) {} template bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; - if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) + if (!try_match(V, m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)), + MContext)) return false; Value *AddLHS, *AddRHS; auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS)); // (a + b) u< a, (a + b) u< b - if (Pred == ICmpInst::ICMP_ULT) - if (AddExpr.match(ICmpLHS) && (ICmpRHS == AddLHS || ICmpRHS == AddRHS)) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); + if (Pred == ICmpInst::ICMP_ULT) { + if (try_match(ICmpLHS, AddExpr, MContext) && + (ICmpRHS == AddLHS || ICmpRHS == AddRHS)) { + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpLHS, MContext); + } + } // a >u (a + b), b >u (a + b) - if (Pred == ICmpInst::ICMP_UGT) - if (AddExpr.match(ICmpRHS) && (ICmpLHS == AddLHS || ICmpLHS == AddRHS)) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); + if (Pred == ICmpInst::ICMP_UGT) { + if (try_match(ICmpRHS, AddExpr, MContext) && + (ICmpLHS == AddLHS || ICmpLHS == AddRHS)) { + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpRHS, MContext); + } + } Value *Op1; auto XorExpr = m_OneUse(m_Xor(m_Value(Op1), m_AllOnes())); // (a ^ -1) u (a ^ -1) if (Pred == ICmpInst::ICMP_UGT) { - if (XorExpr.match(ICmpRHS)) - return L.match(Op1) && R.match(ICmpLHS) && S.match(ICmpRHS); + if (try_match(ICmpRHS, XorExpr, MContext)) { + return L.match(Op1, MContext) && R.match(ICmpLHS, MContext) && + S.match(ICmpRHS, MContext); + } } // Match special-case for increment-by-1. if (Pred == ICmpInst::ICMP_EQ) { // (a + 1) == 0 // (1 + a) == 0 - if (AddExpr.match(ICmpLHS) && m_ZeroInt().match(ICmpRHS) && - (m_One().match(AddLHS) || m_One().match(AddRHS))) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); + MatcherContext CopyCtx(MContext); + if (AddExpr.match(ICmpLHS, CopyCtx) && + m_ZeroInt().match(ICmpRHS, CopyCtx)) { + if (try_match(AddLHS, m_One(), CopyCtx) || + try_match(AddRHS, m_One(), CopyCtx)) { + MContext = CopyCtx; + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpLHS, MContext); + } + } // 0 == (a + 1) // 0 == (1 + a) - if (m_ZeroInt().match(ICmpLHS) && AddExpr.match(ICmpRHS) && - (m_One().match(AddLHS) || m_One().match(AddRHS))) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); + if (m_ZeroInt().match(ICmpLHS, MContext) && + AddExpr.match(ICmpRHS, MContext) && + (try_match(AddLHS, m_One(), MContext) || + m_One().match(AddRHS, MContext))) + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpRHS, MContext); } return false; @@ -1951,9 +2265,14 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} template bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) - return Val.match(CI->getArgOperand(OpI)); + if (const auto *CI = trait_dyn_cast(V)) + return Val.match(CI->getArgOperand(OpI), MContext); return false; } }; @@ -1971,7 +2290,12 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (const auto *CI = trait_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -1996,15 +2320,13 @@ }; template struct m_Intrinsic_Ty { - using Ty = - match_combine_and::Ty, - Argument_match>; + using Ty = match_combine_and::Ty, + Argument_match>; }; template struct m_Intrinsic_Ty { - using Ty = - match_combine_and::Ty, - Argument_match>; + using Ty = match_combine_and::Ty, + Argument_match>; }; template @@ -2013,7 +2335,8 @@ Argument_match>; }; -template +template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; @@ -2244,6 +2567,11 @@ Signum_match(const Opnd_t &V) : Val(V) {} template bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; @@ -2265,7 +2593,7 @@ auto RHS = m_LShr(m_Neg(m_Value(OpR)), m_SpecificInt(ShiftWidth)); auto Signum = m_Or(LHS, RHS); - return Signum.match(V) && OpL == OpR && Val.match(OpL); + return Signum.match(V, MContext) && OpL == OpR && Val.match(OpL, MContext); } }; @@ -2284,9 +2612,14 @@ ExtractValue_match(const Opnd_t &V) : Val(V) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *I = trait_dyn_cast(V)) return I->getNumIndices() == 1 && I->getIndices()[0] == Ind && - Val.match(I->getAggregateOperand()); + Val.match(I->getAggregateOperand(), MContext); return false; } }; @@ -2306,9 +2639,16 @@ InsertValue_match(const T0 &Op0, const T1 &Op1) : Op0(Op0), Op1(Op1) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - return Op0.match(I->getOperand(0)) && Op1.match(I->getOperand(1)) && - I->getNumIndices() == 1 && Ind == I->getIndices()[0]; + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &Matcher) { + if (auto *I = trait_dyn_cast(V)) { + return I->getNumIndices() == 1 && Ind == I->getIndices()[0] && + Op0.match(I->getOperand(0), Matcher) && + Op1.match(I->getOperand(1), Matcher); } return false; } @@ -2338,11 +2678,17 @@ VScaleVal_match(const DataLayout &DL) : DL(DL) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + template + bool match(ITy *V, MatcherContext &MContext) { if (m_Intrinsic().match(V)) return true; - if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))).match(V)) { - Type *PtrTy = cast(V)->getOperand(0)->getType(); + if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))) + .match(V, MContext)) { + Type *PtrTy = trait_cast(V)->getOperand(0)->getType(); auto *DerefTy = PtrTy->getPointerElementType(); if (isa(DerefTy) && DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8) diff --git a/llvm/include/llvm/IR/Traits/EnabledTraits.def b/llvm/include/llvm/IR/Traits/EnabledTraits.def new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/EnabledTraits.def @@ -0,0 +1,4 @@ +ENABLE_TRAIT(EmptyTrait) +ENABLE_TRAIT(CFPTrait) +ENABLE_TRAIT(VPTrait) +#undef ENABLE_TRAIT diff --git a/llvm/include/llvm/IR/Traits/SemanticTrait.h b/llvm/include/llvm/IR/Traits/SemanticTrait.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/SemanticTrait.h @@ -0,0 +1,149 @@ +//===- llvm/IR/Trait/SemanticTrait.h - Basic trait definitions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines semantic traits. +// Some intrinsics in LLVM can be described as being a regular operation (such +// as an fadd) at their core with an additional semantic trait. We do this to +// lift optimizations that are defined in terms of standard IR operations (eg +// fadd, fmul) to these intrinsics. We keep the existing patterns and rewrite +// machinery and transparently check the rewrite is consistent with the +// semantical trait(s) that are attached to the operations. +// +// For example: +// +// @llvm.vp.fadd(<256 x double> %a, <256 x double> %b, +// <256 x i1> %m. i32 %avl) +// +// This is a vector-predicated fadd instruction with a %mask and %evl +// parameter +// (https://llvm.org/docs/LangRef.html#vector-predication-intrinsics). +// However, at its core it is just an 'fadd'. +// +// Consider the fma-fusion rewrite pattern (fadd (fmul x,y) z) --> (fma x, y, +// z). If the 'fadd' is actually an 'llvm.vp.fadd" and the 'fmul' is actually +// an 'llvm.vp.fmul', we can perform the rewrite using the %mask and %evl of +// the 'fadd' node. +// +// +// @llvm.experimental.constrained.fadd(double %a, double %b, +// metadata metadata, +// ) +// +// This is an fadd with a possibly non-default rounding mode and exception +// behavior. +// (https://llvm.org/docs/LangRef.html#constrained-floating-point-intrinsics). +// In this case, the operation matches the semantics of a regular 'fadd' +// exactly, if the rounding mode is 'round.tonearest' and the exception +// behavior is 'fpexcept.ignore'. +// Re-considering the case of fma fusion, this time with two constrained fp +// intrinsics. If the rounding mode is tonearest for both and neither of the +// 'llvm.experimental.contrained.fmul' has 'fpexcept.strict',, we are good to +// apply the rewrite and emit a contrained fma with the exception flad of the +// 'fadd'. +// +// There is also a proposal to add complex arithmetic intrinsics to LLVM. In +// that case, the operation is semantically an 'fadd', if we consider the space +// of complex floating-point numbers and their operations. +// +//===----------------------------------------------------------------------===// + +// Look for comments starting with "TODO(new trait)" to see what to implement to +// establish a new instruction trait. + +#ifndef LLVM_IR_TRAIT_SEMANTICTRAIT_H +#define LLVM_IR_TRAIT_SEMANTICTRAIT_H + +#include +#include +#include +#include + +namespace llvm { + +/// Type Casting { +/// These cast operators allow you to side-step the first-class type hierarchy +/// of LLVM (Value, Instruction, BinaryOperator, ..) into your custom type +/// hierarchy. +/// +/// trait_cast(V) +/// +/// actually casts \p V to ExtInstruction. +template struct TraitCast { + using ExtType = ValueDerivedType; +}; + +// This has to happen after all traits are defined since we are referring to +// members of template specialization for each Trait (The TraitCast::ExtType). +#define CASTING_TEMPLATE(CASTFUNC, PREFIXMOD, REFMOD) \ + template \ + static auto trait_##CASTFUNC(PREFIXMOD Value REFMOD V) \ + ->decltype( \ + CASTFUNC::ExtType>(V)) { \ + using TraitExtendedType = \ + typename TraitCast::ExtType; \ + return CASTFUNC(V); \ + } + +#define CONST_CAST_TEMPLATE(CASTFUNC, REFMOD) \ + CASTING_TEMPLATE(CASTFUNC, const, REFMOD) \ + CASTING_TEMPLATE(CASTFUNC, , REFMOD) + +// 'dyn_cast' (allow [const] Value*) +CONST_CAST_TEMPLATE(dyn_cast, *) + +// 'cast' (allow [const] Value(*|&)) +CONST_CAST_TEMPLATE(cast, *) +CONST_CAST_TEMPLATE(cast, &) + +// 'isa' +CONST_CAST_TEMPLATE(isa, *) +CONST_CAST_TEMPLATE(isa, &) +/// } Type Casting + +// TODO +// The trait builder is a specialized IRBuilder that emits trait-compatible +// instructions. +template struct TraitBuilder {}; + +// This is used in pattern matching to check that all instructions in the +// pattern are trait-compatible. +template struct MatcherContext { + // Check whether \p Val is compatible with this context and merge its + // properties. \returns Whether \p Val is compatible with the current state of + // the context. + bool accept(const Value *Val) { return Val; } + + // Like accept() but does not modify the context. + bool check(const Value *Val) const { return Val; } + + // Whether to allow constant folding with the currently accepted operators and + // their operands. + bool allowConstantFolding() const { return true; } +}; + +/// Empty Trait { +/// +/// This defined the empty trait without properties. Type casting stays in the +/// standard llvm::Value type hierarchy. + +// Trait without any difference to standard IR +struct EmptyTrait { + // This is to block reassociation for traits that do not support it. + static constexpr bool AllowReassociation = true; + + // Whether \p V should be considered at all with this trait. + static bool consider(const Value *) { return true; } +}; + +using DefaultTrait = EmptyTrait; + +/// } Empty Trait + +} // end namespace llvm + +#endif // LLVM_IR_TRAIT_SEMANTICTRAIT_H diff --git a/llvm/include/llvm/IR/Traits/Traits.h b/llvm/include/llvm/IR/Traits/Traits.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/Traits.h @@ -0,0 +1,330 @@ +//===- llvm/IR/Trait/SemanticTrait.h - Basic trait definitions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines semantic traits. +// Some intrinsics in LLVM can be described as being a regular operation (such +// as an fadd) at their core with an additional semantic trait. We do this to +// lift optimizations that are defined in terms of standard IR operations (eg +// fadd, fmul) to these intrinsics. We keep the existing patterns and rewrite +// machinery and transparently check the rewrite is consistent with the +// semantical trait(s) that are attached to the operations. +// +// For example: +// +// @llvm.vp.fadd(<256 x double> %a, <256 x double> %b, +// <256 x i1> %m. i32 %avl) +// +// This is a vector-predicated fadd instruction with a %mask and %evl +// parameter +// (https://llvm.org/docs/LangRef.html#vector-predication-intrinsics). +// However, at its core it is just an 'fadd'. +// +// Consider the simplification (add (sub x,y), y) --> x. If the 'add' is +// actually an 'llvm.vp.add" and the 'sub' is really an 'llvm.vp.sub', we can +// do the simplification. the 'fadd' node. +// +// +// @llvm.experimental.constrained.fadd(double %a, double %b, +// metadata metadata, +// ) +// +// This is an fadd with a possibly non-default rounding mode and exception +// behavior. +// (https://llvm.org/docs/LangRef.html#constrained-floating-point-intrinsics). +// The constrained fp intrinsic has exactly the semantics of a regular 'fadd', +// if the rounding mode is 'round.tonearest' and the exception behavior is +// 'fpexcept.ignore'. +// We can use all simplifying rewrites for regular fp arithmetic also for +// constrained fp arithmetic where this applies. +// +// There is also a proposal to add complex arithmetic intrinsics to LLVM. In +// that case, the operation is semantically an 'fadd', if we consider the space +// of complex floating-point numbers and their operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Traits/SemanticTrait.h" + +#ifndef LLVM_IR_TRAITS_TRAITS_H +#define LLVM_IR_TRAITS_TRAITS_H + +namespace llvm { + +/// Base classes { +// Shared functionality for extended instructions +// These define all functions used in PatterMatch as 'virtual' to remind you to +// implement them. + +// Make sure no member of the Ext.. hierarchy can be constructed. +struct ExtBase { + ExtBase() = delete; + ~ExtBase() = delete; + ExtBase &operator=(const ExtBase &) = delete; + void *operator new(size_t s) = delete; +}; + +// Mirror generic functionality of llvm::Instruction. +struct ExtInstructionBase : public ExtBase, public User { + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } +}; + +/// } Base classes + +/// Intrinsic-Trait Template { +// Generic instantiation for traits that want to masquerade their intrinsic as a +// regular IR instruction. + +// Pretend to be a llvm::Operator. +template struct ExtOperator : public ExtBase, public User { + unsigned getOpcode() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) { + return Intrin->getFunctionalOpcode(); + } + // Default Operator opcode. + return cast(this)->getOpcode(); + } + + bool hasNoSignedWrap() const { + if (const auto *OverflowingBOp = dyn_cast(this)) + return OverflowingBOp->hasNoSignedWrap(); + return false; + } + + bool hasNoUnsignedWrap() const { + if (const auto *OverflowingBOp = dyn_cast(this)) + return OverflowingBOp->hasNoUnsignedWrap(); + return false; + } + + // Every operator is also an extended operator. + static bool classof(const Value *V) { return isa(V); } +}; + +// Pretend to be a llvm::Instruction. +template +struct ExtInstruction final : public ExtInstructionBase { + unsigned getOpcode() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) { + return Intrin->getFunctionalOpcode(); + } + // Default opcode. + return cast(this)->getOpcode(); + } + + static bool classof(const Value *V) { return isa(V); } +}; + +// Pretend to be a (different) llvm::IntrinsicInst. +template +struct ExtIntrinsic final : public ExtInstructionBase { + Intrinsic::ID getIntrinsicID() const { + // Use the intrinsic override. + if (const auto *TraitIntrin = dyn_cast(this)) + return TraitIntrin->getFunctionalIntrinsic(); + // Default intrinsic opcode. + return cast(this)->getIntrinsicID(); + } + + unsigned getOpcode() const { + // We are looking at this as an intrinsic -> do not hide this. + return cast(this)->getOpcode(); + } + + bool isCommutative() const { + // The underlying intrinsic may not specify whether it is commutative. + // Query our own interface to be sure this is done right. + // Use the intrinsic override. + if (const auto *TraitIntrin = dyn_cast(this)) + return TraitIntrin->isFunctionalCommutative(); + return cast(this)->isFunctionalCommutative(); + } + + static bool classof(const Value *V) { return IntrinsicInst::classof(V); } +}; + +template +struct ExtCmpBase : public ExtInstructionBase { + unsigned getOpcode() const { return OPC; } + + FCmpInst::Predicate getPredicate() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) { + return Intrin->getPredicate(); + } + + // Default opcode. + return cast(this)->getPredicate(); + } +}; + +template +static bool classofExtCmpBase(const Value *V) { + return isa(V) || + cast->getFunctionalOpcode() == OPC; +} + +// Pretend to be a llvm::FCmpInst. +template +struct ExtFCmpInst final + : public ExtCmpBase { + static bool classof(const Value *V) { + return classofExtCmpBase(V); + } +}; + +// Pretend to be a llvm::ICmpInst. +template +struct ExtICmpInst final + : public ExtCmpBase { + static bool classof(const Value *V) { + return classofExtCmpBase(V); + } +}; + +// Pretend to be a BinaryOperator. +template +struct ExtBinaryOperator final : public ExtOperator { + using BinaryOps = Instruction::BinaryOps; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + const auto *Intrin = dyn_cast(I); + return Intrin && Intrin->isFunctionalBinaryOp(); + } + static bool classof(const ConstantExpr *CE) { + return isa(CE); + } + static bool classof(const Value *V) { + if (const auto *I = dyn_cast(V)) + return classof(I); + + if (const auto *CE = dyn_cast(V)) + return classof(CE); + + return false; + } +}; + +// Pretend to be a UnaryOperator. +template +struct ExtUnaryOperator final : public ExtOperator { + using BinaryOps = Instruction::BinaryOps; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + const auto *Intrin = dyn_cast(I); + return Intrin && Intrin->isFunctionalUnaryOp(); + } +}; + +// TODO Implement other extended types. + +/// Template-specialization for the Ext type hierarchy { +//// Enable the ExtSOMETHING for your trait +#define INTRINSIC_TRAIT_SPECIALIZE(TRAIT, TYPE) \ + template <> struct TraitCast { \ + using ExtType = Ext##TYPE; \ + }; \ + template <> struct TraitCast { \ + using ExtType = const Ext##TYPE; \ + }; + +/// } Trait Template Classes + +// Constraint fp trait. +struct CFPTrait { + using Intrinsic = ConstrainedFPIntrinsic; + static constexpr bool AllowReassociation = false; + + // Whether \p V should be considered at all with this trait. + // It is not possible to mix constrained and unconstrained ops. + // Only apply this trait with the constrained variant. + static bool consider(const Value *V) { + return isa(V); + } +}; +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, Instruction) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, Operator) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, BinaryOperator) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, UnaryOperator) +// Deflect queries for the Predicate to the ConstrainedFPCmpIntrinsic. +template <> struct ExtFCmpInst : public ExtInstructionBase { + unsigned getOpcode() const { return Instruction::FCmp; } + + FCmpInst::Predicate getPredicate() const { + return cast(this)->getPredicate(); + } + + bool classof(const Value *V) { return isa(V); } +}; +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, FCmpInst) + +// Accept all constrained fp intrinsics that are actually not constrained. +template <> struct MatcherContext { + bool accept(const Value *Val) { return check(Val); } + bool check(const Value *Val) const { + if (!Val) + return false; + const auto *CFP = dyn_cast(Val); + if (!CFP) + return true; + auto RoundingOpt = CFP->hasRoundingMode() ? CFP->getRoundingMode() : None; + auto ExceptOpt = CFP->getExceptionBehavior(); + return (!ExceptOpt || ExceptOpt == fp::ExceptionBehavior::ebIgnore) && + (!RoundingOpt || (RoundingOpt == RoundingMode::NearestTiesToEven)); + } +}; + +// Vector-predicated trait. +struct VPTrait { + using Intrinsic = VPIntrinsic; + // TODO Enable re-association. + static constexpr bool AllowReassociation = false; + // VP intrinsic mix with regular IR instructions. + // TODO: Adapt this to work with other than arithmetic VP ops. + static bool consider(const Value *V) { + return V->getType()->isVectorTy() && + V->getType()->getScalarType()->isIntegerTy(); + } +}; +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, Instruction) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, Operator) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, BinaryOperator) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, UnaryOperator) + +// Accept everything that passes as a VPIntrinsic. +template <> struct MatcherContext { + // TODO: pick up %mask and %evl here and use them to generate code again. We + // only remove instructions for the moment. + bool accept(const Value *Val) { return Val; } + bool check(const Value *Val) const { return Val; } +}; + +} // namespace llvm + +#undef INTRINSIC_TRAIT_SPECIALIZE + +#endif // LLVM_IR_TRAITS_TRAITS_H diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Traits/Traits.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include @@ -47,25 +48,32 @@ enum { RecursionLimit = 3 }; -STATISTIC(NumExpand, "Number of expansions"); +STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); -static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, unsigned); +static Value *SimplifyAndInst(Value *, Value *, const SimplifyQuery &, + unsigned); static Value *simplifyUnOp(unsigned, Value *, const SimplifyQuery &, unsigned); static Value *simplifyFPUnOp(unsigned, Value *, const FastMathFlags &, const SimplifyQuery &, unsigned); static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); +template +static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, + MatcherContext &, unsigned); +template static Value *SimplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &, - const SimplifyQuery &, unsigned); + const SimplifyQuery &, MatcherContext &, + unsigned); static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse); static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned); -static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned); -static Value *SimplifyCastInst(unsigned, Value *, Type *, - const SimplifyQuery &, unsigned); +static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, + unsigned); +static Value *SimplifyCastInst(unsigned, Value *, Type *, const SimplifyQuery &, + unsigned); static Value *SimplifyGEPInst(Type *, ArrayRef, const SimplifyQuery &, unsigned); @@ -113,15 +121,11 @@ /// For a boolean type or a vector of boolean type, return false or a vector /// with every element false. -static Constant *getFalse(Type *Ty) { - return ConstantInt::getFalse(Ty); -} +static Constant *getFalse(Type *Ty) { return ConstantInt::getFalse(Ty); } /// For a boolean type or a vector of boolean type, return true or a vector /// with every element true. -static Constant *getTrue(Type *Ty) { - return ConstantInt::getTrue(Ty); -} +static Constant *getTrue(Type *Ty) { return ConstantInt::getTrue(Ty); } /// isSameCompare - Is V equivalent to the comparison "LHS Pred RHS"? static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS, @@ -134,7 +138,7 @@ if (CPred == Pred && CLHS == LHS && CRHS == RHS) return true; return CPred == CmpInst::getSwappedPredicate(Pred) && CLHS == RHS && - CRHS == LHS; + CRHS == LHS; } /// Simplify comparison with true or false branch of select: @@ -238,12 +242,12 @@ if (!B || B->getOpcode() != OpcodeToExpand) return nullptr; Value *B0 = B->getOperand(0), *B1 = B->getOperand(1); - Value *L = SimplifyBinOp(Opcode, B0, OtherOp, Q.getWithoutUndef(), - MaxRecurse); + Value *L = + SimplifyBinOp(Opcode, B0, OtherOp, Q.getWithoutUndef(), MaxRecurse); if (!L) return nullptr; - Value *R = SimplifyBinOp(Opcode, B1, OtherOp, Q.getWithoutUndef(), - MaxRecurse); + Value *R = + SimplifyBinOp(Opcode, B1, OtherOp, Q.getWithoutUndef(), MaxRecurse); if (!R) return nullptr; @@ -265,8 +269,8 @@ /// Try to simplify binops of form "A op (B op' C)" or the commuted variant by /// distributing op over op'. -static Value *expandCommutativeBinOp(Instruction::BinaryOps Opcode, - Value *L, Value *R, +static Value *expandCommutativeBinOp(Instruction::BinaryOps Opcode, Value *L, + Value *R, Instruction::BinaryOps OpcodeToExpand, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -283,32 +287,47 @@ /// Generic simplifications for associative binary operations. /// Returns the simpler value, or null if none was found. -static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, - Value *LHS, Value *RHS, - const SimplifyQuery &Q, - unsigned MaxRecurse) { +template +static Value * +SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q, MatcherContext &Matcher, + unsigned MaxRecurse) { + // This trait blocks re-association. + // Eg. any trait that adds side-effects may clash with free reassociation + // (FIXME are 'fpexcept.strict', 'fast' fp ops a thing?) + // FIXME associativity may depend on a trait parameter of this specific + // instance. + if (!Trait::AllowReassociation) + return nullptr; + assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; - BinaryOperator *Op0 = dyn_cast(LHS); - BinaryOperator *Op1 = dyn_cast(RHS); + auto *Op0 = trait_dyn_cast(LHS); + auto *Op1 = trait_dyn_cast(RHS); // Transform: "(A op B) op C" ==> "A op (B op C)" if it simplifies completely. - if (Op0 && Op0->getOpcode() == Opcode) { + MatcherContext Op0Matcher(Matcher); + if (Op0Matcher.accept(Op0) && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); Value *B = Op0->getOperand(1); Value *C = RHS; // Does "B op C" simplify? - if (Value *V = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { + if (Value *V = + SimplifyBinOp(Opcode, B, C, Q, Op0Matcher, MaxRecurse)) { // It does! Return "A op V" if it simplifies or is already available. // If V equals B then "A op V" is just the LHS. - if (V == B) return LHS; + if (V == B) { + Matcher = Op0Matcher; + return LHS; + } // Otherwise return "A op V" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, A, V, Q, MaxRecurse)) { + if (Value *W = SimplifyBinOp(Opcode, A, V, Q, Op0Matcher, MaxRecurse)) { + Matcher = Op0Matcher; ++NumReassoc; return W; } @@ -316,18 +335,23 @@ } // Transform: "A op (B op C)" ==> "(A op B) op C" if it simplifies completely. - if (Op1 && Op1->getOpcode() == Opcode) { + MatcherContext Op1Matcher(Matcher); + if (Op1Matcher.accept(Op1) && Op1->getOpcode() == Opcode) { Value *A = LHS; Value *B = Op1->getOperand(0); Value *C = Op1->getOperand(1); // Does "A op B" simplify? - if (Value *V = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) { + if (Value *V = SimplifyBinOp(Opcode, A, B, Q, Op1Matcher, MaxRecurse)) { // It does! Return "V op C" if it simplifies or is already available. // If V equals B then "V op C" is just the RHS. - if (V == B) return RHS; + if (V == B) { + Matcher = Op1Matcher; + return RHS; + } // Otherwise return "V op C" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, V, C, Q, MaxRecurse)) { + if (Value *W = SimplifyBinOp(Opcode, V, C, Q, Op1Matcher, MaxRecurse)) { + Matcher = Op1Matcher; ++NumReassoc; return W; } @@ -335,22 +359,30 @@ } // The remaining transforms require commutativity as well as associativity. + // FIXME commutativity may depend on a trait parameter of this specific + // instance. Eg, matrix multiplication is associative but not commutative. if (!Instruction::isCommutative(Opcode)) return nullptr; // Transform: "(A op B) op C" ==> "(C op A) op B" if it simplifies completely. + MatcherContext CommOp0Matcher(Matcher); if (Op0 && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); Value *B = Op0->getOperand(1); Value *C = RHS; // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + if (Value *V = SimplifyBinOp(Opcode, C, A, Q, CommOp0Matcher, MaxRecurse)) { // It does! Return "V op B" if it simplifies or is already available. // If V equals A then "V op B" is just the LHS. - if (V == A) return LHS; + if (V == A) { + Matcher = CommOp0Matcher; + return LHS; + } // Otherwise return "V op B" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, V, B, Q, MaxRecurse)) { + MatcherContext VContext(Matcher); + if (Value *W = SimplifyBinOp(Opcode, V, B, Q, VContext, MaxRecurse)) { + Matcher = VContext; ++NumReassoc; return W; } @@ -364,12 +396,14 @@ Value *C = Op1->getOperand(1); // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + if (Value *V = SimplifyBinOp(Opcode, C, A, Q, Matcher, MaxRecurse)) { // It does! Return "B op V" if it simplifies or is already available. // If V equals C then "B op V" is just the RHS. - if (V == C) return RHS; + if (V == C) { + return RHS; + } // Otherwise return "B op V" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, B, V, Q, MaxRecurse)) { + if (Value *W = SimplifyBinOp(Opcode, B, V, Q, Matcher, MaxRecurse)) { ++NumReassoc; return W; } @@ -379,6 +413,14 @@ return nullptr; } +static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + MatcherContext MContext; + return SimplifyAssociativeBinOp(Opcode, LHS, RHS, Q, MContext, MaxRecurse); +} + /// In the case of a binary operation with a select instruction as an operand, /// try to simplify the binop by seeing whether evaluating it on both branches /// of the select results in the same value. Returns the common value if so, @@ -532,10 +574,10 @@ Value *CommonValue = nullptr; for (Value *Incoming : PI->incoming_values()) { // If the incoming value is the phi node itself, it can safely be skipped. - if (Incoming == PI) continue; - Value *V = PI == LHS ? - SimplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) : - SimplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse); + if (Incoming == PI) + continue; + Value *V = PI == LHS ? SimplifyBinOp(Opcode, Incoming, RHS, Q, MaxRecurse) + : SimplifyBinOp(Opcode, LHS, Incoming, Q, MaxRecurse); // If the operation failed to simplify, or simplified to a different value // to previously, then give up. if (!V || (CommonValue && V != CommonValue)) @@ -574,7 +616,8 @@ Value *Incoming = PI->getIncomingValue(u); Instruction *InTI = PI->getIncomingBlock(u)->getTerminator(); // If the incoming value is the phi node itself, it can safely be skipped. - if (Incoming == PI) continue; + if (Incoming == PI) + continue; // Change the context instruction to the "edge" that flows into the phi. // This is important because that is where incoming is actually "evaluated" // even though it is used later somewhere else. @@ -605,9 +648,12 @@ } /// Given operands for an Add, see if we can fold the result. -/// If not, this returns null. +/// If not, thi:s returns null. +template static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, + MatcherContext &Matcher, + unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) return C; @@ -616,7 +662,7 @@ return Op1; // X + 0 -> X - if (match(Op1, m_Zero())) + if (try_match(Op1, m_Zero(), Matcher)) return Op0; // If two operands are negative, return 0. @@ -627,35 +673,38 @@ // (Y - X) + X -> Y // Eg: X + -X -> 0 Value *Y = nullptr; - if (match(Op1, m_Sub(m_Value(Y), m_Specific(Op0))) || - match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)))) + if (try_match(Op1, m_Sub(m_Value(Y), m_Specific(Op0)), Matcher) || + try_match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)), Matcher)) return Y; // X + ~X -> -1 since ~X = -X-1 Type *Ty = Op0->getType(); - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (try_match(Op0, m_Not(m_Specific(Op1)), Matcher) || + try_match(Op1, m_Not(m_Specific(Op0)), Matcher)) return Constant::getAllOnesValue(Ty); // add nsw/nuw (xor Y, signmask), signmask --> Y // The no-wrapping add guarantees that the top bit will be set by the add. // Therefore, the xor must be clearing the already set sign bit of Y. - if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) && - match(Op0, m_Xor(m_Value(Y), m_SignMask()))) + MatcherContext CopyMatch(Matcher); + if ((IsNSW || IsNUW) && match(Op1, m_SignMask(), CopyMatch) && + match(Op0, m_Xor(m_Value(Y), m_SignMask()), CopyMatch)) { + Matcher = CopyMatch; return Y; + } // add nuw %x, -1 -> -1, because %x can only be 0. - if (IsNUW && match(Op1, m_AllOnes())) + if (IsNUW && try_match(Op1, m_AllOnes(), Matcher)) return Op1; // Which is -1. /// i1 add -> xor. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) - if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) + if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse - 1)) return V; // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q, - MaxRecurse)) + Matcher, MaxRecurse)) return V; // Threading Add over selects and phi nodes is pointless, so don't bother. @@ -670,9 +719,12 @@ return nullptr; } +template Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, - const SimplifyQuery &Query) { - return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); + const SimplifyQuery &Query, + MatcherContext &Matcher) { + return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, Matcher, + RecursionLimit); } /// Compute the base pointer and cumulative constant offsets for V. @@ -766,17 +818,17 @@ Value *X = nullptr, *Y = nullptr, *Z = Op1; if (MaxRecurse && match(Op0, m_Add(m_Value(X), m_Value(Y)))) { // (X + Y) - Z // See if "V === Y - Z" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, Y, Z, Q, MaxRecurse-1)) + if (Value *V = SimplifyBinOp(Instruction::Sub, Y, Z, Q, MaxRecurse - 1)) // It does! Now see if "X + V" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Add, X, V, Q, MaxRecurse-1)) { + if (Value *W = SimplifyBinOp(Instruction::Add, X, V, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; } // See if "V === X - Z" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse - 1)) // It does! Now see if "Y + V" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Add, Y, V, Q, MaxRecurse-1)) { + if (Value *W = SimplifyBinOp(Instruction::Add, Y, V, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; @@ -788,17 +840,17 @@ X = Op0; if (MaxRecurse && match(Op1, m_Add(m_Value(Y), m_Value(Z)))) { // X - (Y + Z) // See if "V === X - Y" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse - 1)) // It does! Now see if "V - Z" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Sub, V, Z, Q, MaxRecurse-1)) { + if (Value *W = SimplifyBinOp(Instruction::Sub, V, Z, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; } // See if "V === X - Z" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse-1)) + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Z, Q, MaxRecurse - 1)) // It does! Now see if "V - Y" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Sub, V, Y, Q, MaxRecurse-1)) { + if (Value *W = SimplifyBinOp(Instruction::Sub, V, Y, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; @@ -810,9 +862,9 @@ Z = Op0; if (MaxRecurse && match(Op1, m_Sub(m_Value(X), m_Value(Y)))) // Z - (X - Y) // See if "V === Z - X" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, Z, X, Q, MaxRecurse-1)) + if (Value *V = SimplifyBinOp(Instruction::Sub, Z, X, Q, MaxRecurse - 1)) // It does! Now see if "V + Y" simplifies. - if (Value *W = SimplifyBinOp(Instruction::Add, V, Y, Q, MaxRecurse-1)) { + if (Value *W = SimplifyBinOp(Instruction::Add, V, Y, Q, MaxRecurse - 1)) { // It does, we successfully reassociated! ++NumReassoc; return W; @@ -823,7 +875,7 @@ match(Op1, m_Trunc(m_Value(Y)))) if (X->getType() == Y->getType()) // See if "V === X - Y" simplifies. - if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse-1)) + if (Value *V = SimplifyBinOp(Instruction::Sub, X, Y, Q, MaxRecurse - 1)) // It does! Now see if "trunc V" simplifies. if (Value *W = SimplifyCastInst(Instruction::Trunc, V, Op0->getType(), Q, MaxRecurse - 1)) @@ -831,14 +883,13 @@ return W; // Variations on GEP(base, I, ...) - GEP(base, i, ...) -> GEP(null, I-i, ...). - if (match(Op0, m_PtrToInt(m_Value(X))) && - match(Op1, m_PtrToInt(m_Value(Y)))) + if (match(Op0, m_PtrToInt(m_Value(X))) && match(Op1, m_PtrToInt(m_Value(Y)))) if (Constant *Result = computePointerDifference(Q.DL, X, Y)) return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); // i1 sub -> xor. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) - if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) + if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse - 1)) return V; // Threading Sub over selects and phi nodes is pointless, so don't bother. @@ -884,12 +935,12 @@ // i1 mul -> and. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) - if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse-1)) + if (Value *V = SimplifyAndInst(Op0, Op1, Q, MaxRecurse - 1)) return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Mul, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + SimplifyAssociativeBinOp(Instruction::Mul, Op0, Op1, Q, MaxRecurse)) return V; // Mul distributes over Add. Try some generic simplifications based on this. @@ -900,15 +951,15 @@ // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa(Op0) || isa(Op1)) - if (Value *V = ThreadBinOpOverSelect(Instruction::Mul, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + ThreadBinOpOverSelect(Instruction::Mul, Op0, Op1, Q, MaxRecurse)) return V; // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa(Op0) || isa(Op1)) - if (Value *V = ThreadBinOpOverPHI(Instruction::Mul, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + ThreadBinOpOverPHI(Instruction::Mul, Op0, Op1, Q, MaxRecurse)) return V; return nullptr; @@ -1229,7 +1280,8 @@ /// Given operands for an Shl, LShr or AShr, see if we can fold the result. /// If not, this returns null. static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, - Value *Op1, const SimplifyQuery &Q, unsigned MaxRecurse) { + Value *Op1, const SimplifyQuery &Q, + unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; @@ -1279,8 +1331,8 @@ /// Given operands for an Shl, LShr or AShr, see if we can /// fold the result. If not, this returns null. static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, - Value *Op1, bool isExact, const SimplifyQuery &Q, - unsigned MaxRecurse) { + Value *Op1, bool isExact, + const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) return V; @@ -1295,7 +1347,8 @@ // The low bit cannot be shifted out of an exact shift if it is set. if (isExact) { - KnownBits Op0Known = computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + KnownBits Op0Known = + computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); if (Op0Known.One[0]) return Op0; } @@ -1341,7 +1394,7 @@ const SimplifyQuery &Q, unsigned MaxRecurse) { if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, MaxRecurse)) - return V; + return V; // (X << A) >> A -> X Value *X; @@ -1518,7 +1571,7 @@ /// with the parameters swapped. static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - Value *A ,*B; + Value *A, *B; if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) return nullptr; @@ -1543,7 +1596,7 @@ /// with the parameters swapped. static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) { ICmpInst::Predicate Pred0, Pred1; - Value *A ,*B; + Value *A, *B; if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) || !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B)))) return nullptr; @@ -1854,8 +1907,8 @@ return nullptr; } -static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, - FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { +static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, FCmpInst *LHS, + FCmpInst *RHS, bool IsAnd) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); if (LHS0->getType() != RHS0->getType()) @@ -1892,8 +1945,8 @@ return nullptr; } -static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, - Value *Op0, Value *Op1, bool IsAnd) { +static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, Value *Op0, + Value *Op1, bool IsAnd) { // Look through casts of the 'and' operands to find compares. auto *Cast0 = dyn_cast(Op0); auto *Cast1 = dyn_cast(Op1); @@ -2023,8 +2076,7 @@ return Op0; // A & ~A = ~A & A = 0 - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) return Constant::getNullValue(Op0->getType()); // (A | ?) & A = A @@ -2086,8 +2138,8 @@ return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; // And distributes over Or. Try some generic simplifications based on this. @@ -2103,15 +2155,15 @@ // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa(Op0) || isa(Op1)) - if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. if (isa(Op0) || isa(Op1)) - if (Value *V = ThreadBinOpOverPHI(Instruction::And, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + ThreadBinOpOverPHI(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; // Assuming the effective width of Y is not larger than A, i.e. all bits @@ -2134,8 +2186,7 @@ const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros(); if (EffWidthY <= ShftCnt) { - const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, - Q.DT); + const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); const unsigned EffWidthX = Width - XKnown.countMinLeadingZeros(); const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY); const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt; @@ -2174,8 +2225,7 @@ return Op0; // A | ~A = ~A | A = -1 - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) return Constant::getAllOnesValue(Op0->getType()); // (A & ?) | A = A @@ -2244,8 +2294,8 @@ return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, MaxRecurse)) return V; // Or distributes over And. Try some generic simplifications based on this. @@ -2256,8 +2306,8 @@ // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa(Op0) || isa(Op1)) - if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, MaxRecurse)) return V; // (A & C1)|(B & C2) @@ -2277,8 +2327,7 @@ return A; } // Or commutes, try both ways. - if (C1->isMask() && - match(B, m_c_Add(m_Specific(A), m_Value(N)))) { + if (C1->isMask() && match(B, m_c_Add(m_Specific(A), m_Value(N)))) { // Add commutes, try both ways. if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return B; @@ -2319,13 +2368,12 @@ return Constant::getNullValue(Op0->getType()); // A ^ ~A = ~A ^ A = -1 - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) return Constant::getAllOnesValue(Op0->getType()); // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q, - MaxRecurse)) + if (Value *V = + SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q, MaxRecurse)) return V; // Threading Xor over selects and phi nodes is pointless, so don't bother. @@ -2344,7 +2392,6 @@ return ::SimplifyXorInst(Op0, Op1, Q, RecursionLimit); } - static Type *GetCompareTy(Value *Op) { return CmpInst::makeCmpResultType(Op->getType()); } @@ -2410,8 +2457,7 @@ if (isa(RHS) && ICmpInst::isEquality(Pred) && llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr, IIQ.UseInstrInfo)) - return ConstantInt::get(GetCompareTy(LHS), - !CmpInst::isTrueWhenEqual(Pred)); + return ConstantInt::get(GetCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); // We can only fold certain predicates on pointer comparisons. switch (Pred) { @@ -2493,10 +2539,8 @@ getObjectSize(RHS, RHSSize, DL, TLI, Opts)) { const APInt &LHSOffsetValue = LHSOffsetCI->getValue(); const APInt &RHSOffsetValue = RHSOffsetCI->getValue(); - if (!LHSOffsetValue.isNegative() && - !RHSOffsetValue.isNegative() && - LHSOffsetValue.ult(LHSSize) && - RHSOffsetValue.ult(RHSSize)) { + if (!LHSOffsetValue.isNegative() && !RHSOffsetValue.isNegative() && + LHSOffsetValue.ult(LHSSize) && RHSOffsetValue.ult(RHSSize)) { return ConstantInt::get(GetCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); } @@ -2506,8 +2550,7 @@ // or being able to compute a precise size. if (!cast(LHS->getType())->isEmptyTy() && !cast(RHS->getType())->isEmptyTy() && - LHSOffset->isNullValue() && - RHSOffset->isNullValue()) + LHSOffset->isNullValue() && RHSOffset->isNullValue()) return ConstantInt::get(GetCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); } @@ -2559,8 +2602,8 @@ if ((IsNAC(LHSUObjs) && IsAllocDisjoint(RHSUObjs)) || (IsNAC(RHSUObjs) && IsAllocDisjoint(LHSUObjs))) - return ConstantInt::get(GetCompareTy(LHS), - !CmpInst::isTrueWhenEqual(Pred)); + return ConstantInt::get(GetCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); // Fold comparisons for non-escaping pointer even if the allocation call // cannot be elided. We cannot fold malloc comparison to null. Also, the @@ -2609,7 +2652,8 @@ case CmpInst::ICMP_SLE: // X <=s 0 -> true return getTrue(ITy); - default: break; + default: + break; } } else if (match(RHS, m_One())) { switch (Pred) { @@ -2626,7 +2670,8 @@ case CmpInst::ICMP_SGE: // X >=s -1 -> true return getTrue(ITy); - default: break; + default: + break; } } @@ -2768,9 +2813,10 @@ return nullptr; } -static Value *simplifyICmpWithBinOpOnLHS( - CmpInst::Predicate Pred, BinaryOperator *LBO, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse) { +static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, + BinaryOperator *LBO, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { Type *ITy = GetCompareTy(RHS); // The return type. Value *Y = nullptr; @@ -2851,7 +2897,6 @@ return nullptr; } - // If only one of the icmp's operands has NSW flags, try to prove that: // // icmp slt (x + C1), (x +nsw C2) @@ -2886,7 +2931,6 @@ (C2->slt(*C1) && C1->isNonPositive()); } - /// TODO: A large part of this logic is duplicated in InstCombine's /// foldICmpBinOp(). We should be able to share that and avoid the code /// duplication. @@ -3038,7 +3082,7 @@ break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) - return V; + return V; break; case Instruction::SDiv: if (!ICmpInst::isEquality(Pred) || !Q.IIQ.isExact(LBO) || @@ -3272,9 +3316,8 @@ continue; CallInst *Assume = cast(AssumeVH); - if (Optional Imp = - isImpliedCondition(Assume->getArgOperand(0), Predicate, LHS, RHS, - Q.DL)) + if (Optional Imp = isImpliedCondition(Assume->getArgOperand(0), + Predicate, LHS, RHS, Q.DL)) if (isValidAssumeForContext(Assume, Q.CxtI, Q.DT)) return ConstantInt::get(GetCompareTy(LHS), *Imp); } @@ -3340,7 +3383,7 @@ return ConstantInt::getTrue(RHS->getContext()); auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( - CmpInst::getInversePredicate(Pred), RHS_CR); + CmpInst::getInversePredicate(Pred), RHS_CR); if (InversedSatisfied_CR.contains(LHS_CR)) return ConstantInt::getFalse(RHS->getContext()); } @@ -3361,13 +3404,13 @@ // Transfer the cast to the constant. if (Value *V = SimplifyICmpInst(Pred, SrcOp, ConstantExpr::getIntToPtr(RHSC, SrcTy), - Q, MaxRecurse-1)) + Q, MaxRecurse - 1)) return V; } else if (PtrToIntInst *RI = dyn_cast(RHS)) { if (RI->getOperand(0)->getType() == SrcTy) // Compare without the cast. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), Q, + MaxRecurse - 1)) return V; } } @@ -3378,9 +3421,9 @@ if (ZExtInst *RI = dyn_cast(RHS)) { if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) // Compare X and Y. Note that signed predicates become unsigned. - if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, RI->getOperand(0), Q, - MaxRecurse-1)) + if (Value *V = + SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), SrcOp, + RI->getOperand(0), Q, MaxRecurse - 1)) return V; } // Fold (zext X) ule (sext X), (zext X) sge (sext X) to true. @@ -3404,14 +3447,15 @@ // also a case of comparing two zero-extended values. if (RExt == CI && MaxRecurse) if (Value *V = SimplifyICmpInst(ICmpInst::getUnsignedPredicate(Pred), - SrcOp, Trunc, Q, MaxRecurse-1)) + SrcOp, Trunc, Q, MaxRecurse - 1)) return V; // Otherwise the upper bits of LHS are zero while RHS has a non-zero bit // there. Use this to work out the result of the comparison. if (RExt != CI) { switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); + default: + llvm_unreachable("Unknown ICmp predicate!"); // LHS getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getTrue(CI->getContext()) + : ConstantInt::getFalse(CI->getContext()); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getFalse(CI->getContext()) + : ConstantInt::getTrue(CI->getContext()); } } } @@ -3447,8 +3491,8 @@ if (SExtInst *RI = dyn_cast(RHS)) { if (MaxRecurse && SrcTy == RI->getOperand(0)->getType()) // Compare X and Y. Note that the predicate does not change. - if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), - Q, MaxRecurse-1)) + if (Value *V = SimplifyICmpInst(Pred, SrcOp, RI->getOperand(0), Q, + MaxRecurse - 1)) return V; } // Fold (sext X) uge (zext X), (sext X) sle (zext X) to true. @@ -3471,14 +3515,16 @@ // If the re-extended constant didn't change then this is effectively // also a case of comparing two sign-extended values. if (RExt == CI && MaxRecurse) - if (Value *V = SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse-1)) + if (Value *V = + SimplifyICmpInst(Pred, SrcOp, Trunc, Q, MaxRecurse - 1)) return V; // Otherwise the upper bits of LHS are all equal, while RHS has varying // bits there. Use this to work out the result of the comparison. if (RExt != CI) { switch (Pred) { - default: llvm_unreachable("Unknown ICmp predicate!"); + default: + llvm_unreachable("Unknown ICmp predicate!"); case ICmpInst::ICMP_EQ: return ConstantInt::getFalse(CI->getContext()); case ICmpInst::ICMP_NE: @@ -3488,14 +3534,14 @@ // LHS >s RHS. case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - return CI->getValue().isNegative() ? - ConstantInt::getTrue(CI->getContext()) : - ConstantInt::getFalse(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getTrue(CI->getContext()) + : ConstantInt::getFalse(CI->getContext()); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - return CI->getValue().isNegative() ? - ConstantInt::getFalse(CI->getContext()) : - ConstantInt::getTrue(CI->getContext()); + return CI->getValue().isNegative() + ? ConstantInt::getFalse(CI->getContext()) + : ConstantInt::getTrue(CI->getContext()); // If LHS is non-negative then LHS u RHS. @@ -3504,8 +3550,8 @@ // Comparison is true iff the LHS =s 0. if (MaxRecurse) if (Value *V = SimplifyICmpInst(ICmpInst::ICMP_SGE, SrcOp, - Constant::getNullValue(SrcTy), - Q, MaxRecurse-1)) + Constant::getNullValue(SrcTy), Q, + MaxRecurse - 1)) return V; break; } @@ -3732,23 +3778,29 @@ // The ordered relationship and minnum/maxnum guarantee that we do not // have NaN constants, so ordered/unordered preds are handled the same. switch (Pred) { - case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_UEQ: + case FCmpInst::FCMP_OEQ: + case FCmpInst::FCMP_UEQ: // minnum(X, LesserC) == C --> false // maxnum(X, GreaterC) == C --> false return getFalse(RetTy); - case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_ONE: + case FCmpInst::FCMP_UNE: // minnum(X, LesserC) != C --> true // maxnum(X, GreaterC) != C --> true return getTrue(RetTy); - case FCmpInst::FCMP_OGE: case FCmpInst::FCMP_UGE: - case FCmpInst::FCMP_OGT: case FCmpInst::FCMP_UGT: + case FCmpInst::FCMP_OGE: + case FCmpInst::FCMP_UGE: + case FCmpInst::FCMP_OGT: + case FCmpInst::FCMP_UGT: // minnum(X, LesserC) >= C --> false // minnum(X, LesserC) > C --> false // maxnum(X, GreaterC) >= C --> true // maxnum(X, GreaterC) > C --> true return ConstantInt::get(RetTy, IsMaxNum); - case FCmpInst::FCMP_OLE: case FCmpInst::FCMP_ULE: - case FCmpInst::FCMP_OLT: case FCmpInst::FCMP_ULT: + case FCmpInst::FCMP_OLE: + case FCmpInst::FCMP_ULE: + case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_ULT: // minnum(X, LesserC) <= C --> true // minnum(X, LesserC) < C --> true // maxnum(X, GreaterC) <= C --> false @@ -3847,13 +3899,11 @@ if (auto *B = dyn_cast(I)) { if (MaxRecurse) { if (B->getOperand(0) == Op) - return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), RepOp, - B->getOperand(1), Q, - MaxRecurse - 1)); + return PreventSelfSimplify(SimplifyBinOp( + B->getOpcode(), RepOp, B->getOperand(1), Q, MaxRecurse - 1)); if (B->getOperand(1) == Op) - return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), - B->getOperand(0), RepOp, Q, - MaxRecurse - 1)); + return PreventSelfSimplify(SimplifyBinOp( + B->getOpcode(), B->getOperand(0), RepOp, Q, MaxRecurse - 1)); } } @@ -3861,13 +3911,11 @@ if (CmpInst *C = dyn_cast(I)) { if (MaxRecurse) { if (C->getOperand(0) == Op) - return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), RepOp, - C->getOperand(1), Q, - MaxRecurse - 1)); + return PreventSelfSimplify(SimplifyCmpInst( + C->getPredicate(), RepOp, C->getOperand(1), Q, MaxRecurse - 1)); if (C->getOperand(1) == Op) - return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), - C->getOperand(0), RepOp, Q, - MaxRecurse - 1)); + return PreventSelfSimplify(SimplifyCmpInst( + C->getPredicate(), C->getOperand(0), RepOp, Q, MaxRecurse - 1)); } } @@ -3974,7 +4022,8 @@ /// Try to simplify a select instruction when its condition operand is an /// integer comparison. static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, - Value *FalseVal, const SimplifyQuery &Q, + Value *FalseVal, + const SimplifyQuery &Q, unsigned MaxRecurse) { ICmpInst::Predicate Pred; Value *CmpLHS, *CmpRHS; @@ -4029,8 +4078,8 @@ } // Check for other compares that behave like bit test. - if (Value *V = simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, - TrueVal, FalseVal)) + if (Value *V = + simplifySelectWithFakeICmpEq(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal)) return V; // If we have an equality comparison, then we know the value in one of the @@ -4038,18 +4087,18 @@ // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal || + /* AllowRefinement */ false, + MaxRecurse) == TrueVal || SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ false, MaxRecurse) == - TrueVal) + /* AllowRefinement */ false, + MaxRecurse) == TrueVal) return FalseVal; if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, - /* AllowRefinement */ true, MaxRecurse) == - FalseVal || + /* AllowRefinement */ true, + MaxRecurse) == FalseVal || SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, - /* AllowRefinement */ true, MaxRecurse) == - FalseVal) + /* AllowRefinement */ true, + MaxRecurse) == FalseVal) return FalseVal; } @@ -4068,11 +4117,11 @@ // This transform is safe if we do not have (do not care about) -0.0 or if // at least one operand is known to not be -0.0. Otherwise, the select can // change the sign of a zero operand. - bool HasNoSignedZeros = Q.CxtI && isa(Q.CxtI) && - Q.CxtI->hasNoSignedZeros(); + bool HasNoSignedZeros = + Q.CxtI && isa(Q.CxtI) && Q.CxtI->hasNoSignedZeros(); const APFloat *C; if (HasNoSignedZeros || (match(T, m_APFloat(C)) && C->isNonZero()) || - (match(F, m_APFloat(C)) && C->isNonZero())) { + (match(F, m_APFloat(C)) && C->isNonZero())) { // (T == F) ? T : F --> F // (F == T) ? T : F --> F if (Pred == FCmpInst::FCMP_OEQ) @@ -4115,8 +4164,8 @@ "Select must have bool or bool vector condition"); assert(TrueVal->getType() == FalseVal->getType() && "Select must have same types for true/false ops"); - if (Cond->getType() == TrueVal->getType() && - match(TrueVal, m_One()) && match(FalseVal, m_ZeroInt())) + if (Cond->getType() == TrueVal->getType() && match(TrueVal, m_One()) && + match(FalseVal, m_ZeroInt())) return Cond; // select ?, X, X -> X @@ -4152,11 +4201,9 @@ // one element is undef, choose the defined element as the safe result. if (TEltC == FEltC) NewC.push_back(TEltC); - else if (Q.isUndefValue(TEltC) && - isGuaranteedNotToBeUndefOrPoison(FEltC)) + else if (Q.isUndefValue(TEltC) && isGuaranteedNotToBeUndefOrPoison(FEltC)) NewC.push_back(FEltC); - else if (Q.isUndefValue(FEltC) && - isGuaranteedNotToBeUndefOrPoison(TEltC)) + else if (Q.isUndefValue(FEltC) && isGuaranteedNotToBeUndefOrPoison(TEltC)) NewC.push_back(TEltC); else break; @@ -4309,8 +4356,8 @@ /// Given operands for an InsertValueInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, - ArrayRef Idxs, const SimplifyQuery &Q, - unsigned) { + ArrayRef Idxs, + const SimplifyQuery &Q, unsigned) { if (Constant *CAgg = dyn_cast(Agg)) if (Constant *CVal = dyn_cast(Val)) return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs); @@ -4363,8 +4410,7 @@ // If the scalar is undef, and there is no risk of propagating poison from the // vector value, simplify to the vector value. - if (Q.isUndefValue(Val) && - isGuaranteedNotToBeUndefOrPoison(Vec)) + if (Q.isUndefValue(Val) && isGuaranteedNotToBeUndefOrPoison(Vec)) return Vec; // If we are extracting a value from a vector, then inserting it into the same @@ -4459,14 +4505,15 @@ bool HasUndefInput = false; for (Value *Incoming : PN->incoming_values()) { // If the incoming value is the phi node itself, it can safely be skipped. - if (Incoming == PN) continue; + if (Incoming == PN) + continue; if (Q.isUndefValue(Incoming)) { // Remember that we saw an undef value, but otherwise ignore them. HasUndefInput = true; continue; } if (CommonValue && Incoming != CommonValue) - return nullptr; // Not the same, bail out. + return nullptr; // Not the same, bail out. CommonValue = Incoming; } @@ -4484,8 +4531,8 @@ return CommonValue; } -static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, - Type *Ty, const SimplifyQuery &Q, unsigned MaxRecurse) { +static Value *SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, + const SimplifyQuery &Q, unsigned MaxRecurse) { if (auto *C = dyn_cast(Op)) return ConstantFoldCastOperand(CastOpc, C, Ty, Q.DL); @@ -4696,8 +4743,8 @@ return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit); } -static Constant *foldConstant(Instruction::UnaryOps Opcode, - Value *&Op, const SimplifyQuery &Q) { +static Constant *foldConstant(Instruction::UnaryOps Opcode, Value *&Op, + const SimplifyQuery &Q) { if (auto *C = dyn_cast(Op)) return ConstantFoldUnaryOpOperand(Opcode, C, Q.DL); return nullptr; @@ -4736,8 +4783,7 @@ /// Perform folds that are common to any floating-point operation. This implies /// transforms based on undef/NaN because the operation itself makes no /// difference to the result. -static Constant *simplifyFPOp(ArrayRef Ops, - FastMathFlags FMF, +static Constant *simplifyFPOp(ArrayRef Ops, FastMathFlags FMF, const SimplifyQuery &Q) { for (Value *V : Ops) { bool IsNan = match(V, m_NaN()); @@ -4760,8 +4806,12 @@ /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. +template static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, + MatcherContext &MContext, + unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) return C; @@ -4769,11 +4819,11 @@ return C; // fadd X, -0 ==> X - if (match(Op1, m_NegZeroFP())) + if (try_match(Op1, m_NegZeroFP(), MContext)) return Op0; // fadd X, 0 ==> X, when we know X is not -0 - if (match(Op1, m_PosZeroFP()) && + if (try_match(Op1, m_PosZeroFP(), MContext) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; @@ -4785,12 +4835,12 @@ // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 if (FMF.noNaNs()) { - if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || - match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)))) + if (try_match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1)), MContext) || + try_match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)), MContext)) return ConstantFP::getNullValue(Op0->getType()); - if (match(Op0, m_FNeg(m_Specific(Op1))) || - match(Op1, m_FNeg(m_Specific(Op0)))) + if (try_match(Op0, m_FNeg(m_Specific(Op1)), MContext) || + try_match(Op1, m_FNeg(m_Specific(Op0)), MContext)) return ConstantFP::getNullValue(Op0->getType()); } @@ -4798,8 +4848,8 @@ // Y + (X - Y) --> X Value *X; if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op0, m_FSub(m_Value(X), m_Specific(Op1))) || - match(Op1, m_FSub(m_Value(X), m_Specific(Op0))))) + (try_match(Op0, m_FSub(m_Value(X), m_Specific(Op1)), MContext) || + try_match(Op1, m_FSub(m_Value(X), m_Specific(Op0)), MContext))) return X; return nullptr; @@ -4827,8 +4877,7 @@ // fsub -0.0, (fsub -0.0, X) ==> X // fsub -0.0, (fneg X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FNeg(m_Value(X)))) + if (match(Op0, m_NegZeroFP()) && match(Op1, m_FNeg(m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. @@ -4895,12 +4944,13 @@ return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse); } +template Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); + const SimplifyQuery &Q, + MatcherContext &Matcher) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Q, Matcher, RecursionLimit); } - Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); @@ -5007,8 +5057,8 @@ /// If not, this returns null. /// Try to use FastMathFlags when folding the result. static Value *simplifyFPUnOp(unsigned Opcode, Value *Op, - const FastMathFlags &FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const FastMathFlags &FMF, const SimplifyQuery &Q, + unsigned MaxRecurse) { switch (Opcode) { case Instruction::FNeg: return simplifyFNegInst(Op, FMF, Q, MaxRecurse); @@ -5028,11 +5078,14 @@ /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse) { +template +static Value * +SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, + MatcherContext &MContext, unsigned MaxRecurse) { switch (Opcode) { case Instruction::Add: - return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); + return SimplifyAddInst(LHS, RHS, false, false, Q, MContext, + MaxRecurse); case Instruction::Sub: return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::Mul: @@ -5058,7 +5111,8 @@ case Instruction::Xor: return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); case Instruction::FAdd: - return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MContext, + MaxRecurse); case Instruction::FSub: return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::FMul: @@ -5072,15 +5126,23 @@ } } +static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + MatcherContext Matcher; + return SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, MaxRecurse); +} + /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. /// Try to use FastMathFlags when folding the result. +template static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const FastMathFlags &FMF, const SimplifyQuery &Q, + MatcherContext &Matcher, unsigned MaxRecurse) { switch (Opcode) { case Instruction::FAdd: - return SimplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse); + return SimplifyFAddInst(LHS, RHS, FMF, Q, Matcher, MaxRecurse); case Instruction::FSub: return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FMul: @@ -5088,19 +5150,31 @@ case Instruction::FDiv: return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); default: - return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); + return SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, MaxRecurse); } } +template Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); + const SimplifyQuery &Q, + MatcherContext &Matcher) { + return ::SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, RecursionLimit); } +template Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q) { - return ::SimplifyBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); -} + FastMathFlags FMF, const SimplifyQuery &Q, + MatcherContext &Matcher) { + return ::SimplifyBinOp<>(Opcode, LHS, RHS, FMF, Q, Matcher, RecursionLimit); +} +#define ENABLE_TRAIT(TRAIT) \ + template Value *llvm::SimplifyBinOp(unsigned, Value *, Value *, \ + const SimplifyQuery &, \ + MatcherContext &); \ + template Value *llvm::SimplifyBinOp(unsigned, Value *, Value *, \ + FastMathFlags, const SimplifyQuery &, \ + MatcherContext &); +#include "llvm/IR/Traits/EnabledTraits.def" /// Given operands for a CmpInst, see if we can fold the result. static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -5117,7 +5191,8 @@ static bool IsIdempotent(Intrinsic::ID ID) { switch (ID) { - default: return false; + default: + return false; // Unary idempotent: f(f(x)) = f(x) case Intrinsic::fabs: @@ -5201,43 +5276,50 @@ Value *X; switch (IID) { case Intrinsic::fabs: - if (SignBitMustBeZero(Op0, Q.TLI)) return Op0; + if (SignBitMustBeZero(Op0, Q.TLI)) + return Op0; break; case Intrinsic::bswap: // bswap(bswap(x)) -> x - if (match(Op0, m_BSwap(m_Value(X)))) return X; + if (match(Op0, m_BSwap(m_Value(X)))) + return X; break; case Intrinsic::bitreverse: // bitreverse(bitreverse(x)) -> x - if (match(Op0, m_BitReverse(m_Value(X)))) return X; + if (match(Op0, m_BitReverse(m_Value(X)))) + return X; break; case Intrinsic::exp: // exp(log(x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic(m_Value(X)))) return X; + match(Op0, m_Intrinsic(m_Value(X)))) + return X; break; case Intrinsic::exp2: // exp2(log2(x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic(m_Value(X)))) return X; + match(Op0, m_Intrinsic(m_Value(X)))) + return X; break; case Intrinsic::log: // log(exp(x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic(m_Value(X)))) return X; + match(Op0, m_Intrinsic(m_Value(X)))) + return X; break; case Intrinsic::log2: // log2(exp2(x)) -> x if (Q.CxtI->hasAllowReassoc() && (match(Op0, m_Intrinsic(m_Value(X))) || - match(Op0, m_Intrinsic(m_SpecificFP(2.0), - m_Value(X))))) return X; + match(Op0, + m_Intrinsic(m_SpecificFP(2.0), m_Value(X))))) + return X; break; case Intrinsic::log10: // log10(pow(10.0, x)) -> x if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic(m_SpecificFP(10.0), - m_Value(X)))) return X; + match(Op0, m_Intrinsic(m_SpecificFP(10.0), m_Value(X)))) + return X; break; case Intrinsic::floor: case Intrinsic::trunc: @@ -5265,31 +5347,46 @@ static Intrinsic::ID getMaxMinOpposite(Intrinsic::ID IID) { switch (IID) { - case Intrinsic::smax: return Intrinsic::smin; - case Intrinsic::smin: return Intrinsic::smax; - case Intrinsic::umax: return Intrinsic::umin; - case Intrinsic::umin: return Intrinsic::umax; - default: llvm_unreachable("Unexpected intrinsic"); + case Intrinsic::smax: + return Intrinsic::smin; + case Intrinsic::smin: + return Intrinsic::smax; + case Intrinsic::umax: + return Intrinsic::umin; + case Intrinsic::umin: + return Intrinsic::umax; + default: + llvm_unreachable("Unexpected intrinsic"); } } static APInt getMaxMinLimit(Intrinsic::ID IID, unsigned BitWidth) { switch (IID) { - case Intrinsic::smax: return APInt::getSignedMaxValue(BitWidth); - case Intrinsic::smin: return APInt::getSignedMinValue(BitWidth); - case Intrinsic::umax: return APInt::getMaxValue(BitWidth); - case Intrinsic::umin: return APInt::getMinValue(BitWidth); - default: llvm_unreachable("Unexpected intrinsic"); + case Intrinsic::smax: + return APInt::getSignedMaxValue(BitWidth); + case Intrinsic::smin: + return APInt::getSignedMinValue(BitWidth); + case Intrinsic::umax: + return APInt::getMaxValue(BitWidth); + case Intrinsic::umin: + return APInt::getMinValue(BitWidth); + default: + llvm_unreachable("Unexpected intrinsic"); } } static ICmpInst::Predicate getMaxMinPredicate(Intrinsic::ID IID) { switch (IID) { - case Intrinsic::smax: return ICmpInst::ICMP_SGE; - case Intrinsic::smin: return ICmpInst::ICMP_SLE; - case Intrinsic::umax: return ICmpInst::ICMP_UGE; - case Intrinsic::umin: return ICmpInst::ICMP_ULE; - default: llvm_unreachable("Unexpected intrinsic"); + case Intrinsic::smax: + return ICmpInst::ICMP_SGE; + case Intrinsic::smin: + return ICmpInst::ICMP_SLE; + case Intrinsic::umax: + return ICmpInst::ICMP_UGE; + case Intrinsic::umin: + return ICmpInst::ICMP_ULE; + default: + llvm_unreachable("Unexpected intrinsic"); } } @@ -5491,7 +5588,8 @@ case Intrinsic::maximum: case Intrinsic::minimum: { // If the arguments are the same, this is a no-op. - if (Op0 == Op1) return Op0; + if (Op0 == Op1) + return Op0; // Canonicalize constant operand as Op1. if (isa(Op0)) @@ -5605,7 +5703,7 @@ Value *Op0 = Call->getArgOperand(0); Value *Op1 = Call->getArgOperand(1); Value *Op2 = Call->getArgOperand(2); - if (Value *V = simplifyFPOp({ Op0, Op1, Op2 }, {}, Q)) + if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q)) return V; return nullptr; } @@ -5674,12 +5772,28 @@ /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { +// FIXME: this will break if masquerading intrinsics do not pass muster. +template +Value *llvm::SimplifyInstructionWithTrait(Instruction *I, + const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); Value *Result; - switch (I->getOpcode()) { + // Allow Traits to bail for cases we do not want to implement. + if (!Trait::consider(I)) + return nullptr; + + // Create an initial context rooted at I. + MatcherContext Matcher; + if (!Matcher.accept(I)) + return nullptr; + + // Cast into the Trait type hierarchy since I may have a different opcode + // there. + // Eg llvm.*.constrained.fadd(%x, %y, %fpround, %fpexcept) is an 'fadd'. + const auto *TraitInst = trait_cast(I); + switch (TraitInst->getOpcode()) { default: Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); break; @@ -5687,24 +5801,25 @@ Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q); break; case Instruction::FAdd: - Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), + I->getFastMathFlags(), Q, Matcher); break; case Instruction::Add: - Result = - SimplifyAddInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifyAddInst( + I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q, + Matcher); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), I->getFastMathFlags(), Q); break; case Instruction::Sub: - Result = - SimplifySubInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifySubInst( + I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), @@ -5734,18 +5849,18 @@ I->getFastMathFlags(), Q); break; case Instruction::Shl: - Result = - SimplifyShlInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifyShlInst( + I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.isExact(cast(I)), Q); + Q.IIQ.isExact(I), Q); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.isExact(cast(I)), Q); + Q.IIQ.isExact(I), Q); break; case Instruction::And: Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); @@ -5820,8 +5935,8 @@ #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = - SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + Result = SimplifyCastInst(TraitInst->getOpcode(), TraitInst->getOperand(0), + TraitInst->getType(), Q); break; case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. @@ -5834,6 +5949,28 @@ /// detecting that case here, returning a safe value instead. return Result == I ? UndefValue::get(I->getType()) : Result; } +#define ENABLE_TRAIT(TRAIT) \ + template Value *llvm::SimplifyInstructionWithTrait( \ + Instruction *, const SimplifyQuery &, OptimizationRemarkEmitter *); +#include "llvm/IR/Traits/EnabledTraits.def" + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + // Either all or no fp operations in a function are constrained. + if (CFPTrait::consider(I)) { + if (auto *Result = SimplifyInstructionWithTrait(I, SQ, ORE)) + return Result; + } + + /// Vector-predicated code. + /// FIXME: We use a quick heuristics (is this a vector type?) for now. + if (VPTrait::consider(I)) { + if (auto *Result = SimplifyInstructionWithTrait(I, SQ, ORE)) + return Result; + } + + return SimplifyInstructionWithTrait(I, SQ, ORE); +} /// Implementation of recursive simplification through an instruction's /// uses. @@ -5950,4 +6087,4 @@ } template const SimplifyQuery getBestSimplifyQuery(AnalysisManager &, Function &); -} +} // namespace llvm diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -104,12 +104,24 @@ return ConstantInt::get(Type::getInt64Ty(Context), 1); } +bool ConstrainedFPIntrinsic::hasRoundingMode() const { + switch (getIntrinsicID()) { + default: + return false; +#define INSTRUCTION(N, A, R, I) \ + case Intrinsic::I: \ + return R; +#include "llvm/IR/ConstrainedOps.def" + } +} + Optional ConstrainedFPIntrinsic::getRoundingMode() const { + if (!hasRoundingMode()) + return None; unsigned NumOperands = getNumArgOperands(); Metadata *MD = cast(getArgOperand(NumOperands - 2))->getMetadata(); - if (!MD || !isa(MD)) - return None; + assert(MD && isa(MD)); return StrToRoundingMode(cast(MD)->getString()); } @@ -145,6 +157,21 @@ .Default(FCmpInst::BAD_FCMP_PREDICATE); } +unsigned ConstrainedFPIntrinsic::getFunctionalOpcode() const { + switch (getIntrinsicID()) { + default: + // Just some intrinsic call + return Instruction::Call; + +#define DAG_FUNCTION(OPC, FPEXCEPT, FPROUND, INTRIN, SD) + +#define DAG_INSTRUCTION(OPC, FPEXCEPT, FPROUND, INTRIN, SD) \ + case Intrinsic::INTRIN: \ + return OPC; +#include "llvm/IR/ConstrainedOps.def" + } +} + bool ConstrainedFPIntrinsic::isUnaryOp() const { switch (getIntrinsicID()) { default: diff --git a/llvm/test/Transforms/InstSimplify/add_vp.ll b/llvm/test/Transforms/InstSimplify/add_vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/add_vp.ll @@ -0,0 +1,76 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instsimplify -S | FileCheck %s + +declare <2 x i32> @llvm.vp.add.v2i32(<2 x i32>, <2 x i32>, <2 x i1>, i32) +declare <2 x i32> @llvm.vp.sub.v2i32(<2 x i32>, <2 x i32>, <2 x i1>, i32) + +declare <2 x i8> @llvm.vp.add.v2i8(<2 x i8>, <2 x i8>, <2 x i1>, i32) +declare <2 x i8> @llvm.vp.sub.v2i8(<2 x i8>, <2 x i8>, <2 x i1>, i32) + +; Constant folding should just work. +define <2 x i32> @constant_vp_add(<2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @constant_vp_add( +; CHECK-NEXT: ret <2 x i32> +; + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> , <2 x i32> , <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Simplifying pure VP intrinsic patterns. +define <2 x i32> @common_sub_operand(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + ; %Z = sub i32 %X, %Y, vp(%mask, %evl) + %Z = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ; %Q = add i32 %Z, %Y, vp(%mask, %evl) + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %Z, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Mixing regular SIMD with vp intrinsics (vp add match root). +define <2 x i32> @common_sub_operand_vproot(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand_vproot( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + %Z = sub <2 x i32> %X, %Y + ; %Q = add i32 %Z, %Y, vp(%mask, %evl) + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %Z, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Mixing regular SIMD with vp intrinsics (vp inside pattern, regular instruction root). +define <2 x i32> @common_sub_operand_vpinner(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand_vpinner( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + ; %Z = sub i32 %X, %Y, vp(%mask, %evl) + %Z = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + %Q = add <2 x i32> %Z, %Y + ret <2 x i32> %Q +} + +define <2 x i32> @negated_operand(<2 x i32> %x, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @negated_operand( +; CHECK-NEXT: ret <2 x i32> zeroinitializer +; + ; %negx = sub i32 0, %x + %negx = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> zeroinitializer, <2 x i32> %x, <2 x i1> %mask, i32 %evl) + ; %r = add i32 %negx, %x + %r = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %negx, <2 x i32> %x, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %r +} + +; TODO Lift InstSimplify::SimplifyAdd to the trait framework to optimize this. +define <2 x i8> @knownnegation(<2 x i8> %x, <2 x i8> %y, <2 x i1> %mask, i32 %evl) { +; TODO-CHECK-LABEL: @knownnegation( +; TODO-XHECK-NEXT: ret i8 <2 x i8> zeroinitializer +; + ; %xy = sub i8 %x, %y + %xy = call <2 x i8> @llvm.vp.sub.v2i8(<2 x i8> %x, <2 x i8> %y, <2 x i1> %mask, i32 %evl) + ; %yx = sub i8 %y, %x + %yx = call <2 x i8> @llvm.vp.sub.v2i8(<2 x i8> %y, <2 x i8> %x, <2 x i1> %mask, i32 %evl) + ; %r = add i8 %xy, %yx + %r = call <2 x i8> @llvm.vp.add.v2i8(<2 x i8> %xy, <2 x i8> %yx, <2 x i1> %mask, i32 %evl) + ret <2 x i8> %r +} diff --git a/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll b/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instsimplify -S | FileCheck %s + +declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata) +declare <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float>, <2 x float>, metadata, metadata) +declare float @llvm.experimental.constrained.fsub.f32(float, float, metadata, metadata) + +; fadd X, -0 ==> X +define float @fadd_x_n0(float %a) { +; CHECK-LABEL: @fadd_x_n0( +; CHECK-NEXT: ret float [[A:%.*]] +; + %ret = call float @llvm.experimental.constrained.fadd.f32(float %a, float -0.0, + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret float %ret +} + +define <2 x float> @fadd_x_n0_vec_undef_elt(<2 x float> %a) { +; CHECK-LABEL: @fadd_x_n0_vec_undef_elt( +; CHECK-NEXT: ret <2 x float> %a +; + %ret = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %a, <2 x float> , + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret <2 x float> %ret +} + +; We can't optimize away the fadd in this test because the input +; value to the function and subsequently to the fadd may be -0.0. +; In that one special case, the result of the fadd should be +0.0 +; rather than the first parameter of the fadd. + +; Fragile test warning: We need 6 sqrt calls to trigger the bug +; because the internal logic has a magic recursion limit of 6. +; This is presented without any explanation or ability to customize. + +declare float @sqrtf(float) + +define float @PR22688(float %x) { +; CHECK-LABEL: @PR22688( +; CHECK-NEXT: [[TMP1:%.*]] = call float @sqrtf(float [[X:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call float @sqrtf(float [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call float @sqrtf(float [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = call float @sqrtf(float [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = call float @sqrtf(float [[TMP4]]) +; CHECK-NEXT: [[TMP6:%.*]] = call float @sqrtf(float [[TMP5]]) +; CHECK-NEXT: [[TMP7:%.*]] = call float @llvm.experimental.constrained.fadd.f32(float [[TMP6]], float 0.000000e+00, +; CHECK-NEXT: ret float [[TMP7]] +; + %1 = call float @sqrtf(float %x) + %2 = call float @sqrtf(float %1) + %3 = call float @sqrtf(float %2) + %4 = call float @sqrtf(float %3) + %5 = call float @sqrtf(float %4) + %6 = call float @sqrtf(float %5) + %7 = call float @llvm.experimental.constrained.fadd.f32(float %6, float 0.0, + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret float %7 +} + +attributes #0 = { strictfp nounwind readnone }