diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -37,6 +37,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Traits/SemanticTrait.h" namespace llvm { @@ -60,6 +61,7 @@ /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative /// results if the users specified it is safe to use. +/// FIXME: Incorporate this into trait framework. struct InstrInfoQuery { InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {} InstrInfoQuery() : UseInstrInfo(true) {} @@ -83,9 +85,9 @@ return false; } - bool isExact(const BinaryOperator *Op) const { - if (UseInstrInfo && isa(Op)) - return cast(Op)->isExact(); + bool isExact(const Instruction *I) const { + if (UseInstrInfo && isa(I)) + return cast(I)->isExact(); return false; } }; @@ -146,16 +148,34 @@ const SimplifyQuery &Q); /// Given operands for an Add, fold the result or return null. +template Value *SimplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *SimplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyAddInst(LHS, RHS, isNSW, isNUW, Q, Matcher); +} /// Given operands for a Sub, fold the result or return null. Value *SimplifySubInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, const SimplifyQuery &Q); -/// Given operands for an FAdd, fold the result or return null. +/// Given operands for an FAdd, and a matcher context that was initialized for +/// the actual instruction, fold the result or return null. +template Value *SimplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +/// Given operands for an FAdd, fold the result or return null. +/// We don't have any information about the traits of the 'fadd' so run this +/// with the unassuming default trait. +inline Value *SimplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyFAddInst(LHS, RHS, FMF, Q, Matcher); +} /// Given operands for an FSub, fold the result or return null. Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, @@ -272,13 +292,27 @@ const SimplifyQuery &Q); /// Given operands for a BinaryOperator, fold the result or return null. +template Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher); +} /// Given operands for a BinaryOperator, fold the result or return null. /// Try to use FastMathFlags when folding the result. -Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q); +template +Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q) { + MatcherContext Matcher; + return SimplifyBinOp(Opcode, LHS, RHS, FMF, Q, Matcher); +} /// Given a callsite, fold the result or return null. Value *SimplifyCall(CallBase *Call, const SimplifyQuery &Q); @@ -289,6 +323,10 @@ /// See if we can compute a simplified version of this instruction. If not, /// return null. +template +Value *SimplifyInstructionWithTrait(Instruction *I, const SimplifyQuery &Q, + OptimizationRemarkEmitter *ORE = nullptr); + Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, OptimizationRemarkEmitter *ORE = nullptr); @@ -336,4 +374,3 @@ } // end namespace llvm #endif - diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -278,6 +278,15 @@ unsigned getFunctionalOpcode() const { return GetFunctionalOpcodeForVP(getIntrinsicID()); } + bool isFunctionalCommutative() const { + return Instruction::isCommutative(getFunctionalOpcode()); + } + bool isFunctionalUnaryOp() const { + return Instruction::isUnaryOp(getFunctionalOpcode()); + } + bool isFunctionalBinaryOp() const { + return Instruction::isBinaryOp(getFunctionalOpcode()); + } // Equivalent non-predicated opcode static unsigned GetFunctionalOpcodeForVP(Intrinsic::ID ID); @@ -288,9 +297,21 @@ public: bool isUnaryOp() const; bool isTernaryOp() const; + bool hasRoundingMode() const; Optional getRoundingMode() const; Optional getExceptionBehavior() const; + unsigned getFunctionalOpcode() const; + bool isFunctionalCommutative() const { + return Instruction::isCommutative(getFunctionalOpcode()); + } + bool isFunctionalUnaryOp() const { + return Instruction::isUnaryOp(getFunctionalOpcode()); + } + bool isFunctionalBinaryOp() const { + return Instruction::isBinaryOp(getFunctionalOpcode()); + } + // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I); static bool classof(const Value *V) { diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -41,15 +41,39 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "llvm/IR/Traits/SemanticTrait.h" + #include namespace llvm { namespace PatternMatch { -template bool match(Val *V, const Pattern &P) { +// Trait-match pattern in a given context and update the context. +template +bool match(Val *V, const Pattern &P) { + // TODO: Use this instead of the Trait-less Pattern::match() functions. This + // single function does the same as the Trait-less 'Pattern::match()' + // functions that are replicated once for every Pattern. return const_cast(P).match(V); } +// Trait-match pattern in a given context and update the context. +template +bool match(Val *V, const Pattern &P, MatcherContext &MContext) { + return const_cast(P).match(V, MContext); +} + +// Trait-match pattern and update the context on match. +template +bool try_match(Val *V, const Pattern &P, MatcherContext &MContext) { + MatcherContext CopyCtx(MContext); + if (const_cast(P).match(V, CopyCtx)) { + MContext = CopyCtx; + return true; + } + return false; +} + template bool match(ArrayRef Mask, const Pattern &P) { return const_cast(P).match(Mask); } @@ -59,8 +83,15 @@ OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + return V->hasOneUse() && SubPattern.match(V, MContext); } }; @@ -69,7 +100,17 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return trait_isa(V); + } }; /// Match an arbitrary value and ignore it. @@ -110,7 +151,16 @@ match_unless(const Ty &Matcher) : M(Matcher) {} - template bool match(ITy *V) { return !M.match(V); } + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return !M.match(V, MContext); + } }; /// Match if the inner matcher does *NOT* match. @@ -125,12 +175,18 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - return true; - if (R.match(V)) + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (try_match(V, L, MContext)) { return true; - return false; + } + return R.match(V, MContext); } }; @@ -140,11 +196,15 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) - return true; - return false; + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return L.match(V, MContext) && R.match(V, MContext); } }; @@ -165,17 +225,24 @@ bool AllowUndef; apint_match(const APInt *&Res, bool AllowUndef) - : Res(Res), AllowUndef(AllowUndef) {} + : Res(Res), AllowUndef(AllowUndef) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; } if (V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) - if (auto *CI = dyn_cast_or_null( - C->getSplatValue(AllowUndef))) { + if (auto *CI = + dyn_cast_or_null(C->getSplatValue(AllowUndef))) { Res = &CI->getValue(); return true; } @@ -192,15 +259,22 @@ apfloat_match(const APFloat *&Res, bool AllowUndef) : Res(Res), AllowUndef(AllowUndef) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; } if (V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) - if (auto *CI = dyn_cast_or_null( - C->getSplatValue(AllowUndef))) { + if (auto *CI = + dyn_cast_or_null(C->getSplatValue(AllowUndef))) { Res = &CI->getValueAPF(); return true; } @@ -243,7 +317,14 @@ } template struct constantint_match { - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -262,14 +343,21 @@ return constantint_match(); } -/// This helper class is used to match constant scalars, vector splats, -/// and fixed width vectors that satisfy a specified predicate. -/// For fixed width vector constants, undefined elements are ignored. +/// This helper class is used to match scalar and fixed width vector integer +/// constants that satisfy a specified predicate. +/// For vector constants, undefined elements are ignored. template struct cstval_pred_ty : public Predicate { - template bool match(ITy *V) { - if (const auto *CV = dyn_cast(V)) - return this->isValue(CV->getValue()); + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (const auto *CI = dyn_cast(V)) + return this->isValue(CI->getValue()); if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CV = dyn_cast_or_null(C->getSplatValue())) @@ -317,7 +405,14 @@ api_pred_ty(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -343,7 +438,14 @@ apf_pred_ty(const APFloat *&R) : Res(R) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -410,9 +512,7 @@ inline cst_pred_ty m_Negative() { return cst_pred_ty(); } -inline api_pred_ty m_Negative(const APInt *&V) { - return V; -} +inline api_pred_ty m_Negative(const APInt *&V) { return V; } struct is_nonnegative { bool isValue(const APInt &C) { return C.isNonNegative(); } @@ -422,9 +522,7 @@ inline cst_pred_ty m_NonNegative() { return cst_pred_ty(); } -inline api_pred_ty m_NonNegative(const APInt *&V) { - return V; -} +inline api_pred_ty m_NonNegative(const APInt *&V) { return V; } struct is_strictlypositive { bool isValue(const APInt &C) { return C.isStrictlyPositive(); } @@ -453,9 +551,7 @@ }; /// Match an integer 1 or a vector with all elements equal to 1. /// For vectors, this includes constants with undefined elements. -inline cst_pred_ty m_One() { - return cst_pred_ty(); -} +inline cst_pred_ty m_One() { return cst_pred_ty(); } struct is_zero_int { bool isValue(const APInt &C) { return C.isNullValue(); } @@ -467,7 +563,14 @@ } struct is_zero { - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { auto *C = dyn_cast(V); // FIXME: this should be able to do something for scalable vectors return C && (C->isNullValue() || cst_pred_ty().match(C)); @@ -475,21 +578,15 @@ }; /// Match any null constant or a vector with all elements equal to 0. /// For vectors, this includes constants with undefined elements. -inline is_zero m_Zero() { - return is_zero(); -} +inline is_zero m_Zero() { return is_zero(); } struct is_power2 { bool isValue(const APInt &C) { return C.isPowerOf2(); } }; /// Match an integer or vector power-of-2. /// For vectors, this includes constants with undefined elements. -inline cst_pred_ty m_Power2() { - return cst_pred_ty(); -} -inline api_pred_ty m_Power2(const APInt *&V) { - return V; -} +inline cst_pred_ty m_Power2() { return cst_pred_ty(); } +inline api_pred_ty m_Power2(const APInt *&V) { return V; } struct is_negated_power2 { bool isValue(const APInt &C) { return (-C).isPowerOf2(); } @@ -578,9 +675,7 @@ }; /// Match an arbitrary NaN constant. This includes quiet and signalling nans. /// For vectors, this includes constants with undefined elements. -inline cstfp_pred_ty m_NaN() { - return cstfp_pred_ty(); -} +inline cstfp_pred_ty m_NaN() { return cstfp_pred_ty(); } struct is_nonnan { bool isValue(const APFloat &C) { return !C.isNaN(); } @@ -596,9 +691,7 @@ }; /// Match a positive or negative infinity FP constant. /// For vectors, this includes constants with undefined elements. -inline cstfp_pred_ty m_Inf() { - return cstfp_pred_ty(); -} +inline cstfp_pred_ty m_Inf() { return cstfp_pred_ty(); } struct is_noninf { bool isValue(const APFloat &C) { return !C.isInfinity(); } @@ -674,7 +767,15 @@ bind_ty(Class *&V) : VR(V) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; if (auto *CV = dyn_cast(V)) { VR = CV; return true; @@ -694,7 +795,9 @@ /// Match a binary operator, capturing it if we match. inline bind_ty m_BinOp(BinaryOperator *&I) { return I; } /// Match a with overflow intrinsic, capturing it if we match. -inline bind_ty m_WithOverflowInst(WithOverflowInst *&I) { return I; } +inline bind_ty m_WithOverflowInst(WithOverflowInst *&I) { + return I; +} /// Match a ConstantInt, capturing the value if we match. inline bind_ty m_ConstantInt(ConstantInt *&CI) { return CI; } @@ -717,7 +820,15 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return V == Val; + } }; /// Match if we have a specific specified value. @@ -730,7 +841,16 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template + bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *const V, MatcherContext &MContext) { + return V == Val; + } }; /// A commutative-friendly version of m_Specific(). @@ -746,7 +866,14 @@ specific_fpval(double V) : Val(V) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -769,7 +896,16 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -781,13 +917,21 @@ /// Match a specified integer value or vector of all elements of that /// value. -template -struct specific_intval { +template struct specific_intval { APInt Val; specific_intval(APInt V) : Val(std::move(V)) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -825,7 +969,16 @@ specific_bbval(BasicBlock *Val) : Val(Val) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; const auto *BB = dyn_cast(V); return BB && BB == Val; } @@ -857,11 +1010,30 @@ // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + auto *I = trait_dyn_cast(V); + if (!I) + return false; + + if (!MContext.accept(I)) + return false; + + MatcherContext LRContext(MContext); + if (L.match(I->getOperand(0), LRContext) && + R.match(I->getOperand(1), LRContext)) { + MContext = LRContext; + return true; + } + if (Commutable && (L.match(I->getOperand(1), MContext) && + R.match(I->getOperand(0), MContext))) + return true; return false; } }; @@ -880,9 +1052,16 @@ AnyUnaryOp_match(const OP_t &X) : X(X) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return X.match(I->getOperand(0)); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *I = trait_dyn_cast(V)) + return X.match(I->getOperand(0), MContext); return false; } }; @@ -905,12 +1084,24 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + if (try_match(I->getOperand(0), L, MContext) && + try_match(I->getOperand(1), R, MContext)) { + return true; + } + return Commutable && (L.match(I->getOperand(1), MContext) && + R.match(I->getOperand(0), MContext)); } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && @@ -949,25 +1140,36 @@ Op_t X; FNeg_match(const Op_t &Op) : X(Op) {} - template bool match(OpTy *V) { - auto *FPMO = dyn_cast(V); - if (!FPMO) return false; + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *FPMO = trait_dyn_cast(V); + if (!FPMO) + return false; + + auto OPC = trait_cast(V)->getOpcode(); - if (FPMO->getOpcode() == Instruction::FNeg) - return X.match(FPMO->getOperand(0)); + if (OPC == Instruction::FNeg) + return X.match(FPMO->getOperand(0), MContext); - if (FPMO->getOpcode() == Instruction::FSub) { + if (OPC == Instruction::FSub) { if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!try_match(FPMO->getOperand(0), cstfp_pred_ty(), MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!try_match(FPMO->getOperand(0), cstfp_pred_ty(), MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match(FPMO->getOperand(1), MContext); } return false; @@ -975,9 +1177,7 @@ }; /// Match 'fneg X' as 'fsub -0.0, X'. -template -inline FNeg_match -m_FNeg(const OpTy &X) { +template inline FNeg_match m_FNeg(const OpTy &X) { return FNeg_match(X); } @@ -1081,8 +1281,16 @@ OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *Op = dyn_cast(V)) { + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *Op = trait_dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; if (WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap && @@ -1091,7 +1299,8 @@ if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match(Op->getOperand(0), MContext) && + R.match(Op->getOperand(1), MContext); } return false; } @@ -1102,32 +1311,32 @@ OverflowingBinaryOperator::NoSignedWrap> m_NSWAdd(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template inline OverflowingBinaryOp_match m_NSWSub(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template inline OverflowingBinaryOp_match m_NSWMul(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template inline OverflowingBinaryOp_match m_NSWShl(const LHS &L, const RHS &R) { return OverflowingBinaryOp_match( - L, R); + OverflowingBinaryOperator::NoSignedWrap>(L, + R); } template @@ -1173,11 +1382,20 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); - if (auto *CE = dyn_cast(V)) + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && + L.match(I->getOperand(0), MContext) && + R.match(I->getOperand(1), MContext); + if (auto *CE = trait_dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); return false; @@ -1268,9 +1486,17 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { - if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *PEO = trait_dyn_cast(V)) + return PEO->isExact() && SubPattern.match(V, MContext); return false; } }; @@ -1295,13 +1521,29 @@ CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) { + MatcherContext LRContext(MContext); + if (L.match(I->getOperand(0), LRContext) && + R.match(I->getOperand(1), LRContext)) { + MContext = LRContext; Predicate = I->getPredicate(); return true; - } else if (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))) { + } + + if (!Commutable) + return false; + + if (L.match(I->getOperand(1), MContext) && + R.match(I->getOperand(0), MContext)) { Predicate = I->getSwappedPredicate(); return true; } @@ -1338,10 +1580,18 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext); } return false; } @@ -1354,10 +1604,19 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext); } return false; } @@ -1373,11 +1632,20 @@ ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) : Op1(Op1), Op2(Op2), Op3(Op3) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext) && + Op3.match(I->getOperand(2), MContext); } return false; } @@ -1429,10 +1697,18 @@ Shuffle_match(const T0 &Op1, const T1 &Op2, const T2 &Mask) : Op1(Op1), Op2(Op2), Mask(Mask) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Mask.match(I->getShuffleMask()); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext) && Mask.match(I->getShuffleMask()); } return false; } @@ -1508,9 +1784,17 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &Matcher) { + if (!Matcher.accept(V)) + return false; + if (auto O = trait_dyn_cast(V)) + return O->getOpcode() == Opcode && Op.match(O->getOperand(0), Matcher); return false; } }; @@ -1624,8 +1908,16 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *BI = trait_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1645,10 +1937,18 @@ brc_match(const Cond_t &C, const TrueBlock_t &t, const FalseBlock_t &f) : Cond(C), T(t), F(f) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) - return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *BI = trait_dyn_cast(V)) + if (BI->isConditional() && Cond.match(BI->getCondition(), MContext)) { + return T.match(BI->getSuccessor(0), MContext) && + F.match(BI->getSuccessor(1), MContext); + } return false; } }; @@ -1680,24 +1980,36 @@ // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *II = dyn_cast(V)) { + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *II = trait_dyn_cast(V)) { Intrinsic::ID IID = II->getIntrinsicID(); if ((IID == Intrinsic::smax && Pred_t::match(ICmpInst::ICMP_SGT)) || (IID == Intrinsic::smin && Pred_t::match(ICmpInst::ICMP_SLT)) || (IID == Intrinsic::umax && Pred_t::match(ICmpInst::ICMP_UGT)) || (IID == Intrinsic::umin && Pred_t::match(ICmpInst::ICMP_ULT))) { Value *LHS = II->getOperand(0), *RHS = II->getOperand(1); - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatcherContext LRContext(MContext); + if (L.match(LHS, LRContext) && R.match(RHS, LRContext)) { + MContext = LRContext; + return true; + } + return Commutable && L.match(RHS, MContext) && R.match(LHS, MContext); } } // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = trait_dyn_cast(V); + if (!SI || !MContext.accept(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = trait_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.accept(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1713,9 +2025,15 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + // TODO factor out commutative matching! + MatcherContext LRContext(MContext); + if (L.match(LHS, LRContext) && R.match(RHS, LRContext)) { + MContext = LRContext; + return true; + } + return Commutable && L.match(RHS, MContext) && R.match(LHS, MContext); } }; @@ -1884,50 +2202,77 @@ UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) : L(L), R(R), S(S) {} - template bool match(OpTy *V) { + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; - if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) + if (!try_match(V, m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)), MContext)) return false; Value *AddLHS, *AddRHS; auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS)); // (a + b) u< a, (a + b) u< b - if (Pred == ICmpInst::ICMP_ULT) - if (AddExpr.match(ICmpLHS) && (ICmpRHS == AddLHS || ICmpRHS == AddRHS)) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); + if (Pred == ICmpInst::ICMP_ULT) { + if (try_match(ICmpLHS, AddExpr, MContext) && + (ICmpRHS == AddLHS || ICmpRHS == AddRHS)) { + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpLHS, MContext); + } + } // a >u (a + b), b >u (a + b) - if (Pred == ICmpInst::ICMP_UGT) - if (AddExpr.match(ICmpRHS) && (ICmpLHS == AddLHS || ICmpLHS == AddRHS)) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); + if (Pred == ICmpInst::ICMP_UGT) { + if (try_match(ICmpRHS, AddExpr, MContext) && + (ICmpLHS == AddLHS || ICmpLHS == AddRHS)) { + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpRHS, MContext); + } + } Value *Op1; auto XorExpr = m_OneUse(m_Xor(m_Value(Op1), m_AllOnes())); // (a ^ -1) u (a ^ -1) if (Pred == ICmpInst::ICMP_UGT) { - if (XorExpr.match(ICmpRHS)) - return L.match(Op1) && R.match(ICmpLHS) && S.match(ICmpRHS); + if (try_match(ICmpRHS, XorExpr, MContext)) { + return L.match(Op1, MContext) && R.match(ICmpLHS, MContext) && S.match(ICmpRHS, MContext); + } } // Match special-case for increment-by-1. if (Pred == ICmpInst::ICMP_EQ) { // (a + 1) == 0 // (1 + a) == 0 - if (AddExpr.match(ICmpLHS) && m_ZeroInt().match(ICmpRHS) && - (m_One().match(AddLHS) || m_One().match(AddRHS))) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); + MatcherContext CopyCtx(MContext); + if (AddExpr.match(ICmpLHS, CopyCtx) && + m_ZeroInt().match(ICmpRHS, CopyCtx)) { + if (try_match(AddLHS, m_One(), CopyCtx) || + try_match(AddRHS, m_One(), CopyCtx)) { + MContext = CopyCtx; + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpLHS, MContext); + } + } // 0 == (a + 1) // 0 == (1 + a) - if (m_ZeroInt().match(ICmpLHS) && AddExpr.match(ICmpRHS) && - (m_One().match(AddLHS) || m_One().match(AddRHS))) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); + if (m_ZeroInt().match(ICmpLHS, MContext) && + AddExpr.match(ICmpRHS, MContext) && + (try_match(AddLHS, m_One(), MContext) || + m_One().match(AddRHS, MContext))) + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpRHS, MContext); } return false; @@ -1950,10 +2295,16 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} - template bool match(OpTy *V) { + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) - return Val.match(CI->getArgOperand(OpI)); + if (const auto *CI = trait_dyn_cast(V)) + return Val.match(CI->getArgOperand(OpI), MContext); return false; } }; @@ -1970,8 +2321,14 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} - template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (const auto *CI = trait_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -1996,15 +2353,13 @@ }; template struct m_Intrinsic_Ty { - using Ty = - match_combine_and::Ty, - Argument_match>; + using Ty = match_combine_and::Ty, + Argument_match>; }; template struct m_Intrinsic_Ty { - using Ty = - match_combine_and::Ty, - Argument_match>; + using Ty = match_combine_and::Ty, + Argument_match>; }; template @@ -2013,7 +2368,8 @@ Argument_match>; }; -template +template struct m_Intrinsic_Ty { using Ty = match_combine_and::Ty, Argument_match>; @@ -2243,7 +2599,13 @@ Opnd_t Val; Signum_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; @@ -2265,7 +2627,7 @@ auto RHS = m_LShr(m_Neg(m_Value(OpR)), m_SpecificInt(ShiftWidth)); auto Signum = m_Or(LHS, RHS); - return Signum.match(V) && OpL == OpR && Val.match(OpL); + return Signum.match(V, MContext) && OpL == OpR && Val.match(OpL, MContext); } }; @@ -2283,10 +2645,16 @@ Opnd_t Val; ExtractValue_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) + template + bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *I = trait_dyn_cast(V)) return I->getNumIndices() == 1 && I->getIndices()[0] == Ind && - Val.match(I->getAggregateOperand()); + Val.match(I->getAggregateOperand(), MContext); return false; } }; @@ -2306,9 +2674,16 @@ InsertValue_match(const T0 &Op0, const T1 &Op1) : Op0(Op0), Op1(Op1) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - return Op0.match(I->getOperand(0)) && Op1.match(I->getOperand(1)) && - I->getNumIndices() == 1 && Ind == I->getIndices()[0]; + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &Matcher) { + if (auto *I = trait_dyn_cast(V)) { + return I->getNumIndices() == 1 && Ind == I->getIndices()[0] && + Op0.match(I->getOperand(0), Matcher) && + Op1.match(I->getOperand(1), Matcher); } return false; } @@ -2337,12 +2712,19 @@ const DataLayout &DL; VScaleVal_match(const DataLayout &DL) : DL(DL) {} - template bool match(ITy *V) { + template + bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + template + bool match(ITy *V, MatcherContext &MContext) { if (m_Intrinsic().match(V)) return true; - if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))).match(V)) { - Type *PtrTy = cast(V)->getOperand(0)->getType(); + if (m_PtrToInt(m_OffsetGep(m_Zero(), m_SpecificInt(1))) + .match(V, MContext)) { + Type *PtrTy = trait_cast(V)->getOperand(0)->getType(); auto *DerefTy = PtrTy->getPointerElementType(); if (isa(DerefTy) && DL.getTypeAllocSizeInBits(DerefTy).getKnownMinSize() == 8) diff --git a/llvm/include/llvm/IR/Traits/EnabledTraits.def b/llvm/include/llvm/IR/Traits/EnabledTraits.def new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/EnabledTraits.def @@ -0,0 +1,4 @@ +ENABLE_TRAIT(EmptyTrait) +ENABLE_TRAIT(CFPTrait) +ENABLE_TRAIT(VPTrait) +#undef ENABLE_TRAIT diff --git a/llvm/include/llvm/IR/Traits/SemanticTrait.h b/llvm/include/llvm/IR/Traits/SemanticTrait.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/SemanticTrait.h @@ -0,0 +1,151 @@ +//===- llvm/IR/Trait/SemanticTrait.h - Basic trait definitions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines semantic traits. +// Some intrinsics in LLVM can be described as being a regular operation (such +// as an fadd) at their core with an additional semantic trait. We do this to +// lift optimizations that are defined in terms of standard IR operations (eg +// fadd, fmul) to these intrinsics. We keep the existing patterns and rewrite +// machinery and transparently check the rewrite is consistent with the +// semantical trait(s) that are attached to the operations. +// +// For example: +// +// @llvm.vp.fadd(<256 x double> %a, <256 x double> %b, +// <256 x i1> %m. i32 %avl) +// +// This is a vector-predicated fadd instruction with a %mask and %evl +// parameter +// (https://llvm.org/docs/LangRef.html#vector-predication-intrinsics). +// However, at its core it is just an 'fadd'. +// +// Consider the fma-fusion rewrite pattern (fadd (fmul x,y) z) --> (fma x, y, +// z). If the 'fadd' is actually an 'llvm.vp.fadd" and the 'fmul' is actually +// an 'llvm.vp.fmul', we can perform the rewrite using the %mask and %evl of +// the 'fadd' node. +// +// +// @llvm.experimental.constrained.fadd(double %a, double %b, +// metadata metadata, +// ) +// +// This is an fadd with a possibly non-default rounding mode and exception +// behavior. +// (https://llvm.org/docs/LangRef.html#constrained-floating-point-intrinsics). +// In this case, the operation matches the semantics of a regular 'fadd' +// exactly, if the rounding mode is 'round.tonearest' and the exception +// behavior is 'fpexcept.ignore'. +// Re-considering the case of fma fusion, this time with two constrained fp +// intrinsics. If the rounding mode is tonearest for both and neither of the +// 'llvm.experimental.contrained.fmul' has 'fpexcept.strict',, we are good to +// apply the rewrite and emit a contrained fma with the exception flad of the +// 'fadd'. +// +// There is also a proposal to add complex arithmetic intrinsics to LLVM. In +// that case, the operation is semantically an 'fadd', if we consider the space +// of complex floating-point numbers and their operations. +// +//===----------------------------------------------------------------------===// + +// Look for comments starting with "TODO(new trait)" to see what to implement to +// establish a new instruction trait. + +#ifndef LLVM_IR_TRAIT_SEMANTICTRAIT_H +#define LLVM_IR_TRAIT_SEMANTICTRAIT_H + +#include +#include +#include +#include + +namespace llvm { + +/// Type Casting { +/// These cast operators allow you to side-step the first-class type hierarchy +/// of LLVM (Value, Instruction, BinaryOperator, ..) into your custom type +/// hierarchy. +/// +/// trait_cast(V) +/// +/// actually casts \p V to ExtInstruction. +template struct TraitCast { + using ExtType = ValueDerivedType; +}; + +// This has to happen after all traits are defined since we are referring to +// members of template specialization for each Trait (The TraitCast::ExtType). +#define CASTING_TEMPLATE(CASTFUNC, PREFIXMOD, REFMOD) \ + template \ + static auto trait_##CASTFUNC(PREFIXMOD Value REFMOD V) \ + ->decltype( \ + CASTFUNC::ExtType>(V)) { \ + using TraitExtendedType = \ + typename TraitCast::ExtType; \ + return CASTFUNC(V); \ + } + +#define CONST_CAST_TEMPLATE(CASTFUNC, REFMOD) \ + CASTING_TEMPLATE(CASTFUNC, const, REFMOD) \ + CASTING_TEMPLATE(CASTFUNC, , REFMOD) + +// 'dyn_cast' (allow [const] Value*) +CONST_CAST_TEMPLATE(dyn_cast, *) + +// 'cast' (allow [const] Value(*|&)) +CONST_CAST_TEMPLATE(cast, *) +CONST_CAST_TEMPLATE(cast, &) + +// 'isa' +CONST_CAST_TEMPLATE(isa, *) +CONST_CAST_TEMPLATE(isa, &) +/// } Type Casting + +// TODO +// The trait builder is a specialized IRBuilder that emits trait-compatible +// instructions. +template struct TraitBuilder {}; + +// This is used in pattern matching to check that all instructions in the +// pattern are trait-compatible. +template struct MatcherContext { + // Check whether \p Val is compatible with this context and merge its + // properties. \returns Whether \p Val is compatible with the current state of + // the context. + bool accept(const Value *Val) { return Val; } + + // Like accept() but does not modify the context. + bool check(const Value *Val) const { return Val; } + + // Whether to allow constant folding with the currently accepted operators and + // their operands. + bool allowConstantFolding() const { + return true; + } +}; + +/// Empty Trait { +/// +/// This defined the empty trait without properties. Type casting stays in the +/// standard llvm::Value type hierarchy. + +// Trait without any difference to standard IR +struct EmptyTrait { + // This is to block reassociation for traits that do not support it. + static constexpr bool AllowReassociation = true; + + // Whether \p V should be considered at all with this trait. + static bool consider(const Value *) { return true; } +}; + +using DefaultTrait = EmptyTrait; + +/// } Empty Trait + +} // end namespace llvm + +#endif // LLVM_IR_TRAIT_SEMANTICTRAIT_H diff --git a/llvm/include/llvm/IR/Traits/Traits.h b/llvm/include/llvm/IR/Traits/Traits.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/Traits.h @@ -0,0 +1,329 @@ +//===- llvm/IR/Trait/SemanticTrait.h - Basic trait definitions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines semantic traits. +// Some intrinsics in LLVM can be described as being a regular operation (such +// as an fadd) at their core with an additional semantic trait. We do this to +// lift optimizations that are defined in terms of standard IR operations (eg +// fadd, fmul) to these intrinsics. We keep the existing patterns and rewrite +// machinery and transparently check the rewrite is consistent with the +// semantical trait(s) that are attached to the operations. +// +// For example: +// +// @llvm.vp.fadd(<256 x double> %a, <256 x double> %b, +// <256 x i1> %m. i32 %avl) +// +// This is a vector-predicated fadd instruction with a %mask and %evl +// parameter +// (https://llvm.org/docs/LangRef.html#vector-predication-intrinsics). +// However, at its core it is just an 'fadd'. +// +// Consider the simplification (add (sub x,y), y) --> x. If the 'add' is +// actually an 'llvm.vp.add" and the 'sub' is really an 'llvm.vp.sub', we can +// do the simplification. the 'fadd' node. +// +// +// @llvm.experimental.constrained.fadd(double %a, double %b, +// metadata metadata, +// ) +// +// This is an fadd with a possibly non-default rounding mode and exception +// behavior. +// (https://llvm.org/docs/LangRef.html#constrained-floating-point-intrinsics). +// The constrained fp intrinsic has exactly the semantics of a regular 'fadd', +// if the rounding mode is 'round.tonearest' and the exception behavior is +// 'fpexcept.ignore'. +// We can use all simplifying rewrites for regular fp arithmetic also for +// constrained fp arithmetic where this applies. +// +// There is also a proposal to add complex arithmetic intrinsics to LLVM. In +// that case, the operation is semantically an 'fadd', if we consider the space +// of complex floating-point numbers and their operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Traits/SemanticTrait.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Operator.h" + +#ifndef LLVM_IR_TRAITS_TRAITS_H +#define LLVM_IR_TRAITS_TRAITS_H + +namespace llvm { + +/// Base classes { +// Shared functionality for extended instructions +// These define all functions used in PatterMatch as 'virtual' to remind you to +// implement them. + +// Make sure no member of the Ext.. hierarchy can be constructed. +struct ExtBase { + ExtBase() = delete; + ~ExtBase() = delete; + ExtBase &operator=(const ExtBase &) = delete; + void *operator new(size_t s) = delete; +}; + +// Mirror generic functionality of llvm::Instruction. +struct ExtInstructionBase : public ExtBase, public User { + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } +}; + +/// } Base classes + +/// Intrinsic-Trait Template { +// Generic instantiation for traits that want to masquerade their intrinsic as a +// regular IR instruction. + +// Pretend to be a llvm::Operator. +template struct ExtOperator : public ExtBase, public User { + unsigned getOpcode() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) { + return Intrin->getFunctionalOpcode(); + } + // Default Operator opcode. + return cast(this)->getOpcode(); + } + + bool hasNoSignedWrap() const { + if (const auto *OverflowingBOp = dyn_cast(this)) + return OverflowingBOp->hasNoSignedWrap(); + return false; + } + + bool hasNoUnsignedWrap() const { + if (const auto *OverflowingBOp = dyn_cast(this)) + return OverflowingBOp->hasNoUnsignedWrap(); + return false; + } + + // Every operator is also an extended operator. + static bool classof(const Value *V) { return isa(V); } +}; + +// Pretend to be a llvm::Instruction. +template +struct ExtInstruction final : public ExtInstructionBase { + unsigned getOpcode() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) { + return Intrin->getFunctionalOpcode(); + } + // Default opcode. + return cast(this)->getOpcode(); + } + + static bool classof(const Value *V) { return isa(V); } +}; + +// Pretend to be a (different) llvm::IntrinsicInst. +template +struct ExtIntrinsic final : public ExtInstructionBase { + Intrinsic::ID getIntrinsicID() const { + // Use the intrinsic override. + if (const auto *TraitIntrin = dyn_cast(this)) + return TraitIntrin->getFunctionalIntrinsic(); + // Default intrinsic opcode. + return cast(this)->getIntrinsicID(); + } + + unsigned getOpcode() const { + // We are looking at this as an intrinsic -> do not hide this. + return cast(this)->getOpcode(); + } + + bool isCommutative() const { + // The underlying intrinsic may not specify whether it is commutative. + // Query our own interface to be sure this is done right. + // Use the intrinsic override. + if (const auto *TraitIntrin = dyn_cast(this)) + return TraitIntrin->isFunctionalCommutative(); + return cast(this)->isFunctionalCommutative(); + } + + static bool classof(const Value *V) { return IntrinsicInst::classof(V); } +}; + +template +struct ExtCmpBase : public ExtInstructionBase { + unsigned getOpcode() const { return OPC; } + + FCmpInst::Predicate getPredicate() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) { + return Intrin->getPredicate(); + } + + // Default opcode. + return cast(this)->getPredicate(); + } +}; + +template +static bool classofExtCmpBase(const Value *V) { + return isa(V) || + cast->getFunctionalOpcode() == OPC; +} + +// Pretend to be a llvm::FCmpInst. +template +struct ExtFCmpInst final + : public ExtCmpBase { + static bool classof(const Value *V) { + return classofExtCmpBase(V); + } +}; + +// Pretend to be a llvm::ICmpInst. +template +struct ExtICmpInst final + : public ExtCmpBase { + static bool classof(const Value *V) { + return classofExtCmpBase(V); + } +}; + +// Pretend to be a BinaryOperator. +template +struct ExtBinaryOperator final : public ExtOperator { + using BinaryOps = Instruction::BinaryOps; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + const auto *Intrin = dyn_cast(I); + return Intrin && Intrin->isFunctionalBinaryOp(); + } + static bool classof(const ConstantExpr *CE) { + return isa(CE); + } + static bool classof(const Value *V) { + if (const auto *I = dyn_cast(V)) + return classof(I); + + if (const auto *CE = dyn_cast(V)) + return classof(CE); + + return false; + } +}; + +// Pretend to be a UnaryOperator. +template +struct ExtUnaryOperator final : public ExtOperator { + using BinaryOps = Instruction::BinaryOps; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + const auto *Intrin = dyn_cast(I); + return Intrin && Intrin->isFunctionalUnaryOp(); + } +}; + +// TODO Implement other extended types. + +/// Template-specialization for the Ext type hierarchy { +//// Enable the ExtSOMETHING for your trait +#define INTRINSIC_TRAIT_SPECIALIZE(TRAIT, TYPE) \ + template <> struct TraitCast { \ + using ExtType = Ext##TYPE; \ + }; \ + template <> struct TraitCast { \ + using ExtType = const Ext##TYPE; \ + }; + +/// } Trait Template Classes + +// Constraint fp trait. +struct CFPTrait { + using Intrinsic = ConstrainedFPIntrinsic; + static constexpr bool AllowReassociation = false; + + // Whether \p V should be considered at all with this trait. + // It is not possible to mix constrained and unconstrained ops. + // Only apply this trait with the constrained variant. + static bool consider(const Value *V) { + return isa(V); + } +}; +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, Instruction) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, Operator) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, BinaryOperator) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, UnaryOperator) +// Deflect queries for the Predicate to the ConstrainedFPCmpIntrinsic. +template <> struct ExtFCmpInst : public ExtInstructionBase { + unsigned getOpcode() const { return Instruction::FCmp; } + + FCmpInst::Predicate getPredicate() const { + return cast(this)->getPredicate(); + } + + bool classof(const Value *V) { return isa(V); } +}; +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, FCmpInst) + +// Accept all constrained fp intrinsics that are actually not constrained. +template <> struct MatcherContext { + bool accept(const Value *Val) { return check(Val); } + bool check(const Value *Val) const { + if (!Val) + return false; + const auto *CFP = dyn_cast(Val); + if (!CFP) + return true; + auto RoundingOpt = CFP->hasRoundingMode() ? CFP->getRoundingMode() : None; + auto ExceptOpt = CFP->getExceptionBehavior(); + return (!ExceptOpt || ExceptOpt == fp::ExceptionBehavior::ebIgnore) && + (!RoundingOpt || (RoundingOpt == RoundingMode::NearestTiesToEven)); + } +}; + +// Vector-predicated trait. +struct VPTrait { + using Intrinsic = VPIntrinsic; + // TODO Enable re-association. + static constexpr bool AllowReassociation = false; + // VP intrinsic mix with regular IR instructions. + // TODO: Adapt this to work with other than arithmetic VP ops. + static bool consider(const Value *V) { + return V->getType()->isVectorTy() && + V->getType()->getScalarType()->isIntegerTy(); + } +}; +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, Instruction) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, Operator) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, BinaryOperator) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, UnaryOperator) + +// Accept everything that passes as a VPIntrinsic. +template <> struct MatcherContext { + // TODO: pick up %mask and %evl here and use them to generate code again. We + // only remove instructions for the moment. + bool accept(const Value *Val) { return Val; } + bool check(const Value *Val) const { return Val; } +}; + +} // namespace llvm + +#undef INTRINSIC_TRAIT_SPECIALIZE + +#endif // LLVM_IR_TRAITS_TRAITS_H diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Traits/Traits.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include @@ -56,8 +57,13 @@ const SimplifyQuery &, unsigned); static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); +template +static Value *SimplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, + MatcherContext &, unsigned); +template static Value *SimplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &, - const SimplifyQuery &, unsigned); + const SimplifyQuery &, MatcherContext &, + unsigned); static Value *SimplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -283,32 +289,46 @@ /// Generic simplifications for associative binary operations. /// Returns the simpler value, or null if none was found. +template static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, + MatcherContext &Matcher, unsigned MaxRecurse) { + // This trait blocks re-association. + // Eg. any trait that adds side-effects may clash with free reassociation + // (FIXME are 'fpexcept.strict', 'fast' fp ops a thing?) + // FIXME associativity may depend on a trait parameter of this specific instance. + if (!Trait::AllowReassociation) + return nullptr; + assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; - BinaryOperator *Op0 = dyn_cast(LHS); - BinaryOperator *Op1 = dyn_cast(RHS); + auto *Op0 = trait_dyn_cast(LHS); + auto *Op1 = trait_dyn_cast(RHS); // Transform: "(A op B) op C" ==> "A op (B op C)" if it simplifies completely. - if (Op0 && Op0->getOpcode() == Opcode) { + MatcherContext Op0Matcher(Matcher); + if (Op0Matcher.accept(Op0) && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); Value *B = Op0->getOperand(1); Value *C = RHS; // Does "B op C" simplify? - if (Value *V = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { + if (Value *V = SimplifyBinOp(Opcode, B, C, Q, Op0Matcher, MaxRecurse)) { // It does! Return "A op V" if it simplifies or is already available. // If V equals B then "A op V" is just the LHS. - if (V == B) return LHS; + if (V == B) { + Matcher = Op0Matcher; + return LHS; + } // Otherwise return "A op V" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, A, V, Q, MaxRecurse)) { + if (Value *W = SimplifyBinOp(Opcode, A, V, Q, Op0Matcher, MaxRecurse)) { + Matcher = Op0Matcher; ++NumReassoc; return W; } @@ -316,18 +336,23 @@ } // Transform: "A op (B op C)" ==> "(A op B) op C" if it simplifies completely. - if (Op1 && Op1->getOpcode() == Opcode) { + MatcherContext Op1Matcher(Matcher); + if (Op1Matcher.accept(Op1) && Op1->getOpcode() == Opcode) { Value *A = LHS; Value *B = Op1->getOperand(0); Value *C = Op1->getOperand(1); // Does "A op B" simplify? - if (Value *V = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) { + if (Value *V = SimplifyBinOp(Opcode, A, B, Q, Op1Matcher, MaxRecurse)) { // It does! Return "V op C" if it simplifies or is already available. // If V equals B then "V op C" is just the RHS. - if (V == B) return RHS; + if (V == B) { + Matcher = Op1Matcher; + return RHS; + } // Otherwise return "V op C" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, V, C, Q, MaxRecurse)) { + if (Value *W = SimplifyBinOp(Opcode, V, C, Q, Op1Matcher, MaxRecurse)) { + Matcher = Op1Matcher; ++NumReassoc; return W; } @@ -335,22 +360,30 @@ } // The remaining transforms require commutativity as well as associativity. + // FIXME commutativity may depend on a trait parameter of this specific instance. + // Eg, matrix multiplication is associative but not commutative. if (!Instruction::isCommutative(Opcode)) return nullptr; // Transform: "(A op B) op C" ==> "(C op A) op B" if it simplifies completely. + MatcherContext CommOp0Matcher(Matcher); if (Op0 && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); Value *B = Op0->getOperand(1); Value *C = RHS; // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + if (Value *V = SimplifyBinOp(Opcode, C, A, Q, CommOp0Matcher, MaxRecurse)) { // It does! Return "V op B" if it simplifies or is already available. // If V equals A then "V op B" is just the LHS. - if (V == A) return LHS; + if (V == A) { + Matcher = CommOp0Matcher; + return LHS; + } // Otherwise return "V op B" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, V, B, Q, MaxRecurse)) { + MatcherContext VContext(Matcher); + if (Value *W = SimplifyBinOp(Opcode, V, B, Q, VContext, MaxRecurse)) { + Matcher = VContext; ++NumReassoc; return W; } @@ -364,12 +397,14 @@ Value *C = Op1->getOperand(1); // Does "C op A" simplify? - if (Value *V = SimplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + if (Value *V = SimplifyBinOp(Opcode, C, A, Q, Matcher, MaxRecurse)) { // It does! Return "B op V" if it simplifies or is already available. // If V equals C then "B op V" is just the RHS. - if (V == C) return RHS; + if (V == C) { + return RHS; + } // Otherwise return "B op V" if it simplifies. - if (Value *W = SimplifyBinOp(Opcode, B, V, Q, MaxRecurse)) { + if (Value *W = SimplifyBinOp(Opcode, B, V, Q, Matcher, MaxRecurse)) { ++NumReassoc; return W; } @@ -379,6 +414,14 @@ return nullptr; } +static Value *SimplifyAssociativeBinOp(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + MatcherContext MContext; + return SimplifyAssociativeBinOp(Opcode, LHS, RHS, Q, MContext, MaxRecurse); +} + /// In the case of a binary operation with a select instruction as an operand, /// try to simplify the binop by seeing whether evaluating it on both branches /// of the select results in the same value. Returns the common value if so, @@ -605,9 +648,12 @@ } /// Given operands for an Add, see if we can fold the result. -/// If not, this returns null. +/// If not, thi:s returns null. +template static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, + MatcherContext &Matcher, + unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) return C; @@ -616,7 +662,7 @@ return Op1; // X + 0 -> X - if (match(Op1, m_Zero())) + if (try_match(Op1, m_Zero(), Matcher)) return Op0; // If two operands are negative, return 0. @@ -627,25 +673,28 @@ // (Y - X) + X -> Y // Eg: X + -X -> 0 Value *Y = nullptr; - if (match(Op1, m_Sub(m_Value(Y), m_Specific(Op0))) || - match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)))) + if (try_match(Op1, m_Sub(m_Value(Y), m_Specific(Op0)), Matcher) || + try_match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)), Matcher)) return Y; // X + ~X -> -1 since ~X = -X-1 Type *Ty = Op0->getType(); - if (match(Op0, m_Not(m_Specific(Op1))) || - match(Op1, m_Not(m_Specific(Op0)))) + if (try_match(Op0, m_Not(m_Specific(Op1)), Matcher) || + try_match(Op1, m_Not(m_Specific(Op0)), Matcher)) return Constant::getAllOnesValue(Ty); // add nsw/nuw (xor Y, signmask), signmask --> Y // The no-wrapping add guarantees that the top bit will be set by the add. // Therefore, the xor must be clearing the already set sign bit of Y. - if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) && - match(Op0, m_Xor(m_Value(Y), m_SignMask()))) + MatcherContext CopyMatch(Matcher); + if ((IsNSW || IsNUW) && match(Op1, m_SignMask(), CopyMatch) && + match(Op0, m_Xor(m_Value(Y), m_SignMask()), CopyMatch)) { + Matcher = CopyMatch; return Y; + } // add nuw %x, -1 -> -1, because %x can only be 0. - if (IsNUW && match(Op1, m_AllOnes())) + if (IsNUW && try_match(Op1, m_AllOnes(), Matcher)) return Op1; // Which is -1. /// i1 add -> xor. @@ -654,7 +703,7 @@ return V; // Try some generic simplifications for associative operations. - if (Value *V = SimplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q, + if (Value *V = SimplifyAssociativeBinOp(Instruction::Add, Op0, Op1, Q, Matcher, MaxRecurse)) return V; @@ -670,9 +719,10 @@ return nullptr; } +template Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, - const SimplifyQuery &Query) { - return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); + const SimplifyQuery &Query, MatcherContext& Matcher) { + return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, Matcher, RecursionLimit); } /// Compute the base pointer and cumulative constant offsets for V. @@ -4760,8 +4810,10 @@ /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. +template static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, MatcherContext& MContext, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) return C; @@ -4769,11 +4821,11 @@ return C; // fadd X, -0 ==> X - if (match(Op1, m_NegZeroFP())) + if (try_match(Op1, m_NegZeroFP(), MContext)) return Op0; // fadd X, 0 ==> X, when we know X is not -0 - if (match(Op1, m_PosZeroFP()) && + if (try_match(Op1, m_PosZeroFP(), MContext) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; @@ -4785,12 +4837,12 @@ // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 if (FMF.noNaNs()) { - if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || - match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)))) + if (try_match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1)), MContext) || + try_match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)), MContext)) return ConstantFP::getNullValue(Op0->getType()); - if (match(Op0, m_FNeg(m_Specific(Op1))) || - match(Op1, m_FNeg(m_Specific(Op0)))) + if (try_match(Op0, m_FNeg(m_Specific(Op1)), MContext) || + try_match(Op1, m_FNeg(m_Specific(Op0)), MContext)) return ConstantFP::getNullValue(Op0->getType()); } @@ -4798,8 +4850,8 @@ // Y + (X - Y) --> X Value *X; if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op0, m_FSub(m_Value(X), m_Specific(Op1))) || - match(Op1, m_FSub(m_Value(X), m_Specific(Op0))))) + (try_match(Op0, m_FSub(m_Value(X), m_Specific(Op1)), MContext) || + try_match(Op1, m_FSub(m_Value(X), m_Specific(Op0)), MContext))) return X; return nullptr; @@ -4895,12 +4947,12 @@ return SimplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse); } +template Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit); + const SimplifyQuery &Q, MatcherContext & Matcher) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Q, Matcher, RecursionLimit); } - Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); @@ -5028,11 +5080,12 @@ /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. +template static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, MatcherContext& MContext, unsigned MaxRecurse) { switch (Opcode) { case Instruction::Add: - return SimplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); + return SimplifyAddInst(LHS, RHS, false, false, Q, MContext, MaxRecurse); case Instruction::Sub: return SimplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::Mul: @@ -5058,7 +5111,7 @@ case Instruction::Xor: return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); case Instruction::FAdd: - return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return SimplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MContext, MaxRecurse); case Instruction::FSub: return SimplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::FMul: @@ -5072,15 +5125,22 @@ } } +static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + MatcherContext Matcher; + return SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, MaxRecurse); +} + /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. /// Try to use FastMathFlags when folding the result. +template static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const FastMathFlags &FMF, const SimplifyQuery &Q, + const FastMathFlags &FMF, const SimplifyQuery &Q, MatcherContext &Matcher, unsigned MaxRecurse) { switch (Opcode) { case Instruction::FAdd: - return SimplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse); + return SimplifyFAddInst(LHS, RHS, FMF, Q, Matcher, MaxRecurse); case Instruction::FSub: return SimplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FMul: @@ -5088,19 +5148,29 @@ case Instruction::FDiv: return SimplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); default: - return SimplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); + return SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, MaxRecurse); } } +template Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); + const SimplifyQuery &Q, MatcherContext & Matcher) { + return ::SimplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, RecursionLimit); } +template Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q) { - return ::SimplifyBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); -} + FastMathFlags FMF, const SimplifyQuery &Q, MatcherContext & Matcher) { + return ::SimplifyBinOp<>(Opcode, LHS, RHS, FMF, Q, Matcher, RecursionLimit); +} +#define ENABLE_TRAIT(TRAIT) \ + template Value *llvm::SimplifyBinOp(unsigned, Value *, Value *, \ + const SimplifyQuery &, \ + MatcherContext &); \ + template Value *llvm::SimplifyBinOp(unsigned, Value *, Value *, \ + FastMathFlags, const SimplifyQuery &, \ + MatcherContext &); +#include "llvm/IR/Traits/EnabledTraits.def" /// Given operands for a CmpInst, see if we can fold the result. static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -5674,12 +5744,27 @@ /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, +// FIXME: this will break if masquerading intrinsics do not pass muster. +template +Value *llvm::SimplifyInstructionWithTrait(Instruction *I, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); Value *Result; - switch (I->getOpcode()) { + // Allow Traits to bail for cases we do not want to implement. + if (!Trait::consider(I)) + return nullptr; + + // Create an initial context rooted at I. + MatcherContext Matcher; + if (!Matcher.accept(I)) + return nullptr; + + // Cast into the Trait type hierarchy since I may have a different opcode + // there. + // Eg llvm.*.constrained.fadd(%x, %y, %fpround, %fpexcept) is an 'fadd'. + const auto *TraitInst = trait_cast(I); + switch (TraitInst->getOpcode()) { default: Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); break; @@ -5687,14 +5772,14 @@ Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q); break; case Instruction::FAdd: - Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), + I->getFastMathFlags(), Q, Matcher); break; case Instruction::Add: - Result = - SimplifyAddInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifyAddInst( + I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q, Matcher); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), @@ -5703,8 +5788,8 @@ case Instruction::Sub: Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), @@ -5736,16 +5821,16 @@ case Instruction::Shl: Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.isExact(cast(I)), Q); + Q.IIQ.isExact(I), Q); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.isExact(cast(I)), Q); + Q.IIQ.isExact(I), Q); break; case Instruction::And: Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); @@ -5821,7 +5906,7 @@ #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST Result = - SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + SimplifyCastInst(TraitInst->getOpcode(), TraitInst->getOperand(0), TraitInst->getType(), Q); break; case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. @@ -5834,6 +5919,28 @@ /// detecting that case here, returning a safe value instead. return Result == I ? UndefValue::get(I->getType()) : Result; } +#define ENABLE_TRAIT(TRAIT) \ + template Value *llvm::SimplifyInstructionWithTrait( \ + Instruction *, const SimplifyQuery &, OptimizationRemarkEmitter *); +#include "llvm/IR/Traits/EnabledTraits.def" + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + // Either all or no fp operations in a function are constrained. + if (CFPTrait::consider(I)) { + if (auto *Result = SimplifyInstructionWithTrait(I, SQ, ORE)) + return Result; + } + + /// Vector-predicated code. + /// FIXME: We use a quick heuristics (is this a vector type?) for now. + if (VPTrait::consider(I)) { + if (auto *Result = SimplifyInstructionWithTrait(I, SQ, ORE)) + return Result; + } + + return SimplifyInstructionWithTrait(I, SQ, ORE); +} /// Implementation of recursive simplification through an instruction's /// uses. diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -104,12 +104,24 @@ return ConstantInt::get(Type::getInt64Ty(Context), 1); } +bool ConstrainedFPIntrinsic::hasRoundingMode() const { + switch (getIntrinsicID()) { + default: + return false; +#define INSTRUCTION(N, A, R, I) \ + case Intrinsic::I: \ + return R; +#include "llvm/IR/ConstrainedOps.def" + } +} + Optional ConstrainedFPIntrinsic::getRoundingMode() const { + if (!hasRoundingMode()) + return None; unsigned NumOperands = getNumArgOperands(); Metadata *MD = cast(getArgOperand(NumOperands - 2))->getMetadata(); - if (!MD || !isa(MD)) - return None; + assert(MD && isa(MD)); return StrToRoundingMode(cast(MD)->getString()); } @@ -145,6 +157,21 @@ .Default(FCmpInst::BAD_FCMP_PREDICATE); } +unsigned ConstrainedFPIntrinsic::getFunctionalOpcode() const { + switch (getIntrinsicID()) { + default: + // Just some intrinsic call + return Instruction::Call; + +#define DAG_FUNCTION(OPC, FPEXCEPT, FPROUND, INTRIN, SD) + +#define DAG_INSTRUCTION(OPC, FPEXCEPT, FPROUND, INTRIN, SD) \ + case Intrinsic::INTRIN: \ + return OPC; +#include "llvm/IR/ConstrainedOps.def" + } +} + bool ConstrainedFPIntrinsic::isUnaryOp() const { switch (getIntrinsicID()) { default: diff --git a/llvm/test/Transforms/InstSimplify/add_vp.ll b/llvm/test/Transforms/InstSimplify/add_vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/add_vp.ll @@ -0,0 +1,76 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instsimplify -S | FileCheck %s + +declare <2 x i32> @llvm.vp.add.v2i32(<2 x i32>, <2 x i32>, <2 x i1>, i32) +declare <2 x i32> @llvm.vp.sub.v2i32(<2 x i32>, <2 x i32>, <2 x i1>, i32) + +declare <2 x i8> @llvm.vp.add.v2i8(<2 x i8>, <2 x i8>, <2 x i1>, i32) +declare <2 x i8> @llvm.vp.sub.v2i8(<2 x i8>, <2 x i8>, <2 x i1>, i32) + +; Constant folding should just work. +define <2 x i32> @constant_vp_add(<2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @constant_vp_add( +; CHECK-NEXT: ret <2 x i32> +; + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> , <2 x i32> , <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Simplifying pure VP intrinsic patterns. +define <2 x i32> @common_sub_operand(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + ; %Z = sub i32 %X, %Y, vp(%mask, %evl) + %Z = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ; %Q = add i32 %Z, %Y, vp(%mask, %evl) + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %Z, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Mixing regular SIMD with vp intrinsics (vp add match root). +define <2 x i32> @common_sub_operand_vproot(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand_vproot( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + %Z = sub <2 x i32> %X, %Y + ; %Q = add i32 %Z, %Y, vp(%mask, %evl) + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %Z, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Mixing regular SIMD with vp intrinsics (vp inside pattern, regular instruction root). +define <2 x i32> @common_sub_operand_vpinner(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand_vpinner( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + ; %Z = sub i32 %X, %Y, vp(%mask, %evl) + %Z = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + %Q = add <2 x i32> %Z, %Y + ret <2 x i32> %Q +} + +define <2 x i32> @negated_operand(<2 x i32> %x, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @negated_operand( +; CHECK-NEXT: ret <2 x i32> zeroinitializer +; + ; %negx = sub i32 0, %x + %negx = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> zeroinitializer, <2 x i32> %x, <2 x i1> %mask, i32 %evl) + ; %r = add i32 %negx, %x + %r = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %negx, <2 x i32> %x, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %r +} + +; TODO Lift InstSimplify::SimplifyAdd to the trait framework to optimize this. +define <2 x i8> @knownnegation(<2 x i8> %x, <2 x i8> %y, <2 x i1> %mask, i32 %evl) { +; TODO-CHECK-LABEL: @knownnegation( +; TODO-XHECK-NEXT: ret i8 <2 x i8> zeroinitializer +; + ; %xy = sub i8 %x, %y + %xy = call <2 x i8> @llvm.vp.sub.v2i8(<2 x i8> %x, <2 x i8> %y, <2 x i1> %mask, i32 %evl) + ; %yx = sub i8 %y, %x + %yx = call <2 x i8> @llvm.vp.sub.v2i8(<2 x i8> %y, <2 x i8> %x, <2 x i1> %mask, i32 %evl) + ; %r = add i8 %xy, %yx + %r = call <2 x i8> @llvm.vp.add.v2i8(<2 x i8> %xy, <2 x i8> %yx, <2 x i1> %mask, i32 %evl) + ret <2 x i8> %r +} diff --git a/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll b/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instsimplify -S | FileCheck %s + +declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata) +declare <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float>, <2 x float>, metadata, metadata) +declare float @llvm.experimental.constrained.fsub.f32(float, float, metadata, metadata) + +; fadd X, -0 ==> X +define float @fadd_x_n0(float %a) { +; CHECK-LABEL: @fadd_x_n0( +; CHECK-NEXT: ret float [[A:%.*]] +; + %ret = call float @llvm.experimental.constrained.fadd.f32(float %a, float -0.0, + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret float %ret +} + +define <2 x float> @fadd_x_n0_vec_undef_elt(<2 x float> %a) { +; CHECK-LABEL: @fadd_x_n0_vec_undef_elt( +; CHECK-NEXT: ret <2 x float> %a +; + %ret = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %a, <2 x float> , + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret <2 x float> %ret +} + +; We can't optimize away the fadd in this test because the input +; value to the function and subsequently to the fadd may be -0.0. +; In that one special case, the result of the fadd should be +0.0 +; rather than the first parameter of the fadd. + +; Fragile test warning: We need 6 sqrt calls to trigger the bug +; because the internal logic has a magic recursion limit of 6. +; This is presented without any explanation or ability to customize. + +declare float @sqrtf(float) + +define float @PR22688(float %x) { +; CHECK-LABEL: @PR22688( +; CHECK-NEXT: [[TMP1:%.*]] = call float @sqrtf(float [[X:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call float @sqrtf(float [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call float @sqrtf(float [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = call float @sqrtf(float [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = call float @sqrtf(float [[TMP4]]) +; CHECK-NEXT: [[TMP6:%.*]] = call float @sqrtf(float [[TMP5]]) +; CHECK-NEXT: [[TMP7:%.*]] = call float @llvm.experimental.constrained.fadd.f32(float [[TMP6]], float 0.000000e+00, +; CHECK-NEXT: ret float [[TMP7]] +; + %1 = call float @sqrtf(float %x) + %2 = call float @sqrtf(float %1) + %3 = call float @sqrtf(float %2) + %4 = call float @sqrtf(float %3) + %5 = call float @sqrtf(float %4) + %6 = call float @sqrtf(float %5) + %7 = call float @llvm.experimental.constrained.fadd.f32(float %6, float 0.0, + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret float %7 +} + +attributes #0 = { strictfp nounwind readnone }