diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -35,7 +35,10 @@ #ifndef LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H #define LLVM_ANALYSIS_INSTRUCTIONSIMPLIFY_H +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Traits/SemanticTrait.h" namespace llvm { @@ -60,6 +63,7 @@ /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative /// results if the users specified it is safe to use. +/// FIXME: Incorporate this into trait framework. struct InstrInfoQuery { InstrInfoQuery(bool UMD) : UseInstrInfo(UMD) {} InstrInfoQuery() = default; @@ -83,9 +87,9 @@ return false; } - bool isExact(const BinaryOperator *Op) const { - if (UseInstrInfo && isa(Op)) - return cast(Op)->isExact(); + bool isExact(const Instruction *I) const { + if (UseInstrInfo && isa(I)) + return cast(I)->isExact(); return false; } }; @@ -147,20 +151,41 @@ Value *simplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q); /// Given operands for an Add, fold the result or return null. +template Value *simplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *simplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return simplifyAddInst(LHS, RHS, isNSW, isNUW, Q, Matcher); +} /// Given operands for a Sub, fold the result or return null. Value *simplifySubInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, const SimplifyQuery &Q); -/// Given operands for an FAdd, fold the result or return null. +/// Given operands for an FAdd, and a matcher context that was initialized for +/// the actual instruction, fold the result or return null. +template Value * simplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF, - const SimplifyQuery &Q, + const SimplifyQuery &Q, MatcherContext &Matcher, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven); +/// Given operands for an FAdd, fold the result or return null. +/// We don't have any information about the traits of the 'fadd' so run this +/// with the unassuming default trait. +inline Value * +simplifyFAddInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q, + fp::ExceptionBehavior ExBehavior = fp::ebIgnore, + RoundingMode Rounding = RoundingMode::NearestTiesToEven) { + MatcherContext Matcher; + return simplifyFAddInst(LHS, RHS, FMF, Q, Matcher, ExBehavior, Rounding); +} + /// Given operands for an FSub, fold the result or return null. Value * simplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, @@ -290,13 +315,27 @@ const SimplifyQuery &Q); /// Given operands for a BinaryOperator, fold the result or return null. +template Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q) { + MatcherContext Matcher; + return simplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher); +} /// Given operands for a BinaryOperator, fold the result or return null. /// Try to use FastMathFlags when folding the result. +template Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF, - const SimplifyQuery &Q); + const SimplifyQuery &Q, MatcherContext &Matcher); + +inline Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q) { + MatcherContext Matcher; + return simplifyBinOp(Opcode, LHS, RHS, FMF, Q, Matcher); +} /// Given a callsite, fold the result or return null. Value *simplifyCall(CallBase *Call, const SimplifyQuery &Q); @@ -316,6 +355,11 @@ /// See if we can compute a simplified version of this instruction. If not, /// return null. +template +Value *simplifyInstructionWithOperandsAndTrait( + Instruction *I, ArrayRef NewOps, const SimplifyQuery &Q, + OptimizationRemarkEmitter *ORE = nullptr); + Value *simplifyInstruction(Instruction *I, const SimplifyQuery &Q, OptimizationRemarkEmitter *ORE = nullptr); diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -446,6 +446,21 @@ Optional getFunctionalOpcode() const { return getFunctionalOpcodeForVP(getIntrinsicID()); } + bool isFunctionalCommutative() const { + if (auto OpcodeOpt = getFunctionalOpcode()) + return Instruction::isCommutative(*OpcodeOpt); + return false; + } + bool isFunctionalUnaryOp() const { + if (auto OpcodeOpt = getFunctionalOpcode()) + return Instruction::isUnaryOp(*OpcodeOpt); + return false; + } + bool isFunctionalBinaryOp() const { + if (auto OpcodeOpt = getFunctionalOpcode()) + return Instruction::isBinaryOp(*OpcodeOpt); + return false; + } // Equivalent non-predicated opcode static Optional getFunctionalOpcodeForVP(Intrinsic::ID ID); @@ -510,10 +525,28 @@ public: bool isUnaryOp() const; bool isTernaryOp() const; + bool hasRoundingMode() const; Optional getRoundingMode() const; Optional getExceptionBehavior() const; bool isDefaultFPEnvironment() const; + Optional getFunctionalOpcode() const; + bool isFunctionalCommutative() const { + if (auto OpcOpt = getFunctionalOpcode()) + return Instruction::isCommutative(*OpcOpt); + return false; + } + bool isFunctionalUnaryOp() const { + if (auto OpcOpt = getFunctionalOpcode()) + return Instruction::isUnaryOp(*OpcOpt); + return false; + } + bool isFunctionalBinaryOp() const { + if (auto OpcOpt = getFunctionalOpcode()) + return Instruction::isBinaryOp(*OpcOpt); + return false; + } + // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I); static bool classof(const Value *V) { diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -39,17 +39,40 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Traits/SemanticTrait.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" + #include namespace llvm { namespace PatternMatch { +// Trait-match pattern in a given context and update the context. template bool match(Val *V, const Pattern &P) { + // TODO: Use this instead of the Trait-less Pattern::match() functions. This + // single function does the same as the Trait-less 'Pattern::match()' + // functions that are replicated once for every Pattern. return const_cast(P).match(V); } +// Trait-match pattern in a given context and update the context. +template +bool match(Val *V, const Pattern &P, MatcherContext &MContext) { + return const_cast(P).match(V, MContext); +} + +// Trait-match pattern and update the context on match. +template +bool try_match(Val *V, const Pattern &P, MatcherContext &MContext) { + MatcherContext CopyCtx(MContext); + if (const_cast(P).match(V, CopyCtx)) { + MContext = CopyCtx; + return true; + } + return false; +} + template bool match(ArrayRef Mask, const Pattern &P) { return const_cast(P).match(Mask); } @@ -59,8 +82,14 @@ OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + return V->hasOneUse() && SubPattern.match(V, MContext); } }; @@ -69,7 +98,16 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return trait_isa(V); + } }; /// Match an arbitrary value and ignore it. @@ -128,6 +166,11 @@ return true; } template bool match(ITy *V) { return check(V); } + template + bool match(ITy *V, MatcherContext &MContext) { + // FIXME: Ok, to ignore the context here? + return check(V); + } }; /// Match an arbitrary undef constant. This matches poison as well. @@ -169,7 +212,15 @@ match_unless(const Ty &Matcher) : M(Matcher) {} - template bool match(ITy *V) { return !M.match(V); } + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return !M.match(V, MContext); + } }; /// Match if the inner matcher does *NOT* match. @@ -185,11 +236,16 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) { - if (L.match(V)) - return true; - if (R.match(V)) + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (try_match(V, L, MContext)) { return true; - return false; + } + return R.match(V, MContext); } }; @@ -200,10 +256,13 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) - return true; - return false; + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return L.match(V, MContext) && R.match(V, MContext); } }; @@ -227,6 +286,12 @@ : Res(Res), AllowUndef(AllowUndef) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; @@ -252,6 +317,12 @@ : Res(Res), AllowUndef(AllowUndef) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; @@ -303,6 +374,12 @@ template struct constantint_match { template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -321,14 +398,20 @@ return constantint_match(); } -/// This helper class is used to match constant scalars, vector splats, -/// and fixed width vectors that satisfy a specified predicate. -/// For fixed width vector constants, undefined elements are ignored. +/// This helper class is used to match scalar and fixed width vector integer +/// constants that satisfy a specified predicate. +/// For vector constants, undefined elements are ignored. template struct cstval_pred_ty : public Predicate { template bool match(ITy *V) { - if (const auto *CV = dyn_cast(V)) - return this->isValue(CV->getValue()); + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (const auto *CI = dyn_cast(V)) + return this->isValue(CI->getValue()); if (const auto *VTy = dyn_cast(V->getType())) { if (const auto *C = dyn_cast(V)) { if (const auto *CV = dyn_cast_or_null(C->getSplatValue())) @@ -377,6 +460,12 @@ api_pred_ty(const APInt *&R) : Res(R) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -403,6 +492,12 @@ apf_pred_ty(const APFloat *&R) : Res(R) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -521,6 +616,12 @@ struct is_zero { template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { auto *C = dyn_cast(V); // FIXME: this should be able to do something for scalable vectors return C && (C->isNullValue() || cst_pred_ty().match(C)); @@ -694,6 +795,13 @@ bind_ty(Class *&V) : VR(V) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; if (auto *CV = dyn_cast(V)) { VR = CV; return true; @@ -759,7 +867,15 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + return V == Val; + } }; /// Match if we have a specific specified value. @@ -772,7 +888,15 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *const V, MatcherContext &MContext) { + return V == Val; + } }; /// Like m_Specific(), but works if the specific value to match is determined @@ -794,6 +918,12 @@ specific_fpval(double V) : Val(V) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -817,6 +947,14 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -834,6 +972,14 @@ specific_intval(APInt V) : Val(std::move(V)) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -872,6 +1018,14 @@ specific_bbval(BasicBlock *Val) : Val(Val) {} template bool match(ITy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(ITy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; const auto *BB = dyn_cast(V); return BB && BB == Val; } @@ -891,6 +1045,18 @@ return BB; } +template +static inline bool commutable_match(bool Commutable, LHS_t &L, RHS_t &R, + OpValueTy OpL, OpValueTy OpR, + MatcherContext &MContext) { + MatcherContext LRContext(MContext); + if (L.match(OpL, LRContext) && R.match(OpR, LRContext)) { + MContext = LRContext; + return true; + } + return Commutable && L.match(OpR, MContext) && R.match(OpL, MContext); +} + //===----------------------------------------------------------------------===// // Matcher for any binary operator. // @@ -904,11 +1070,21 @@ AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); - return false; + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + auto *I = trait_dyn_cast(V); + if (!I) + return false; + + if (!MContext.accept(I)) + return false; + + return commutable_match<>(Commutable, L, R, I->getOperand(0), + I->getOperand(1), MContext); } }; @@ -927,8 +1103,14 @@ AnyUnaryOp_match(const OP_t &X) : X(X) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return X.match(I->getOperand(0)); + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *I = trait_dyn_cast(V)) + return X.match(I->getOperand(0), MContext); return false; } }; @@ -951,22 +1133,36 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template inline bool match(unsigned Opc, OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opc) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { + MatcherContext Matcher; + return match(Opcode, V, Matcher); + } + + template bool match(unsigned Opc, OpTy *V) { + MatcherContext Matcher; + return match(Opc, V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &Matcher) { + return match(Opcode, V, Matcher); + } + + template + bool match(unsigned Opc, OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opc) { + return commutable_match(Commutable, L, R, I->getOperand(0), + I->getOperand(1), MContext); } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opc && - ((L.match(CE->getOperand(0)) && R.match(CE->getOperand(1))) || - (Commutable && L.match(CE->getOperand(1)) && - R.match(CE->getOperand(0)))); + commutable_match(Commutable, L, R, CE->getOperand(0), + CE->getOperand(1), MContext); return false; } - - template bool match(OpTy *V) { return match(Opcode, V); } }; template @@ -998,25 +1194,37 @@ FNeg_match(const Op_t &Op) : X(Op) {} template bool match(OpTy *V) { - auto *FPMO = dyn_cast(V); + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *FPMO = trait_dyn_cast(V); if (!FPMO) return false; - if (FPMO->getOpcode() == Instruction::FNeg) - return X.match(FPMO->getOperand(0)); + auto OPC = trait_cast(V)->getOpcode(); - if (FPMO->getOpcode() == Instruction::FSub) { + if (OPC == Instruction::FNeg) + return X.match(FPMO->getOperand(0), MContext); + + if (OPC == Instruction::FSub) { if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!try_match(FPMO->getOperand(0), cstfp_pred_ty(), + MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!try_match(FPMO->getOperand(0), cstfp_pred_ty(), + MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match(FPMO->getOperand(1), MContext); } return false; @@ -1129,7 +1337,14 @@ : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *Op = dyn_cast(V)) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *Op = trait_dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; if ((WrapFlags & OverflowingBinaryOperator::NoUnsignedWrap) && @@ -1138,7 +1353,8 @@ if ((WrapFlags & OverflowingBinaryOperator::NoSignedWrap) && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match(Op->getOperand(0), MContext) && + R.match(Op->getOperand(1), MContext); } return false; } @@ -1219,7 +1435,15 @@ : BinaryOp_match(LHS, RHS), Opcode(Opcode) {} template bool match(OpTy *V) { - return BinaryOp_match::match(Opcode, V); + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &MContext) { + return BinaryOp_match::template match(Opcode, V, + MContext); } }; @@ -1241,10 +1465,18 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); - if (auto *CE = dyn_cast(V)) + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && + L.match(I->getOperand(0), MContext) && + R.match(I->getOperand(1), MContext); + if (auto *CE = trait_dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); return false; @@ -1336,8 +1568,15 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *PEO = trait_dyn_cast(V)) + return PEO->isExact() && SubPattern.match(V, MContext); return false; } }; @@ -1363,12 +1602,27 @@ : Predicate(Pred), L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) { + MatcherContext LRContext(MContext); + if (L.match(I->getOperand(0), LRContext) && + R.match(I->getOperand(1), LRContext)) { + MContext = LRContext; Predicate = I->getPredicate(); return true; - } else if (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))) { + } + + if (!Commutable) + return false; + + if (L.match(I->getOperand(1), MContext) && + R.match(I->getOperand(0), MContext)) { Predicate = I->getSwappedPredicate(); return true; } @@ -1406,9 +1660,16 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext); } return false; } @@ -1422,9 +1683,17 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext); } return false; } @@ -1441,10 +1710,18 @@ : Op1(Op1), Op2(Op2), Op3(Op3) {} template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + auto *I = trait_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext) && + Op3.match(I->getOperand(2), MContext); } return false; } @@ -1497,8 +1774,16 @@ : Op1(Op1), Op2(Op2), Mask(Mask) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *I = trait_dyn_cast(V)) { + return Op1.match(I->getOperand(0), MContext) && + Op2.match(I->getOperand(1), MContext) && Mask.match(I->getShuffleMask()); } return false; @@ -1576,8 +1861,15 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &Matcher) { + if (!Matcher.accept(V)) + return false; + if (auto O = trait_dyn_cast(V)) + return O->getOpcode() == Opcode && Op.match(O->getOperand(0), Matcher); return false; } }; @@ -1692,7 +1984,14 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *BI = trait_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1713,9 +2012,16 @@ : Cond(C), T(t), F(f) {} template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) - return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *BI = trait_dyn_cast(V)) + if (BI->isConditional() && Cond.match(BI->getCondition(), MContext)) { + return T.match(BI->getSuccessor(0), MContext) && + F.match(BI->getSuccessor(1), MContext); + } return false; } }; @@ -1749,23 +2055,34 @@ MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} template bool match(OpTy *V) { - if (auto *II = dyn_cast(V)) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (!MContext.accept(V)) + return false; + if (auto *II = trait_dyn_cast(V)) { Intrinsic::ID IID = II->getIntrinsicID(); if ((IID == Intrinsic::smax && Pred_t::match(ICmpInst::ICMP_SGT)) || (IID == Intrinsic::smin && Pred_t::match(ICmpInst::ICMP_SLT)) || (IID == Intrinsic::umax && Pred_t::match(ICmpInst::ICMP_UGT)) || (IID == Intrinsic::umin && Pred_t::match(ICmpInst::ICMP_ULT))) { Value *LHS = II->getOperand(0), *RHS = II->getOperand(1); - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatcherContext LRContext(MContext); + if (L.match(LHS, LRContext) && R.match(RHS, LRContext)) { + MContext = LRContext; + return true; + } + return Commutable && L.match(RHS, MContext) && R.match(LHS, MContext); } } // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = trait_dyn_cast(V); + if (!SI || !MContext.accept(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = trait_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.accept(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1781,9 +2098,10 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + // TODO factor out commutative matching! + return commutable_match(Commutable, L, R, LHS, RHS, MContext); } }; @@ -1953,49 +2271,77 @@ : L(L), R(R), S(S) {} template bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; - if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) + if (!try_match(V, m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)), + MContext)) return false; Value *AddLHS, *AddRHS; auto AddExpr = m_Add(m_Value(AddLHS), m_Value(AddRHS)); // (a + b) u< a, (a + b) u< b - if (Pred == ICmpInst::ICMP_ULT) - if (AddExpr.match(ICmpLHS) && (ICmpRHS == AddLHS || ICmpRHS == AddRHS)) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); + if (Pred == ICmpInst::ICMP_ULT) { + if (try_match(ICmpLHS, AddExpr, MContext) && + (ICmpRHS == AddLHS || ICmpRHS == AddRHS)) { + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpLHS, MContext); + } + } // a >u (a + b), b >u (a + b) - if (Pred == ICmpInst::ICMP_UGT) - if (AddExpr.match(ICmpRHS) && (ICmpLHS == AddLHS || ICmpLHS == AddRHS)) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); + if (Pred == ICmpInst::ICMP_UGT) { + if (try_match(ICmpRHS, AddExpr, MContext) && + (ICmpLHS == AddLHS || ICmpLHS == AddRHS)) { + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpRHS, MContext); + } + } Value *Op1; auto XorExpr = m_OneUse(m_Xor(m_Value(Op1), m_AllOnes())); // (a ^ -1) u (a ^ -1) if (Pred == ICmpInst::ICMP_UGT) { - if (XorExpr.match(ICmpRHS)) - return L.match(Op1) && R.match(ICmpLHS) && S.match(ICmpRHS); + if (try_match(ICmpRHS, XorExpr, MContext)) { + return L.match(Op1, MContext) && R.match(ICmpLHS, MContext) && + S.match(ICmpRHS, MContext); + } } // Match special-case for increment-by-1. if (Pred == ICmpInst::ICMP_EQ) { // (a + 1) == 0 // (1 + a) == 0 - if (AddExpr.match(ICmpLHS) && m_ZeroInt().match(ICmpRHS) && - (m_One().match(AddLHS) || m_One().match(AddRHS))) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpLHS); + MatcherContext CopyCtx(MContext); + if (AddExpr.match(ICmpLHS, CopyCtx) && + m_ZeroInt().match(ICmpRHS, CopyCtx)) { + if (try_match(AddLHS, m_One(), CopyCtx) || + try_match(AddRHS, m_One(), CopyCtx)) { + MContext = CopyCtx; + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpLHS, MContext); + } + } // 0 == (a + 1) // 0 == (1 + a) - if (m_ZeroInt().match(ICmpLHS) && AddExpr.match(ICmpRHS) && - (m_One().match(AddLHS) || m_One().match(AddRHS))) - return L.match(AddLHS) && R.match(AddRHS) && S.match(ICmpRHS); + if (m_ZeroInt().match(ICmpLHS, MContext) && + AddExpr.match(ICmpRHS, MContext) && + (try_match(AddLHS, m_One(), MContext) || + m_One().match(AddRHS, MContext))) + return L.match(AddLHS, MContext) && R.match(AddRHS, MContext) && + S.match(ICmpRHS, MContext); } return false; @@ -2019,9 +2365,14 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} template bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) - return Val.match(CI->getArgOperand(OpI)); + if (const auto *CI = trait_dyn_cast(V)) + return Val.match(CI->getArgOperand(OpI), MContext); return false; } }; @@ -2039,7 +2390,12 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (const auto *CI = trait_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -2289,15 +2645,23 @@ NotForbidUndef_match(const ValTy &V) : Val(V) {} template bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &Matcher) { // We do not use m_c_Xor because that could match an arbitrary APInt that is // not -1 as C and then fail to match the other operand if it is -1. // This code should still work even when both operands are constants. Value *X; const APInt *C; - if (m_Xor(m_Value(X), m_APIntForbidUndef(C)).match(V) && C->isAllOnes()) - return Val.match(X); - if (m_Xor(m_APIntForbidUndef(C), m_Value(X)).match(V) && C->isAllOnes()) - return Val.match(X); + if (try_match(V, m_Xor(m_Value(X), m_APIntForbidUndef(C)), Matcher) && + C->isAllOnes()) + return Val.match(X, Matcher); + if (try_match(V, m_Xor(m_APIntForbidUndef(C), m_Value(X)), Matcher) && + C->isAllOnes()) + return Val.match(X, Matcher); return false; } }; @@ -2364,6 +2728,11 @@ Signum_match(const Opnd_t &V) : Val(V) {} template bool match(OpTy *V) { + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; @@ -2385,7 +2754,7 @@ auto RHS = m_LShr(m_Neg(m_Value(OpR)), m_SpecificInt(ShiftWidth)); auto Signum = m_Or(LHS, RHS); - return Signum.match(V) && OpL == OpR && Val.match(OpL); + return Signum.match(V, MContext) && OpL == OpR && Val.match(OpL, MContext); } }; @@ -2404,13 +2773,14 @@ ExtractValue_match(const Opnd_t &V) : Val(V) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - // If Ind is -1, don't inspect indices - if (Ind != -1 && - !(I->getNumIndices() == 1 && I->getIndices()[0] == (unsigned)Ind)) - return false; - return Val.match(I->getAggregateOperand()); - } + MatcherContext Matcher; + return match(V, Matcher); + } + template + bool match(OpTy *V, MatcherContext &MContext) { + if (auto *I = trait_dyn_cast(V)) + return I->getNumIndices() == 1 && I->getIndices()[0] == Ind && + Val.match(I->getAggregateOperand(), MContext); return false; } }; @@ -2437,9 +2807,16 @@ InsertValue_match(const T0 &Op0, const T1 &Op1) : Op0(Op0), Op1(Op1) {} template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) { - return Op0.match(I->getOperand(0)) && Op1.match(I->getOperand(1)) && - I->getNumIndices() == 1 && Ind == I->getIndices()[0]; + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(OpTy *V, MatcherContext &Matcher) { + if (auto *I = trait_dyn_cast(V)) { + return I->getNumIndices() == 1 && Ind == I->getIndices()[0] && + Op0.match(I->getOperand(0), Matcher) && + Op1.match(I->getOperand(1), Matcher); } return false; } @@ -2461,12 +2838,17 @@ VScaleVal_match(const DataLayout &DL) : DL(DL) {} template bool match(ITy *V) { + MatcherContext MContext; + return match(V, MContext); + } + template + bool match(ITy *V, MatcherContext &MContext) { if (m_Intrinsic().match(V)) return true; Value *Ptr; - if (m_PtrToInt(m_Value(Ptr)).match(V)) { - if (auto *GEP = dyn_cast(Ptr)) { + if (m_PtrToInt(m_Value(Ptr)).match(V, MContext)) { + if (auto *GEP = trait_dyn_cast(Ptr)) { auto *DerefTy = GEP->getSourceElementType(); if (GEP->getNumIndices() == 1 && isa(DerefTy) && m_Zero().match(GEP->getPointerOperand()) && @@ -2492,15 +2874,20 @@ LogicalOp_match(const LHS &L, const RHS &R) : L(L), R(R) {} template bool match(T *V) { - auto *I = dyn_cast(V); + MatcherContext Matcher; + return match(V, Matcher); + } + + template + bool match(T *V, MatcherContext &MContext) { + auto *I = trait_dyn_cast(V); if (!I || !I->getType()->isIntOrIntVectorTy(1)) return false; if (I->getOpcode() == Opcode) { auto *Op0 = I->getOperand(0); auto *Op1 = I->getOperand(1); - return (L.match(Op0) && R.match(Op1)) || - (Commutable && L.match(Op1) && R.match(Op0)); + return commutable_match(Commutable, L, R, Op0, Op1, MContext); } if (auto *Select = dyn_cast(I)) { @@ -2510,14 +2897,12 @@ if (Opcode == Instruction::And) { auto *C = dyn_cast(FVal); if (C && C->isNullValue()) - return (L.match(Cond) && R.match(TVal)) || - (Commutable && L.match(TVal) && R.match(Cond)); + return commutable_match(Commutable, L, R, Cond, TVal, MContext); } else { assert(Opcode == Instruction::Or); auto *C = dyn_cast(TVal); if (C && C->isOneValue()) - return (L.match(Cond) && R.match(FVal)) || - (Commutable && L.match(FVal) && R.match(Cond)); + return commutable_match(Commutable, L, R, Cond, FVal, MContext); } } diff --git a/llvm/include/llvm/IR/Traits/EnabledTraits.def b/llvm/include/llvm/IR/Traits/EnabledTraits.def new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/EnabledTraits.def @@ -0,0 +1,4 @@ +ENABLE_TRAIT(EmptyTrait) +ENABLE_TRAIT(CFPTrait) +ENABLE_TRAIT(VPTrait) +#undef ENABLE_TRAIT diff --git a/llvm/include/llvm/IR/Traits/SemanticTrait.h b/llvm/include/llvm/IR/Traits/SemanticTrait.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/SemanticTrait.h @@ -0,0 +1,149 @@ +//===- llvm/IR/Trait/SemanticTrait.h - Basic trait definitions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines semantic traits. +// Some intrinsics in LLVM can be described as being a regular operation (such +// as an fadd) at their core with an additional semantic trait. We do this to +// lift optimizations that are defined in terms of standard IR operations (eg +// fadd, fmul) to these intrinsics. We keep the existing patterns and rewrite +// machinery and transparently check the rewrite is consistent with the +// semantical trait(s) that are attached to the operations. +// +// For example: +// +// @llvm.vp.fadd(<256 x double> %a, <256 x double> %b, +// <256 x i1> %m. i32 %avl) +// +// This is a vector-predicated fadd instruction with a %mask and %evl +// parameter +// (https://llvm.org/docs/LangRef.html#vector-predication-intrinsics). +// However, at its core it is just an 'fadd'. +// +// Consider the fma-fusion rewrite pattern (fadd (fmul x,y) z) --> (fma x, y, +// z). If the 'fadd' is actually an 'llvm.vp.fadd" and the 'fmul' is actually +// an 'llvm.vp.fmul', we can perform the rewrite using the %mask and %evl of +// the 'fadd' node. +// +// +// @llvm.experimental.constrained.fadd(double %a, double %b, +// metadata metadata, +// ) +// +// This is an fadd with a possibly non-default rounding mode and exception +// behavior. +// (https://llvm.org/docs/LangRef.html#constrained-floating-point-intrinsics). +// In this case, the operation matches the semantics of a regular 'fadd' +// exactly, if the rounding mode is 'round.tonearest' and the exception +// behavior is 'fpexcept.ignore'. +// Re-considering the case of fma fusion, this time with two constrained fp +// intrinsics. If the rounding mode is tonearest for both and neither of the +// 'llvm.experimental.contrained.fmul' has 'fpexcept.strict',, we are good to +// apply the rewrite and emit a contrained fma with the exception flad of the +// 'fadd'. +// +// There is also a proposal to add complex arithmetic intrinsics to LLVM. In +// that case, the operation is semantically an 'fadd', if we consider the space +// of complex floating-point numbers and their operations. +// +//===----------------------------------------------------------------------===// + +// Look for comments starting with "TODO(new trait)" to see what to implement to +// establish a new instruction trait. + +#ifndef LLVM_IR_TRAIT_SEMANTICTRAIT_H +#define LLVM_IR_TRAIT_SEMANTICTRAIT_H + +#include +#include +#include +#include + +namespace llvm { + +/// Type Casting { +/// These cast operators allow you to side-step the first-class type hierarchy +/// of LLVM (Value, Instruction, BinaryOperator, ..) into your custom type +/// hierarchy. +/// +/// trait_cast(V) +/// +/// actually casts \p V to ExtInstruction. +template struct TraitCast { + using ExtType = ValueDerivedType; +}; + +// This has to happen after all traits are defined since we are referring to +// members of template specialization for each Trait (The TraitCast::ExtType). +#define CASTING_TEMPLATE(CASTFUNC, PREFIXMOD, REFMOD) \ + template \ + static auto trait_##CASTFUNC(PREFIXMOD Value REFMOD V) \ + ->decltype(CASTFUNC< \ + typename TraitCast::ExtType>(V)) { \ + using TraitExtendedType = \ + typename TraitCast::ExtType; \ + return CASTFUNC(V); \ + } + +#define CONST_CAST_TEMPLATE(CASTFUNC, REFMOD) \ + CASTING_TEMPLATE(CASTFUNC, const, REFMOD) \ + CASTING_TEMPLATE(CASTFUNC, , REFMOD) + +// 'dyn_cast' (allow [const] Value*) +CONST_CAST_TEMPLATE(dyn_cast, *) + +// 'cast' (allow [const] Value(*|&)) +CONST_CAST_TEMPLATE(cast, *) +CONST_CAST_TEMPLATE(cast, &) + +// 'isa' +CONST_CAST_TEMPLATE(isa, *) +CONST_CAST_TEMPLATE(isa, &) +/// } Type Casting + +// TODO +// The trait builder is a specialized IRBuilder that emits trait-compatible +// instructions. +template struct TraitBuilder {}; + +// This is used in pattern matching to check that all instructions in the +// pattern are trait-compatible. +template struct MatcherContext { + // Check whether \p Val is compatible with this context and merge its + // properties. \returns Whether \p Val is compatible with the current state of + // the context. + bool accept(const Value *Val) { return Val; } + + // Like accept() but does not modify the context. + bool check(const Value *Val) const { return Val; } + + // Whether to allow constant folding with the currently accepted operators and + // their operands. + bool allowConstantFolding() const { return true; } +}; + +/// Empty Trait { +/// +/// This defined the empty trait without properties. Type casting stays in the +/// standard llvm::Value type hierarchy. + +// Trait without any difference to standard IR +struct EmptyTrait { + // This is to block reassociation for traits that do not support it. + static constexpr bool AllowReassociation = true; + + // Whether \p V should be considered at all with this trait. + static bool consider(const Value *) { return true; } +}; + +using DefaultTrait = EmptyTrait; + +/// } Empty Trait + +} // end namespace llvm + +#endif // LLVM_IR_TRAIT_SEMANTICTRAIT_H diff --git a/llvm/include/llvm/IR/Traits/Traits.h b/llvm/include/llvm/IR/Traits/Traits.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/Traits/Traits.h @@ -0,0 +1,331 @@ +//===- llvm/IR/Trait/SemanticTrait.h - Basic trait definitions --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines semantic traits. +// Some intrinsics in LLVM can be described as being a regular operation (such +// as an fadd) at their core with an additional semantic trait. We do this to +// lift optimizations that are defined in terms of standard IR operations (eg +// fadd, fmul) to these intrinsics. We keep the existing patterns and rewrite +// machinery and transparently check the rewrite is consistent with the +// semantical trait(s) that are attached to the operations. +// +// For example: +// +// @llvm.vp.fadd(<256 x double> %a, <256 x double> %b, +// <256 x i1> %m. i32 %avl) +// +// This is a vector-predicated fadd instruction with a %mask and %evl +// parameter +// (https://llvm.org/docs/LangRef.html#vector-predication-intrinsics). +// However, at its core it is just an 'fadd'. +// +// Consider the simplification (add (sub x,y), y) --> x. If the 'add' is +// actually an 'llvm.vp.add" and the 'sub' is really an 'llvm.vp.sub', we can +// do the simplification. the 'fadd' node. +// +// +// @llvm.experimental.constrained.fadd(double %a, double %b, +// metadata metadata, +// ) +// +// This is an fadd with a possibly non-default rounding mode and exception +// behavior. +// (https://llvm.org/docs/LangRef.html#constrained-floating-point-intrinsics). +// The constrained fp intrinsic has exactly the semantics of a regular 'fadd', +// if the rounding mode is 'round.tonearest' and the exception behavior is +// 'fpexcept.ignore'. +// We can use all simplifying rewrites for regular fp arithmetic also for +// constrained fp arithmetic where this applies. +// +// There is also a proposal to add complex arithmetic intrinsics to LLVM. In +// that case, the operation is semantically an 'fadd', if we consider the space +// of complex floating-point numbers and their operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Traits/SemanticTrait.h" + +#ifndef LLVM_IR_TRAITS_TRAITS_H +#define LLVM_IR_TRAITS_TRAITS_H + +namespace llvm { + +/// Base classes { +// Shared functionality for extended instructions +// These define all functions used in PatterMatch as 'virtual' to remind you to +// implement them. + +// Make sure no member of the Ext.. hierarchy can be constructed. +struct ExtBase { + ExtBase() = delete; + ~ExtBase() = delete; + ExtBase &operator=(const ExtBase &) = delete; + void *operator new(size_t s) = delete; +}; + +// Mirror generic functionality of llvm::Instruction. +struct ExtInstructionBase : public ExtBase, public User { + BasicBlock *getParent() { return cast(this)->getParent(); } + const BasicBlock *getParent() const { + return cast(this)->getParent(); + } + + void copyIRFlags(const Value *V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } +}; + +/// } Base classes + +/// Intrinsic-Trait Template { +// Generic instantiation for traits that want to masquerade their intrinsic as a +// regular IR instruction. + +// Pretend to be a llvm::Operator. +template struct ExtOperator : public ExtBase, public User { + unsigned getOpcode() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) + if (auto OpcOpt = Intrin->getFunctionalOpcode()) + return *OpcOpt; + + // Default Operator opcode. + return cast(this)->getOpcode(); + } + + bool hasNoSignedWrap() const { + if (const auto *OverflowingBOp = dyn_cast(this)) + return OverflowingBOp->hasNoSignedWrap(); + return false; + } + + bool hasNoUnsignedWrap() const { + if (const auto *OverflowingBOp = dyn_cast(this)) + return OverflowingBOp->hasNoUnsignedWrap(); + return false; + } + + // Every operator is also an extended operator. + static bool classof(const Value *V) { return isa(V); } +}; + +// Pretend to be a llvm::Instruction. +template +struct ExtInstruction final : public ExtInstructionBase { + unsigned getOpcode() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) + if (auto OpcOpt = Intrin->getFunctionalOpcode()) + return *OpcOpt; + // Default opcode. + return cast(this)->getOpcode(); + } + + static bool classof(const Value *V) { return isa(V); } +}; + +// Pretend to be a (different) llvm::IntrinsicInst. +template +struct ExtIntrinsic final : public ExtInstructionBase { + Intrinsic::ID getIntrinsicID() const { + // Use the intrinsic override. + if (const auto *TraitIntrin = dyn_cast(this)) + return TraitIntrin->getFunctionalIntrinsic(); + // Default intrinsic opcode. + return cast(this)->getIntrinsicID(); + } + + unsigned getOpcode() const { + // We are looking at this as an intrinsic -> do not hide this. + return cast(this)->getOpcode(); + } + + bool isCommutative() const { + // The underlying intrinsic may not specify whether it is commutative. + // Query our own interface to be sure this is done right. + // Use the intrinsic override. + if (const auto *TraitIntrin = dyn_cast(this)) + return TraitIntrin->isFunctionalCommutative(); + return cast(this)->isFunctionalCommutative(); + } + + static bool classof(const Value *V) { return IntrinsicInst::classof(V); } +}; + +template +struct ExtCmpBase : public ExtInstructionBase { + unsigned getOpcode() const { return OPC; } + + FCmpInst::Predicate getPredicate() const { + // Use the intrinsic override. + if (const auto *Intrin = dyn_cast(this)) { + return Intrin->getPredicate(); + } + + // Default opcode. + return cast(this)->getPredicate(); + } +}; + +template +static bool classofExtCmpBase(const Value *V) { + return isa(V) || + cast->getFunctionalOpcode() == OPC; +} + +// Pretend to be a llvm::FCmpInst. +template +struct ExtFCmpInst final + : public ExtCmpBase { + static bool classof(const Value *V) { + return classofExtCmpBase(V); + } +}; + +// Pretend to be a llvm::ICmpInst. +template +struct ExtICmpInst final + : public ExtCmpBase { + static bool classof(const Value *V) { + return classofExtCmpBase(V); + } +}; + +// Pretend to be a BinaryOperator. +template +struct ExtBinaryOperator final : public ExtOperator { + using BinaryOps = Instruction::BinaryOps; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + const auto *Intrin = dyn_cast(I); + return Intrin && Intrin->isFunctionalBinaryOp(); + } + static bool classof(const ConstantExpr *CE) { + return isa(CE); + } + static bool classof(const Value *V) { + if (const auto *I = dyn_cast(V)) + return classof(I); + + if (const auto *CE = dyn_cast(V)) + return classof(CE); + + return false; + } +}; + +// Pretend to be a UnaryOperator. +template +struct ExtUnaryOperator final : public ExtOperator { + using BinaryOps = Instruction::BinaryOps; + + static bool classof(const Instruction *I) { + if (isa(I)) + return true; + const auto *Intrin = dyn_cast(I); + return Intrin && Intrin->isFunctionalUnaryOp(); + } +}; + +// TODO Implement other extended types. + +/// Template-specialization for the Ext type hierarchy { +//// Enable the ExtSOMETHING for your trait +#define INTRINSIC_TRAIT_SPECIALIZE(TRAIT, TYPE) \ + template <> struct TraitCast { \ + using ExtType = Ext##TYPE; \ + }; \ + template <> struct TraitCast { \ + using ExtType = const Ext##TYPE; \ + }; + +/// } Trait Template Classes + +// Constraint fp trait. +struct CFPTrait { + using Intrinsic = ConstrainedFPIntrinsic; + static constexpr bool AllowReassociation = false; + + // Whether \p V should be considered at all with this trait. + // It is not possible to mix constrained and unconstrained ops. + // Only apply this trait with the constrained variant. + static bool consider(const Value *V) { + return isa(V); + } +}; +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, Instruction) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, Operator) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, BinaryOperator) +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, UnaryOperator) +// Deflect queries for the Predicate to the ConstrainedFPCmpIntrinsic. +template <> struct ExtFCmpInst : public ExtInstructionBase { + unsigned getOpcode() const { return Instruction::FCmp; } + + FCmpInst::Predicate getPredicate() const { + return cast(this)->getPredicate(); + } + + bool classof(const Value *V) { return isa(V); } +}; +INTRINSIC_TRAIT_SPECIALIZE(CFPTrait, FCmpInst) + +// Accept all constrained fp intrinsics that are actually not constrained. +template <> struct MatcherContext { + bool accept(const Value *Val) { return check(Val); } + bool check(const Value *Val) const { + if (!Val) + return false; + const auto *CFP = dyn_cast(Val); + if (!CFP) + return true; + auto RoundingOpt = CFP->hasRoundingMode() ? CFP->getRoundingMode() : None; + auto ExceptOpt = CFP->getExceptionBehavior(); + return (!ExceptOpt || ExceptOpt == fp::ExceptionBehavior::ebIgnore) && + (!RoundingOpt || (RoundingOpt == RoundingMode::NearestTiesToEven)); + } +}; + +// Vector-predicated trait. +struct VPTrait { + using Intrinsic = VPIntrinsic; + // TODO Enable re-association. + static constexpr bool AllowReassociation = false; + // VP intrinsic mix with regular IR instructions. + // TODO: Adapt this to work with other than arithmetic VP ops. + static bool consider(const Value *V) { + return V->getType()->isVectorTy() && + V->getType()->getScalarType()->isIntegerTy(); + } +}; +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, Instruction) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, Operator) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, BinaryOperator) +INTRINSIC_TRAIT_SPECIALIZE(VPTrait, UnaryOperator) + +// Accept everything that passes as a VPIntrinsic. +template <> struct MatcherContext { + // TODO: pick up %mask and %evl here and use them to generate code again. We + // only remove instructions for the moment. + bool accept(const Value *Val) { return Val; } + bool check(const Value *Val) const { return Val; } +}; + +} // namespace llvm + +#undef INTRINSIC_TRAIT_SPECIALIZE + +#endif // LLVM_IR_TRAITS_TRAITS_H diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -39,6 +39,8 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Traits/Traits.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include using namespace llvm; @@ -58,8 +60,13 @@ const SimplifyQuery &, unsigned); static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); +template +static Value *simplifyBinOp(unsigned, Value *, Value *, const SimplifyQuery &, + MatcherContext &, unsigned); +template static Value *simplifyBinOp(unsigned, Value *, Value *, const FastMathFlags &, - const SimplifyQuery &, unsigned); + const SimplifyQuery &, MatcherContext &, + unsigned); static Value *simplifyCmpInst(unsigned, Value *, Value *, const SimplifyQuery &, unsigned); static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -287,33 +294,47 @@ /// Generic simplifications for associative binary operations. /// Returns the simpler value, or null if none was found. -static Value *simplifyAssociativeBinOp(Instruction::BinaryOps Opcode, - Value *LHS, Value *RHS, - const SimplifyQuery &Q, - unsigned MaxRecurse) { +template +static Value * +simplifyAssociativeBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q, MatcherContext &Matcher, + unsigned MaxRecurse) { + // This trait blocks re-association. + // Eg. any trait that adds side-effects may clash with free reassociation + // (FIXME are 'fpexcept.strict', 'fast' fp ops a thing?) + // FIXME associativity may depend on a trait parameter of this specific + // instance. + if (!Trait::AllowReassociation) + return nullptr; + assert(Instruction::isAssociative(Opcode) && "Not an associative operation!"); // Recursion is always used, so bail out at once if we already hit the limit. if (!MaxRecurse--) return nullptr; - BinaryOperator *Op0 = dyn_cast(LHS); - BinaryOperator *Op1 = dyn_cast(RHS); + auto *Op0 = trait_dyn_cast(LHS); + auto *Op1 = trait_dyn_cast(RHS); // Transform: "(A op B) op C" ==> "A op (B op C)" if it simplifies completely. - if (Op0 && Op0->getOpcode() == Opcode) { + MatcherContext Op0Matcher(Matcher); + if (Op0Matcher.accept(Op0) && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); Value *B = Op0->getOperand(1); Value *C = RHS; // Does "B op C" simplify? - if (Value *V = simplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { + if (Value *V = + simplifyBinOp(Opcode, B, C, Q, Op0Matcher, MaxRecurse)) { // It does! Return "A op V" if it simplifies or is already available. // If V equals B then "A op V" is just the LHS. - if (V == B) + if (V == B) { + Matcher = Op0Matcher; return LHS; + } // Otherwise return "A op V" if it simplifies. - if (Value *W = simplifyBinOp(Opcode, A, V, Q, MaxRecurse)) { + if (Value *W = simplifyBinOp(Opcode, A, V, Q, Op0Matcher, MaxRecurse)) { + Matcher = Op0Matcher; ++NumReassoc; return W; } @@ -321,19 +342,23 @@ } // Transform: "A op (B op C)" ==> "(A op B) op C" if it simplifies completely. - if (Op1 && Op1->getOpcode() == Opcode) { + MatcherContext Op1Matcher(Matcher); + if (Op1Matcher.accept(Op1) && Op1->getOpcode() == Opcode) { Value *A = LHS; Value *B = Op1->getOperand(0); Value *C = Op1->getOperand(1); // Does "A op B" simplify? - if (Value *V = simplifyBinOp(Opcode, A, B, Q, MaxRecurse)) { + if (Value *V = simplifyBinOp(Opcode, A, B, Q, Op1Matcher, MaxRecurse)) { // It does! Return "V op C" if it simplifies or is already available. // If V equals B then "V op C" is just the RHS. - if (V == B) + if (V == B) { + Matcher = Op1Matcher; return RHS; + } // Otherwise return "V op C" if it simplifies. - if (Value *W = simplifyBinOp(Opcode, V, C, Q, MaxRecurse)) { + if (Value *W = simplifyBinOp(Opcode, V, C, Q, Op1Matcher, MaxRecurse)) { + Matcher = Op1Matcher; ++NumReassoc; return W; } @@ -341,23 +366,30 @@ } // The remaining transforms require commutativity as well as associativity. + // FIXME commutativity may depend on a trait parameter of this specific + // instance. Eg, matrix multiplication is associative but not commutative. if (!Instruction::isCommutative(Opcode)) return nullptr; // Transform: "(A op B) op C" ==> "(C op A) op B" if it simplifies completely. + MatcherContext CommOp0Matcher(Matcher); if (Op0 && Op0->getOpcode() == Opcode) { Value *A = Op0->getOperand(0); Value *B = Op0->getOperand(1); Value *C = RHS; // Does "C op A" simplify? - if (Value *V = simplifyBinOp(Opcode, C, A, Q, MaxRecurse)) { + if (Value *V = simplifyBinOp(Opcode, C, A, Q, CommOp0Matcher, MaxRecurse)) { // It does! Return "V op B" if it simplifies or is already available. // If V equals A then "V op B" is just the LHS. - if (V == A) + if (V == A) { + Matcher = CommOp0Matcher; return LHS; + } // Otherwise return "V op B" if it simplifies. - if (Value *W = simplifyBinOp(Opcode, V, B, Q, MaxRecurse)) { + MatcherContext VContext(Matcher); + if (Value *W = simplifyBinOp(Opcode, V, B, Q, VContext, MaxRecurse)) { + Matcher = VContext; ++NumReassoc; return W; } @@ -387,6 +419,14 @@ return nullptr; } +static Value *simplifyAssociativeBinOp(Instruction::BinaryOps Opcode, + Value *LHS, Value *RHS, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + MatcherContext Matcher; + return simplifyAssociativeBinOp(Opcode, LHS, RHS, Q, Matcher, MaxRecurse); +} + /// In the case of a binary operation with a select instruction as an operand, /// try to simplify the binop by seeing whether evaluating it on both branches /// of the select results in the same value. Returns the common value if so, @@ -614,9 +654,12 @@ } /// Given operands for an Add, see if we can fold the result. -/// If not, this returns null. +/// If not, thi:s returns null. +template static Value *simplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, - const SimplifyQuery &Q, unsigned MaxRecurse) { + const SimplifyQuery &Q, + MatcherContext &Matcher, + unsigned MaxRecurse) { if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q)) return C; @@ -629,7 +672,7 @@ return Op1; // X + 0 -> X - if (match(Op1, m_Zero())) + if (try_match(Op1, m_Zero(), Matcher)) return Op0; // If two operands are negative, return 0. @@ -640,24 +683,28 @@ // (Y - X) + X -> Y // Eg: X + -X -> 0 Value *Y = nullptr; - if (match(Op1, m_Sub(m_Value(Y), m_Specific(Op0))) || - match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)))) + if (try_match(Op1, m_Sub(m_Value(Y), m_Specific(Op0)), Matcher) || + try_match(Op0, m_Sub(m_Value(Y), m_Specific(Op1)), Matcher)) return Y; // X + ~X -> -1 since ~X = -X-1 Type *Ty = Op0->getType(); - if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) + if (try_match(Op0, m_Not(m_Specific(Op1)), Matcher) || + try_match(Op1, m_Not(m_Specific(Op0)), Matcher)) return Constant::getAllOnesValue(Ty); // add nsw/nuw (xor Y, signmask), signmask --> Y // The no-wrapping add guarantees that the top bit will be set by the add. // Therefore, the xor must be clearing the already set sign bit of Y. - if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) && - match(Op0, m_Xor(m_Value(Y), m_SignMask()))) + MatcherContext CopyMatch(Matcher); + if ((IsNSW || IsNUW) && match(Op1, m_SignMask(), CopyMatch) && + match(Op0, m_Xor(m_Value(Y), m_SignMask()), CopyMatch)) { + Matcher = CopyMatch; return Y; + } // add nuw %x, -1 -> -1, because %x can only be 0. - if (IsNUW && match(Op1, m_AllOnes())) + if (IsNUW && try_match(Op1, m_AllOnes(), Matcher)) return Op1; // Which is -1. /// i1 add -> xor. @@ -682,9 +729,12 @@ return nullptr; } +template Value *llvm::simplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, - const SimplifyQuery &Query) { - return ::simplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit); + const SimplifyQuery &Query, + MatcherContext &Matcher) { + return ::simplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, Matcher, + RecursionLimit); } /// Compute the base pointer and cumulative constant offsets for V. @@ -3641,7 +3691,9 @@ if (LHS_CR.icmp(Pred, RHS_CR)) return ConstantInt::getTrue(RHS->getContext()); - if (LHS_CR.icmp(CmpInst::getInversePredicate(Pred), RHS_CR)) + auto InversedSatisfied_CR = ConstantRange::makeSatisfyingICmpRegion( + CmpInst::getInversePredicate(Pred), RHS_CR); + if (InversedSatisfied_CR.contains(LHS_CR)) return ConstantInt::getFalse(RHS->getContext()); } } @@ -4206,6 +4258,36 @@ return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); + + // If all operands are constant after substituting Op for RepOp then we can + // constant fold the instruction. + if (Constant *CRepOp = dyn_cast(RepOp)) { + // Build a list of all constant operands. + SmallVector ConstOps; + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + if (I->getOperand(i) == Op) + ConstOps.push_back(CRepOp); + else if (Constant *COp = dyn_cast(I->getOperand(i))) + ConstOps.push_back(COp); + else + break; + } + + // All operands were constants, fold it. + if (ConstOps.size() == I->getNumOperands()) { + if (CmpInst *C = dyn_cast(I)) + return ConstantFoldCompareInstOperands(C->getPredicate(), ConstOps[0], + ConstOps[1], Q.DL, Q.TLI); + + if (LoadInst *LI = dyn_cast(I)) + if (!LI->isVolatile()) + return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); + + return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); + } + } + + return nullptr; } Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, @@ -5146,9 +5228,11 @@ /// Given operands for an FAdd, see if we can fold the result. If not, this /// returns null. +template static Value * simplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse, + const SimplifyQuery &Q, MatcherContext &Matcher, + unsigned MaxRecurse, fp::ExceptionBehavior ExBehavior = fp::ebIgnore, RoundingMode Rounding = RoundingMode::NearestTiesToEven) { if (isDefaultFPEnvironment(ExBehavior, Rounding)) @@ -5166,12 +5250,12 @@ if (canIgnoreSNaN(ExBehavior, FMF) && (!canRoundingModeBe(Rounding, RoundingMode::TowardNegative) || FMF.noSignedZeros())) - if (match(Op1, m_NegZeroFP())) + if (try_match(Op1, m_NegZeroFP(), Matcher)) return Op0; // fadd X, 0 ==> X, when we know X is not -0 if (canIgnoreSNaN(ExBehavior, FMF)) - if (match(Op1, m_PosZeroFP()) && + if (try_match(Op1, m_PosZeroFP(), Matcher) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; @@ -5186,12 +5270,12 @@ // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 if (FMF.noNaNs()) { - if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || - match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)))) + if (try_match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1)), Matcher) || + try_match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0)), Matcher)) return ConstantFP::getNullValue(Op0->getType()); - if (match(Op0, m_FNeg(m_Specific(Op1))) || - match(Op1, m_FNeg(m_Specific(Op0)))) + if (try_match(Op0, m_FNeg(m_Specific(Op1)), Matcher) || + try_match(Op1, m_FNeg(m_Specific(Op0)), Matcher)) return ConstantFP::getNullValue(Op0->getType()); } @@ -5199,8 +5283,8 @@ // Y + (X - Y) --> X Value *X; if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op0, m_FSub(m_Value(X), m_Specific(Op1))) || - match(Op1, m_FSub(m_Value(X), m_Specific(Op0))))) + (try_match(Op0, m_FSub(m_Value(X), m_Specific(Op1)), Matcher) || + try_match(Op1, m_FSub(m_Value(X), m_Specific(Op0)), Matcher))) return X; return nullptr; @@ -5316,12 +5400,14 @@ return simplifyFMAFMul(Op0, Op1, FMF, Q, MaxRecurse, ExBehavior, Rounding); } +template Value *llvm::simplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, + MatcherContext &Matcher, fp::ExceptionBehavior ExBehavior, RoundingMode Rounding) { - return ::simplifyFAddInst(Op0, Op1, FMF, Q, RecursionLimit, ExBehavior, - Rounding); + return ::simplifyFAddInst(Op0, Op1, FMF, Q, Matcher, RecursionLimit, + ExBehavior, Rounding); } Value *llvm::simplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, @@ -5480,11 +5566,14 @@ /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. -static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q, unsigned MaxRecurse) { +template +static Value * +simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const SimplifyQuery &Q, + MatcherContext &Matcher, unsigned MaxRecurse) { switch (Opcode) { case Instruction::Add: - return simplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse); + return simplifyAddInst(LHS, RHS, false, false, Q, Matcher, + MaxRecurse); case Instruction::Sub: return simplifySubInst(LHS, RHS, false, false, Q, MaxRecurse); case Instruction::Mul: @@ -5510,7 +5599,8 @@ case Instruction::Xor: return simplifyXorInst(LHS, RHS, Q, MaxRecurse); case Instruction::FAdd: - return simplifyFAddInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); + return simplifyFAddInst(LHS, RHS, FastMathFlags(), Q, Matcher, + MaxRecurse); case Instruction::FSub: return simplifyFSubInst(LHS, RHS, FastMathFlags(), Q, MaxRecurse); case Instruction::FMul: @@ -5524,15 +5614,23 @@ } } +static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + MatcherContext Matcher; + return simplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, MaxRecurse); +} + /// Given operands for a BinaryOperator, see if we can fold the result. /// If not, this returns null. /// Try to use FastMathFlags when folding the result. +template static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const FastMathFlags &FMF, const SimplifyQuery &Q, + MatcherContext &Matcher, unsigned MaxRecurse) { switch (Opcode) { case Instruction::FAdd: - return simplifyFAddInst(LHS, RHS, FMF, Q, MaxRecurse); + return simplifyFAddInst(LHS, RHS, FMF, Q, Matcher, MaxRecurse); case Instruction::FSub: return simplifyFSubInst(LHS, RHS, FMF, Q, MaxRecurse); case Instruction::FMul: @@ -5540,19 +5638,31 @@ case Instruction::FDiv: return simplifyFDivInst(LHS, RHS, FMF, Q, MaxRecurse); default: - return simplifyBinOp(Opcode, LHS, RHS, Q, MaxRecurse); + return simplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, MaxRecurse); } } +template Value *llvm::simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - const SimplifyQuery &Q) { - return ::simplifyBinOp(Opcode, LHS, RHS, Q, RecursionLimit); + const SimplifyQuery &Q, + MatcherContext &Matcher) { + return ::simplifyBinOp<>(Opcode, LHS, RHS, Q, Matcher, RecursionLimit); } +template Value *llvm::simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q) { - return ::simplifyBinOp(Opcode, LHS, RHS, FMF, Q, RecursionLimit); -} + FastMathFlags FMF, const SimplifyQuery &Q, + MatcherContext &Matcher) { + return ::simplifyBinOp<>(Opcode, LHS, RHS, FMF, Q, Matcher, RecursionLimit); +} +#define ENABLE_TRAIT(TRAIT) \ + template Value *llvm::simplifyBinOp(unsigned, Value *, Value *, \ + const SimplifyQuery &, \ + MatcherContext &); \ + template Value *llvm::simplifyBinOp(unsigned, Value *, Value *, \ + FastMathFlags, const SimplifyQuery &, \ + MatcherContext &); +#include "llvm/IR/Traits/EnabledTraits.def" /// Given operands for a CmpInst, see if we can fold the result. static Value *simplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, @@ -6308,14 +6418,28 @@ /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -static Value *simplifyInstructionWithOperands(Instruction *I, - ArrayRef NewOps, - const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { +// FIXME: this will break if masquerading intrinsics do not pass muster. +template +Value *llvm::simplifyInstructionWithOperandsAndTrait( + Instruction *I, ArrayRef NewOps, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); Value *Result = nullptr; - switch (I->getOpcode()) { + // Allow Traits to bail for cases we do not want to implement. + if (!Trait::consider(I)) + return nullptr; + + // Create an initial context rooted at I. + MatcherContext Matcher; + if (!Matcher.accept(I)) + return nullptr; + + // Cast into the Trait type hierarchy since I may have a different opcode + // there. + // Eg llvm.*.constrained.fadd(%x, %y, %fpround, %fpexcept) is an 'fadd'. + const auto *TraitInst = trait_cast(I); + switch (TraitInst->getOpcode()) { default: if (llvm::all_of(NewOps, [](Value *V) { return isa(V); })) { SmallVector NewConstOps(NewOps.size()); @@ -6328,20 +6452,25 @@ Result = simplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q); break; case Instruction::FAdd: - Result = simplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); + Result = simplifyFAddInst(NewOps[0], NewOps[1], + I->getFastMathFlags(), Q, Matcher); break; case Instruction::Add: - Result = simplifyAddInst( - NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = simplifyAddInst( + NewOps[0], NewOps[1], + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q, + Matcher); break; case Instruction::FSub: Result = simplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Sub: + // TODO: Add Trait abstraction Result = simplifySubInst( - NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + NewOps[0], NewOps[1], + Q.IIQ.hasNoSignedWrap(trait_cast(I)), + Q.IIQ.hasNoUnsignedWrap(trait_cast(I)), Q); break; case Instruction::FMul: Result = simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); @@ -6445,7 +6574,8 @@ #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = simplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q); + Result = + simplifyCastInst(TraitInst->getOpcode(), NewOps[0], I->getType(), Q); break; case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. @@ -6461,6 +6591,33 @@ /// detecting that case here, returning a safe value instead. return Result == I ? UndefValue::get(I->getType()) : Result; } +#define ENABLE_TRAIT(TRAIT) \ + template Value *llvm::simplifyInstructionWithOperandsAndTrait( \ + Instruction *, ArrayRef NewOps, const SimplifyQuery &, \ + OptimizationRemarkEmitter *); +#include "llvm/IR/Traits/EnabledTraits.def" + +Value *llvm::simplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + SmallVector Ops(I->operands()); + + // Either all or no fp operations in a function are constrained. + if (CFPTrait::consider(I)) { + if (auto *Result = + simplifyInstructionWithOperandsAndTrait(I, Ops, SQ, ORE)) + return Result; + } + + /// Vector-predicated code. + /// FIXME: We use a quick heuristics (is this a vector type?) for now. + if (VPTrait::consider(I)) { + if (auto *Result = + simplifyInstructionWithOperandsAndTrait(I, Ops, SQ, ORE)) + return Result; + } + + return simplifyInstructionWithOperandsAndTrait(I, Ops, SQ, ORE); +} Value *llvm::simplifyInstructionWithOperands(Instruction *I, ArrayRef NewOps, @@ -6468,13 +6625,8 @@ OptimizationRemarkEmitter *ORE) { assert(NewOps.size() == I->getNumOperands() && "Number of operands should match the instruction!"); - return ::simplifyInstructionWithOperands(I, NewOps, SQ, ORE); -} - -Value *llvm::simplifyInstruction(Instruction *I, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { - SmallVector Ops(I->operands()); - return ::simplifyInstructionWithOperands(I, Ops, SQ, ORE); + return ::simplifyInstructionWithOperandsAndTrait(I, NewOps, SQ, + ORE); } /// Implementation of recursive simplification through an instruction's diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -197,6 +197,17 @@ return ConstantInt::get(Type::getInt64Ty(Context), 1); } +bool ConstrainedFPIntrinsic::hasRoundingMode() const { + switch (getIntrinsicID()) { + default: + return false; +#define INSTRUCTION(N, A, R, I) \ + case Intrinsic::I: \ + return R; +#include "llvm/IR/ConstrainedOps.def" + } +} + Optional ConstrainedFPIntrinsic::getRoundingMode() const { unsigned NumOperands = arg_size(); Metadata *MD = nullptr; @@ -262,6 +273,21 @@ return getFPPredicateFromMD(getArgOperand(2)); } +Optional ConstrainedFPIntrinsic::getFunctionalOpcode() const { + switch (getIntrinsicID()) { + default: + // Just some intrinsic call + return None; + +#define DAG_FUNCTION(OPC, FPEXCEPT, FPROUND, INTRIN, SD) + +#define DAG_INSTRUCTION(OPC, FPEXCEPT, FPROUND, INTRIN, SD) \ + case Intrinsic::INTRIN: \ + return OPC; +#include "llvm/IR/ConstrainedOps.def" + } +} + bool ConstrainedFPIntrinsic::isUnaryOp() const { switch (getIntrinsicID()) { default: diff --git a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp --- a/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -2249,6 +2249,11 @@ match_LoopInvariant(const SubPattern_t &SP, const Loop *L) : SubPattern(SP), L(L) {} + // FIXME: Existing trait-unaware patterns should keep working without code changes. + template bool match(ITy *V, MatcherContext &MContext) { + return L->isLoopInvariant(V) && SubPattern.match(V, MContext); + } + template bool match(ITy *V) { return L->isLoopInvariant(V) && SubPattern.match(V); } diff --git a/llvm/test/Transforms/InstSimplify/add_vp.ll b/llvm/test/Transforms/InstSimplify/add_vp.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/add_vp.ll @@ -0,0 +1,76 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instsimplify -S | FileCheck %s + +declare <2 x i32> @llvm.vp.add.v2i32(<2 x i32>, <2 x i32>, <2 x i1>, i32) +declare <2 x i32> @llvm.vp.sub.v2i32(<2 x i32>, <2 x i32>, <2 x i1>, i32) + +declare <2 x i8> @llvm.vp.add.v2i8(<2 x i8>, <2 x i8>, <2 x i1>, i32) +declare <2 x i8> @llvm.vp.sub.v2i8(<2 x i8>, <2 x i8>, <2 x i1>, i32) + +; Constant folding should just work. +define <2 x i32> @constant_vp_add(<2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @constant_vp_add( +; CHECK-NEXT: ret <2 x i32> +; + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> , <2 x i32> , <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Simplifying pure VP intrinsic patterns. +define <2 x i32> @common_sub_operand(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + ; %Z = sub i32 %X, %Y, vp(%mask, %evl) + %Z = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ; %Q = add i32 %Z, %Y, vp(%mask, %evl) + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %Z, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Mixing regular SIMD with vp intrinsics (vp add match root). +define <2 x i32> @common_sub_operand_vproot(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand_vproot( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + %Z = sub <2 x i32> %X, %Y + ; %Q = add i32 %Z, %Y, vp(%mask, %evl) + %Q = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %Z, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %Q +} + +; Mixing regular SIMD with vp intrinsics (vp inside pattern, regular instruction root). +define <2 x i32> @common_sub_operand_vpinner(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @common_sub_operand_vpinner( +; CHECK-NEXT: ret <2 x i32> [[X:%.*]] +; + ; %Z = sub i32 %X, %Y, vp(%mask, %evl) + %Z = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> %X, <2 x i32> %Y, <2 x i1> %mask, i32 %evl) + %Q = add <2 x i32> %Z, %Y + ret <2 x i32> %Q +} + +define <2 x i32> @negated_operand(<2 x i32> %x, <2 x i1> %mask, i32 %evl) { +; CHECK-LABEL: @negated_operand( +; CHECK-NEXT: ret <2 x i32> zeroinitializer +; + ; %negx = sub i32 0, %x + %negx = call <2 x i32> @llvm.vp.sub.v2i32(<2 x i32> zeroinitializer, <2 x i32> %x, <2 x i1> %mask, i32 %evl) + ; %r = add i32 %negx, %x + %r = call <2 x i32> @llvm.vp.add.v2i32(<2 x i32> %negx, <2 x i32> %x, <2 x i1> %mask, i32 %evl) + ret <2 x i32> %r +} + +; TODO Lift InstSimplify::SimplifyAdd to the trait framework to optimize this. +define <2 x i8> @knownnegation(<2 x i8> %x, <2 x i8> %y, <2 x i1> %mask, i32 %evl) { +; TODO-CHECK-LABEL: @knownnegation( +; TODO-XHECK-NEXT: ret i8 <2 x i8> zeroinitializer +; + ; %xy = sub i8 %x, %y + %xy = call <2 x i8> @llvm.vp.sub.v2i8(<2 x i8> %x, <2 x i8> %y, <2 x i1> %mask, i32 %evl) + ; %yx = sub i8 %y, %x + %yx = call <2 x i8> @llvm.vp.sub.v2i8(<2 x i8> %y, <2 x i8> %x, <2 x i1> %mask, i32 %evl) + ; %r = add i8 %xy, %yx + %r = call <2 x i8> @llvm.vp.add.v2i8(<2 x i8> %xy, <2 x i8> %yx, <2 x i1> %mask, i32 %evl) + ret <2 x i8> %r +} diff --git a/llvm/test/Transforms/InstSimplify/fast-math-strictfp.ll b/llvm/test/Transforms/InstSimplify/fast-math-strictfp.ll --- a/llvm/test/Transforms/InstSimplify/fast-math-strictfp.ll +++ b/llvm/test/Transforms/InstSimplify/fast-math-strictfp.ll @@ -58,9 +58,7 @@ define float @fadd_binary_fnegx(float %x) #0 { ; CHECK-LABEL: @fadd_binary_fnegx( -; CHECK-NEXT: [[NEGX:%.*]] = call float @llvm.experimental.constrained.fsub.f32(float -0.000000e+00, float [[X:%.*]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: [[R:%.*]] = call nnan float @llvm.experimental.constrained.fadd.f32(float [[NEGX]], float [[X]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: ret float [[R]] +; CHECK-NEXT: ret float 0.000000e+00 ; %negx = call float @llvm.experimental.constrained.fsub.f32(float -0.0, float %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 %r = call nnan float @llvm.experimental.constrained.fadd.f32(float %negx, float %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 @@ -80,9 +78,7 @@ define <2 x float> @fadd_binary_fnegx_commute_vec(<2 x float> %x) #0 { ; CHECK-LABEL: @fadd_binary_fnegx_commute_vec( -; CHECK-NEXT: [[NEGX:%.*]] = call <2 x float> @llvm.experimental.constrained.fsub.v2f32(<2 x float> , <2 x float> [[X:%.*]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: [[R:%.*]] = call nnan <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> [[X]], <2 x float> [[NEGX]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: ret <2 x float> [[R]] +; CHECK-NEXT: ret <2 x float> zeroinitializer ; %negx = call <2 x float> @llvm.experimental.constrained.fsub.v2f32(<2 x float> , <2 x float> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 %r = call nnan <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %x, <2 x float> %negx, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 @@ -100,9 +96,7 @@ define <2 x float> @fadd_fnegx_commute_vec_undef(<2 x float> %x) #0 { ; CHECK-LABEL: @fadd_fnegx_commute_vec_undef( -; CHECK-NEXT: [[NEGX:%.*]] = call <2 x float> @llvm.experimental.constrained.fsub.v2f32(<2 x float> , <2 x float> [[X:%.*]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: [[R:%.*]] = call nnan <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> [[X]], <2 x float> [[NEGX]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: ret <2 x float> [[R]] +; CHECK-NEXT: ret <2 x float> zeroinitializer ; %negx = call <2 x float> @llvm.experimental.constrained.fsub.v2f32(<2 x float> , <2 x float> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 %r = call nnan <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %x, <2 x float> %negx, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 @@ -160,9 +154,7 @@ define float @fadd_fsub_nnan_ninf(float %x) #0 { ; CHECK-LABEL: @fadd_fsub_nnan_ninf( -; CHECK-NEXT: [[SUB:%.*]] = call nnan ninf float @llvm.experimental.constrained.fsub.f32(float 0.000000e+00, float [[X:%.*]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: [[ZERO:%.*]] = call nnan ninf float @llvm.experimental.constrained.fadd.f32(float [[X]], float [[SUB]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: ret float [[ZERO]] +; CHECK-NEXT: ret float 0.0 ; %sub = call nnan ninf float @llvm.experimental.constrained.fsub.f32(float 0.0, float %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 %zero = call nnan ninf float @llvm.experimental.constrained.fadd.f32(float %x, float %sub, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 @@ -173,9 +165,7 @@ define <2 x float> @fadd_fsub_nnan_ninf_commute_vec(<2 x float> %x) #0 { ; CHECK-LABEL: @fadd_fsub_nnan_ninf_commute_vec( -; CHECK-NEXT: [[SUB:%.*]] = call <2 x float> @llvm.experimental.constrained.fsub.v2f32(<2 x float> zeroinitializer, <2 x float> [[X:%.*]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: [[ZERO:%.*]] = call nnan ninf <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> [[SUB]], <2 x float> [[X]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: ret <2 x float> [[ZERO]] +; CHECK-NEXT: ret <2 x float> zeroinitializer ; %sub = call <2 x float> @llvm.experimental.constrained.fsub.v2f32(<2 x float> zeroinitializer, <2 x float> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 %zero = call nnan ninf <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %sub, <2 x float> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 @@ -187,9 +177,7 @@ define float @fadd_fsub_nnan(float %x) #0 { ; CHECK-LABEL: @fadd_fsub_nnan( -; CHECK-NEXT: [[SUB:%.*]] = call float @llvm.experimental.constrained.fsub.f32(float 0.000000e+00, float [[X:%.*]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: [[ZERO:%.*]] = call nnan float @llvm.experimental.constrained.fadd.f32(float [[SUB]], float [[X]], metadata !"round.tonearest", metadata !"fpexcept.ignore") #[[ATTR0]] -; CHECK-NEXT: ret float [[ZERO]] +; CHECK-NEXT: ret float 0.0 ; %sub = call float @llvm.experimental.constrained.fsub.f32(float 0.0, float %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 %zero = call nnan float @llvm.experimental.constrained.fadd.f32(float %sub, float %x, metadata !"round.tonearest", metadata !"fpexcept.ignore") #0 diff --git a/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll b/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/fpadd_constrained.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -instsimplify -S | FileCheck %s + +declare float @llvm.experimental.constrained.fadd.f32(float, float, metadata, metadata) +declare <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float>, <2 x float>, metadata, metadata) +declare float @llvm.experimental.constrained.fsub.f32(float, float, metadata, metadata) + +; fadd X, -0 ==> X +define float @fadd_x_n0(float %a) { +; CHECK-LABEL: @fadd_x_n0( +; CHECK-NEXT: ret float [[A:%.*]] +; + %ret = call float @llvm.experimental.constrained.fadd.f32(float %a, float -0.0, + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret float %ret +} + +define <2 x float> @fadd_x_n0_vec_undef_elt(<2 x float> %a) { +; CHECK-LABEL: @fadd_x_n0_vec_undef_elt( +; CHECK-NEXT: ret <2 x float> %a +; + %ret = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %a, <2 x float> , + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret <2 x float> %ret +} + +; We can't optimize away the fadd in this test because the input +; value to the function and subsequently to the fadd may be -0.0. +; In that one special case, the result of the fadd should be +0.0 +; rather than the first parameter of the fadd. + +; Fragile test warning: We need 6 sqrt calls to trigger the bug +; because the internal logic has a magic recursion limit of 6. +; This is presented without any explanation or ability to customize. + +declare float @sqrtf(float) + +define float @PR22688(float %x) { +; CHECK-LABEL: @PR22688( +; CHECK-NEXT: [[TMP1:%.*]] = call float @sqrtf(float [[X:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = call float @sqrtf(float [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = call float @sqrtf(float [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = call float @sqrtf(float [[TMP3]]) +; CHECK-NEXT: [[TMP5:%.*]] = call float @sqrtf(float [[TMP4]]) +; CHECK-NEXT: [[TMP6:%.*]] = call float @sqrtf(float [[TMP5]]) +; CHECK-NEXT: [[TMP7:%.*]] = call float @llvm.experimental.constrained.fadd.f32(float [[TMP6]], float 0.000000e+00, +; CHECK-NEXT: ret float [[TMP7]] +; + %1 = call float @sqrtf(float %x) + %2 = call float @sqrtf(float %1) + %3 = call float @sqrtf(float %2) + %4 = call float @sqrtf(float %3) + %5 = call float @sqrtf(float %4) + %6 = call float @sqrtf(float %5) + %7 = call float @llvm.experimental.constrained.fadd.f32(float %6, float 0.0, + metadata !"round.tonearest", + metadata !"fpexcept.ignore") #0 + ret float %7 +} + +attributes #0 = { strictfp nounwind readnone }