Index: docs/LangRef.rst =================================================================== --- docs/LangRef.rst +++ docs/LangRef.rst @@ -13979,6 +13979,23 @@ after performing the required machine specific adjustments. The pointer returned can then be :ref:`bitcast and executed `. + +.. _int_vp: + +Vector Predication Intrinsics +---------------------------- +VP intrinics are intended for predicated SIMD/vector code. +A typical VP operation takes a mask () and an explicit vector length parameter (i32) as in: + + llvm.vp..*( %x, %y, mask %Mask, i32 vlen %evl) + +The mask and explicit vector length parameter are unambiguously identified by the mask and vlen parameter attributes. +Result elements are only computed for enabled lanes. +The explicit vector length parameter only disables lane if the MSB of the parameter is zero. +A lane is enabled if the mask at that position is true and, if effective, where the lane position is below the explicit vector length. + +In case of purely vertical operations (SIMD binary operators, etc) the result is undef on disabled lanes. + .. _int_mload_mstore: Masked Vector Load and Store Intrinsics Index: docs/Proposals/VectorPredication.rst =================================================================== --- /dev/null +++ docs/Proposals/VectorPredication.rst @@ -0,0 +1,83 @@ +========================== +Vector Predication Roadmap +========================== + +.. contents:: Table of Contents + :depth: 3 + :local: + +Motivation +========== + +This proposal defines a roadmap towards native vector predication in LLVM, specifically for vector instructions with a mask and/or an explicit vector length. +LLVM currently has no target-independent means to model predicated vector instructions for modern SIMD ISAs such as AVX512, ARM SVE, the RISC-V V extension and NEC SX-Aurora. +Only some predicated vector operations, such as masked loads and stores are available through intrinsics [MaskedIR]_. + +The Explicit Vector Length extension +==================================== + +The Explicit Vector Length (EVL) extension [EvlRFC]_ can be a first step towards native vector predication. +The EVL prototype in this patch demonstrates the following concepts: + +- Predicated vector intrinsics with an explicit mask and vector length parameter on IR level. +- First-class predicated SDNodes on ISel level. Mask and vector length are value operands. +- An incremental strategy to generalize PatternMatch/InstCombine/InstSimplify and DAGCombiner to work on both regular instructions and EVL intrinsics. +- DAGCombiner example: FMA fusion. +- InstCombine/InstSimplify example: FSub pattern re-writes. +- Early experiments on the LNT test suite (Clang static release, O3 -ffast-math) indicate that compile time on non-EVL IR is not affected by the API abstractions in PatternMatch, etc. + +Roadmap +======= + +Drawing from the EVL prototype, we propose the following roadmap towards native vector predication in LLVM: + + +1. IR-level EVL intrinsics +----------------------------------------- + +- There is a consensus on the semantics/instruction set of EVL. +- EVL intrinsics and attributes are available on IR level. +- TTI has capability flags for EVL (``supportsEVL()``?, ``haveActiveVectorLength()``?). + +Result: EVL usable for IR-level vectorizers (LV, VPlan, RegionVectorizer), potential integration in Clang with builtins. + +2. CodeGen support +------------------ + +- EVL intrinsics translate to first-class SDNodes (``llvm.evl.fdiv.* -> evl_fdiv``). +- EVL legalization (legalize explicit vector length to mask (AVX512), legalize EVL SDNodes to pre-existing ones (SSE, NEON)). + +Result: Backend development based on EVL SDNodes. + +3. Lift InstSimplify/InstCombine/DAGCombiner to EVL +------------------------------------------------ + +- Introduce PredicatedInstruction, PredicatedBinaryOperator, .. helper classes that match standard vector IR and EVL intrinsics. +- Add a matcher context to PatternMatch and context-aware IR Builder APIs. +- Incrementally lift DAGCombiner to work on EVL SDNodes as well as on regular vector instructions. +- Incrementally lift InstCombine/InstSimplify to operate on EVL as well as regular IR instructions. + +Result: Optimization of EVL intrinsics on par with standard vector instructions. + +4. Deprecate llvm.masked.* / llvm.experimental.reduce.* +------------------------------------------------------- + +- Modernize llvm.masked.* / llvm.experimental.reduce* by translating to EVL. +- DCE transitional APIs. + +Result: EVL has superseded earlier vector intrinsics. + +5. Predicated IR Instructions +--------------------------------------- + +- Vector instructions have an optional mask and vector length parameter. These lower to EVL SDNodes (from Stage 2). +- Phase out EVL intrinsics, only keeping those that are not equivalent to vectorized scalar instructions (reduce, shuffles, ..) +- InstCombine/InstSimplify expect predication in regular Instructions (Stage (3) has laid the groundwork). + +Result: Native vector predication in IR. + +References +========== + +.. [MaskedIR] `llvm.masked.*` intrinsics, https://llvm.org/docs/LangRef.html#masked-vector-load-and-store-intrinsics +.. [EvlRFC] Explicit Vector Length RFC, https://reviews.llvm.org/D53613 Index: include/llvm/Analysis/InstructionSimplify.h =================================================================== --- include/llvm/Analysis/InstructionSimplify.h +++ include/llvm/Analysis/InstructionSimplify.h @@ -52,6 +52,10 @@ class Value; class MDNode; class BinaryOperator; +class VPIntrinsic; +namespace PatternMatch { + struct PredicatedContext; +} /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative @@ -133,6 +137,13 @@ Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); +/// Given operands for an FSub, fold the result or return null. +Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); +Value *SimplifyPredicatedFSubInst(Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q, + PatternMatch::PredicatedContext & PC); + /// Given operands for an FMul, fold the result or return null. Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); @@ -245,6 +256,10 @@ Value *SimplifyCall(CallBase *Call, Value *V, User::op_iterator ArgBegin, User::op_iterator ArgEnd, const SimplifyQuery &Q); +/// Given a function and iterators over arguments, fold the result or return +/// null. +Value *SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q); + /// Given a function and set of arguments, fold the result or return null. Value *SimplifyCall(CallBase *Call, Value *V, ArrayRef Args, const SimplifyQuery &Q); Index: include/llvm/Bitcode/LLVMBitCodes.h =================================================================== --- include/llvm/Bitcode/LLVMBitCodes.h +++ include/llvm/Bitcode/LLVMBitCodes.h @@ -606,7 +606,10 @@ ATTR_KIND_OPT_FOR_FUZZING = 57, ATTR_KIND_SHADOWCALLSTACK = 58, ATTR_KIND_SPECULATIVE_LOAD_HARDENING = 59, - ATTR_KIND_IMMARG = 60 + ATTR_KIND_IMMARG = 60, + ATTR_KIND_MASK = 61, + ATTR_KIND_VECTORLENGTH = 62, + ATTR_KIND_PASSTHRU = 63, }; enum ComdatSelectionKindCodes { Index: include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- include/llvm/CodeGen/ISDOpcodes.h +++ include/llvm/CodeGen/ISDOpcodes.h @@ -198,6 +198,7 @@ /// Simple integer binary arithmetic operators. ADD, SUB, MUL, SDIV, UDIV, SREM, UREM, + VP_ADD, VP_SUB, VP_MUL, VP_SDIV, VP_UDIV, VP_SREM, VP_UREM, /// SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing /// a signed/unsigned value of type i[2*N], and return the full value as @@ -280,6 +281,7 @@ /// Simple binary floating point operators. FADD, FSUB, FMUL, FDIV, FREM, + VP_FADD, VP_FSUB, VP_FMUL, VP_FDIV, VP_FREM, /// Constrained versions of the binary floating point operators. /// These will be lowered to the simple operators before final selection. @@ -299,6 +301,7 @@ /// FMA - Perform a * b + c with no intermediate rounding step. FMA, + VP_FMA, /// FMAD - Perform a * b + c, while getting the same result as the /// separately rounded operations. @@ -365,6 +368,19 @@ /// in terms of the element size of VEC1/VEC2, not in terms of bytes. VECTOR_SHUFFLE, + /// VP_VSHIFT(VEC1, AMOUNT, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. AMOUNT is an integer value. The returned vector is equivalent + /// to VEC1 shifted by AMOUNT (RETURNED_VEC[idx] = VEC1[idx + AMOUNT]). + VP_VSHIFT, + + /// VP_COMPRESS(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_COMPRESS, + + /// VP_EXPAND(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_EXPAND, + /// SCALAR_TO_VECTOR(VAL) - This represents the operation of loading a /// scalar value into element 0 of the resultant vector type. The top /// elements 1 to N-1 of the N-element vector are undefined. The type @@ -384,6 +400,7 @@ /// Bitwise operators - logical and, logical or, logical xor. AND, OR, XOR, + VP_AND, VP_OR, VP_XOR, /// ABS - Determine the unsigned absolute value of a signed integer value of /// the same bitwidth. @@ -407,6 +424,7 @@ /// fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW))) /// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW)) SHL, SRA, SRL, ROTL, ROTR, FSHL, FSHR, + VP_SHL, VP_SRA, VP_SRL, /// Byte Swap and Counting operators. BSWAP, CTTZ, CTLZ, CTPOP, BITREVERSE, @@ -426,6 +444,14 @@ /// change the condition type in order to match the VSELECT node using a /// pattern. The condition follows the BooleanContent format of the target. VSELECT, + VP_SELECT, + + /// Select with an integer pivot (op #0) and two vector operands (ops #1 + /// and #2), returning a vector result. All vectors have the same length. + /// Similar to the vector select, a comparison of the results element index + /// with the integer pivot selects hether the corresponding result element + /// is taken from op #1 or op #2. + VP_COMPOSE, /// Select with condition operator - This selects between a true value and /// a false value (ops #2 and #3) based on the boolean result of comparing @@ -440,6 +466,7 @@ /// them with (op #2) as a CondCodeSDNode. If the operands are vector types /// then the result type must also be a vector type. SETCC, + VP_SETCC, /// Like SetCC, ops #0 and #1 are the LHS and RHS operands to compare, but /// op #2 is a boolean indicating if there is an incoming carry. This @@ -585,6 +612,7 @@ FNEG, FABS, FSQRT, FCBRT, FSIN, FCOS, FPOWI, FPOW, FLOG, FLOG2, FLOG10, FEXP, FEXP2, FCEIL, FTRUNC, FRINT, FNEARBYINT, FROUND, FFLOOR, + VP_FNEG, /// FMINNUM/FMAXNUM - Perform floating-point minimum or maximum on two /// values. // @@ -833,6 +861,7 @@ // Val, OutChain = MLOAD(BasePtr, Mask, PassThru) // OutChain = MSTORE(Value, BasePtr, Mask) MLOAD, MSTORE, + VP_LOAD, VP_STORE, // Masked gather and scatter - load and store operations for a vector of // random addresses with additional mask operand that prevents memory @@ -844,6 +873,7 @@ // The Index operand can have more vector elements than the other operands // due to type legalization. The extra elements are ignored. MGATHER, MSCATTER, + VP_GATHER, VP_SCATTER, /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. @@ -881,6 +911,14 @@ VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR, VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN, + VP_REDUCE_FADD, VP_REDUCE_FMUL, + VP_REDUCE_ADD, VP_REDUCE_MUL, + VP_REDUCE_AND, VP_REDUCE_OR, VP_REDUCE_XOR, + VP_REDUCE_SMAX, VP_REDUCE_SMIN, VP_REDUCE_UMAX, VP_REDUCE_UMIN, + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + VP_REDUCE_FMAX, VP_REDUCE_FMIN, + /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. BUILTIN_OP_END @@ -1040,6 +1078,19 @@ /// SETCC_INVALID if it is not possible to represent the resultant comparison. CondCode getSetCCAndOperation(CondCode Op1, CondCode Op2, bool isInteger); + /// Return the mask operand of this VP SDNode. + /// Otw, return -1. + int GetMaskPosVP(unsigned OpCode); + + /// Return the vector length operand of this VP SDNode. + /// Otw, return -1. + int GetVectorLengthPosVP(unsigned OpCode); + + /// Translate this VP OpCode to a native instruction OpCode. + unsigned GetFunctionOpCodeForVP(unsigned VPOpCode); + + unsigned GetVPForFunctionOpCode(unsigned OpCode); + } // end llvm::ISD namespace } // end llvm namespace Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -1089,6 +1089,20 @@ SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base, SDValue Offset, ISD::MemIndexedMode AM); + /// Returns sum of the base pointer and offset. + SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, ISD::LoadExtType); + + SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + bool IsTruncating = false); + SDValue getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO); + SDValue getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO); + /// Returns sum of the base pointer and offset. SDValue getMemBasePlusOffset(SDValue Base, unsigned Offset, const SDLoc &DL); Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -545,6 +545,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class VPLoadSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -555,6 +556,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class MaskedStoreSDNode; + friend class VPStoreSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -695,6 +697,66 @@ } } + /// Test whether this is an Explicit Vector Length node. + bool isVP() const { + switch (NodeType) { + default: + return false; + case ISD::VP_LOAD: + case ISD::VP_STORE: + case ISD::VP_GATHER: + case ISD::VP_SCATTER: + + case ISD::VP_FNEG: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + case ISD::VP_FSUB: + case ISD::VP_FDIV: + case ISD::VP_FREM: + + case ISD::VP_FMA: + + case ISD::VP_ADD: + case ISD::VP_MUL: + case ISD::VP_SUB: + case ISD::VP_SRA: + case ISD::VP_SRL: + case ISD::VP_SHL: + case ISD::VP_UDIV: + case ISD::VP_SDIV: + case ISD::VP_UREM: + case ISD::VP_SREM: + + case ISD::VP_EXPAND: + case ISD::VP_COMPRESS: + case ISD::VP_VSHIFT: + case ISD::VP_SETCC: + case ISD::VP_COMPOSE: + + case ISD::VP_AND: + case ISD::VP_XOR: + case ISD::VP_OR: + + case ISD::VP_REDUCE_ADD: + case ISD::VP_REDUCE_SMIN: + case ISD::VP_REDUCE_SMAX: + case ISD::VP_REDUCE_UMIN: + case ISD::VP_REDUCE_UMAX: + + case ISD::VP_REDUCE_MUL: + case ISD::VP_REDUCE_AND: + case ISD::VP_REDUCE_OR: + case ISD::VP_REDUCE_FADD: + case ISD::VP_REDUCE_FMUL: + case ISD::VP_REDUCE_FMIN: + case ISD::VP_REDUCE_FMAX: + + return true; + } + } + + /// Test if this node has a post-isel opcode, directly /// corresponding to a MachineInstr opcode. bool isMachineOpcode() const { return NodeType < 0; } @@ -1389,6 +1451,10 @@ N->getOpcode() == ISD::MSTORE || N->getOpcode() == ISD::MGATHER || N->getOpcode() == ISD::MSCATTER || + N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE || + N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER || N->isMemIntrinsic() || N->isTargetMemoryOpcode(); } @@ -2241,6 +2307,96 @@ } }; +/// This base class is used to represent MLOAD and MSTORE nodes +class VPLoadStoreSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + + // VPLoadSDNode (Chain, ptr, mask, VLen) + // VPStoreSDNode (Chain, data, ptr, mask, VLen) + // Mask is a vector of i1 elements, Vlen is i32 + const SDValue &getBasePtr() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 1 : 2); + } + const SDValue &getMask() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 2 : 3); + } + const SDValue &getVectorLength() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 3 : 4); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE; + } +}; + +/// This class is used to represent an MLOAD node +class VPLoadSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + ISD::LoadExtType ETy, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, MemVT, MMO) { + LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = false; + } + + ISD::LoadExtType getExtensionType() const { + return static_cast(LoadSDNodeBits.ExtTy); + } + + const SDValue &getBasePtr() const { return getOperand(1); } + const SDValue &getMask() const { return getOperand(2); } + const SDValue &getVectorLength() const { return getOperand(3); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD; + } + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } +}; + +/// This class is used to represent an MSTORE node +class VPStoreSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + bool isTrunc, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, MemVT, MMO) { + StoreSDNodeBits.IsTruncating = isTrunc; + StoreSDNodeBits.IsCompressing = false; + } + + /// Return true if the op does a truncation before store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } + + /// Returns true if the op does a compression to the vector before storing. + /// The node contiguously stores the active elements (integers or floats) + /// in src (those with their respective bit set in writemask k) to unaligned + /// memory at base_addr. + bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; } + + const SDValue &getValue() const { return getOperand(1); } + const SDValue &getBasePtr() const { return getOperand(2); } + const SDValue &getMask() const { return getOperand(3); } + const SDValue &getVectorLength() const { return getOperand(4); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_STORE; + } +}; + /// This base class is used to represent MLOAD and MSTORE nodes class MaskedLoadStoreSDNode : public MemSDNode { public: @@ -2328,6 +2484,67 @@ } }; +/// This is a base class used to represent +/// VP_GATHER and VP_SCATTER nodes +/// +class VPGatherScatterSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + + // In the both nodes address is Op1, mask is Op2: + // VPGatherSDNode (Chain, base, index, scale, mask, vlen) + // VPScatterSDNode (Chain, value, base, index, sckae, mask, vlen) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 1 : 2); } + const SDValue &getIndex() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 2 : 3); } + const SDValue &getScale() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 3 : 4); } + const SDValue &getMask() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 4 : 5); } + const SDValue &getVectorLength() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 5 : 6); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER; + } +}; + +/// This class is used to represent an VP_GATHER node +/// +class VPGatherSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO) + : VPGatherScatterSDNode(ISD::VP_GATHER, Order, dl, VTs, MemVT, MMO) {} + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER; + } +}; + +/// This class is used to represent an VP_SCATTER node +/// +class VPScatterSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO) + : VPGatherScatterSDNode(ISD::VP_SCATTER, Order, dl, VTs, MemVT, MMO) {} + + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_SCATTER; + } +}; + + /// This is a base class used to represent /// MGATHER and MSCATTER nodes /// Index: include/llvm/IR/Attributes.td =================================================================== --- include/llvm/IR/Attributes.td +++ include/llvm/IR/Attributes.td @@ -133,6 +133,15 @@ /// Parameter is required to be a trivial constant. def ImmArg : EnumAttr<"immarg">; +/// Return value that is equal to this argument on enabled lanes (mask). +def Passthru : EnumAttr<"passthru">; + +/// Mask argument that applies to this function. +def Mask : EnumAttr<"mask">; + +/// Dynamic Vector Length argument of this function. +def VectorLength : EnumAttr<"vlen">; + /// Function can return twice. def ReturnsTwice : EnumAttr<"returns_twice">; Index: include/llvm/IR/InstrTypes.h =================================================================== --- include/llvm/IR/InstrTypes.h +++ include/llvm/IR/InstrTypes.h @@ -162,7 +162,7 @@ static BinaryOperator *CreateWithCopiedFlags(BinaryOps Opc, Value *V1, Value *V2, - BinaryOperator *CopyBO, + Instruction *CopyBO, const Twine &Name = "") { BinaryOperator *BO = Create(Opc, V1, V2, Name); BO->copyIRFlags(CopyBO); @@ -170,31 +170,31 @@ } static BinaryOperator *CreateFAddFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FAdd, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFSubFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FSub, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFMulFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FMul, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFDivFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FDiv, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFRemFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FRem, V1, V2, FMFSource, Name); } - static BinaryOperator *CreateFNegFMF(Value *Op, BinaryOperator *FMFSource, + static BinaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, const Twine &Name = "") { Value *Zero = ConstantFP::getNegativeZero(Op->getType()); return CreateWithCopiedFlags(Instruction::FSub, Zero, Op, FMFSource); Index: include/llvm/IR/IntrinsicInst.h =================================================================== --- include/llvm/IR/IntrinsicInst.h +++ include/llvm/IR/IntrinsicInst.h @@ -205,25 +205,207 @@ /// @} }; - /// This is the common base class for constrained floating point intrinsics. - class ConstrainedFPIntrinsic : public IntrinsicInst { + enum class RoundingMode { + rmInvalid, + rmDynamic, + rmToNearest, + rmDownward, + rmUpward, + rmTowardZero + }; + + enum class ExceptionBehavior { + ebInvalid, + ebIgnore, + ebMayTrap, + ebStrict + }; + + class VPIntrinsic : public IntrinsicInst { public: - enum RoundingMode { - rmInvalid, - rmDynamic, - rmToNearest, - rmDownward, - rmUpward, - rmTowardZero + enum class VPTypeToken : int8_t { + Scalar = 1, // scalar operand type + Vector = 2, // vectorized operand type + Mask = 3 // vector mask type }; - enum ExceptionBehavior { - ebInvalid, - ebIgnore, - ebMayTrap, - ebStrict + using TypeTokenVec = SmallVector; + using ShortTypeVec = SmallVector; + + struct + VPIntrinsicDesc { + Intrinsic::ID ID; // LLVM Intrinsic ID. + TypeTokenVec typeTokens; // Type Parmeters for the LLVM Intrinsic. + int MaskPos; // Parameter index of the Mask parameter. + int EVLPos; // Parameter index of the VP parameter. }; + // Translate this generic Opcode to a VPIntrinsic + static VPIntrinsicDesc GetVPIntrinsicDesc(unsigned OC); + // Translate this non-VP intrinsic to a VPIntrinsic. + static VPIntrinsicDesc GetVPDescForIntrinsic(unsigned IntrinsicID); + + // Generate the disambiguating type vec for this VP Intrinsic + static VPIntrinsic::ShortTypeVec + EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type & VectorTy, Type & ScalarTy); + + // available for all VP intrinsics + Value* getMask() const; + Value* getVectorLength() const; + + bool isUnaryOp() const; + bool isBinaryOp() const; + bool isTernaryOp() const; + + // compare intrinsic + bool isCompareOp() const { return getIntrinsicID() == Intrinsic::vp_cmp; } + CmpInst::Predicate getCmpPredicate() const; + + // llvm.vp.constrained.* + bool isConstrainedOp() const; + RoundingMode getRoundingMode() const; + ExceptionBehavior getExceptionBehavior() const; + + // llvm.vp.reduction.* + bool isReductionOp() const; + + // Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + switch (I->getIntrinsicID()) { + default: + return false; + + case Intrinsic::vp_cmp: + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + + case Intrinsic::vp_constrained_fadd: + case Intrinsic::vp_constrained_fsub: + case Intrinsic::vp_constrained_fmul: + case Intrinsic::vp_constrained_fdiv: + case Intrinsic::vp_constrained_frem: + case Intrinsic::vp_constrained_fma: + case Intrinsic::vp_constrained_sqrt: + case Intrinsic::vp_constrained_pow: + case Intrinsic::vp_constrained_powi: + case Intrinsic::vp_constrained_sin: + case Intrinsic::vp_constrained_cos: + case Intrinsic::vp_constrained_exp: + case Intrinsic::vp_constrained_exp2: + case Intrinsic::vp_constrained_log: + case Intrinsic::vp_constrained_log10: + case Intrinsic::vp_constrained_log2: + case Intrinsic::vp_constrained_rint: + case Intrinsic::vp_constrained_nearbyint: + case Intrinsic::vp_constrained_maxnum: + case Intrinsic::vp_constrained_minnum: + case Intrinsic::vp_constrained_ceil: + case Intrinsic::vp_constrained_floor: + case Intrinsic::vp_constrained_round: + case Intrinsic::vp_constrained_trunc: + return true; + } + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } + + // Equivalent non-predicated opcode + unsigned getFunctionalOpcode() const { + switch (getIntrinsicID()) { + default: return Instruction::Call; + + case Intrinsic::vp_cmp: + if (getArgOperand(0)->getType()->isFloatingPointTy()) { + return Instruction::FCmp; + } else { + return Instruction::ICmp; + } + + case Intrinsic::vp_and: return Instruction::And; + case Intrinsic::vp_or: return Instruction::Or; + case Intrinsic::vp_xor: return Instruction::Xor; + case Intrinsic::vp_ashr: return Instruction::AShr; + case Intrinsic::vp_lshr: return Instruction::LShr; + case Intrinsic::vp_shl: return Instruction::Shl; + + case Intrinsic::vp_select: return Instruction::Select; + + case Intrinsic::vp_load: return Instruction::Load; + case Intrinsic::vp_store: return Instruction::Store; + + case Intrinsic::vp_fneg: return Instruction::FNeg; + + case Intrinsic::vp_fadd: return Instruction::FAdd; + case Intrinsic::vp_fsub: return Instruction::FSub; + case Intrinsic::vp_fmul: return Instruction::FMul; + case Intrinsic::vp_fdiv: return Instruction::FDiv; + case Intrinsic::vp_frem: return Instruction::FRem; + + case Intrinsic::vp_add: return Instruction::Add; + case Intrinsic::vp_sub: return Instruction::Sub; + case Intrinsic::vp_mul: return Instruction::Mul; + case Intrinsic::vp_udiv: return Instruction::UDiv; + case Intrinsic::vp_sdiv: return Instruction::SDiv; + case Intrinsic::vp_urem: return Instruction::URem; + case Intrinsic::vp_srem: return Instruction::SRem; + } + } + }; + + /// This is the common base class for constrained floating point intrinsics. + class ConstrainedFPIntrinsic : public IntrinsicInst { + public: bool isUnaryOp() const; bool isTernaryOp() const; RoundingMode getRoundingMode() const; Index: include/llvm/IR/Intrinsics.td =================================================================== --- include/llvm/IR/Intrinsics.td +++ include/llvm/IR/Intrinsics.td @@ -92,6 +92,25 @@ int ArgNo = argNo; } +// VectorLength - The specified argument is the Dynamic Vector Length of the +// operation. +class VectorLength : IntrinsicProperty { + int ArgNo = argNo; +} + +// Mask - The specified argument contains the per-lane mask of this +// intrinsic. Inputs on masked-out lanes must not affect the result of this +// intrinsic (except for the Passthru argument). +class Mask : IntrinsicProperty { + int ArgNo = argNo; +} +// Passthru - The specified argument contains the per-lane return value +// for this vector intrinsic where the mask is false. +// (requires the Mask attribute in the same function) +class Passthru : IntrinsicProperty { + int ArgNo = argNo; +} + def IntrNoReturn : IntrinsicProperty; // IntrCold - Calls to this intrinsic are cold. @@ -1021,9 +1040,440 @@ // Intrinsic to detect whether its argument is a constant. def int_is_constant : Intrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem], "llvm.is.constant">; +//===---------------- Vector Predication Intrinsics --------------===// + +// Memory Intrinsics +def int_vp_store : Intrinsic<[], + [ llvm_anyvector_ty, + LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrArgMemOnly, Mask<2>, VectorLength<3> ]>; + +def int_vp_load : Intrinsic<[ llvm_anyvector_ty], + [ LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrReadMem, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_vp_gather: Intrinsic<[ llvm_anyvector_ty], + [ LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrReadMem, IntrReadMem, Mask<1>, VectorLength<2> ]>; + +def int_vp_scatter: Intrinsic<[], + [ llvm_anyvector_ty, + LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrArgMemOnly, Mask<2>, VectorLength<3> ]>; + +// Reductions +let IntrProperties = [IntrNoMem, Mask<2>, VectorLength<3>] in { +def int_vp_reduce_add : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_mul : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_and : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_or : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_xor : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_smax : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_smin : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_umax : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_umin : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; + +def int_vp_reduce_fadd : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_fmul : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_fmax : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_vp_reduce_fmin : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +} + +// Binary operators +let IntrProperties = [IntrNoMem, Mask<2>, VectorLength<3>] in { + def int_vp_add : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_mul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_udiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_srem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_urem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_fadd : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_fsub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_fmul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_fdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_frem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Logical operators + def int_vp_ashr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_lshr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_shl : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_or : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_and : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_xor : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Comparison +// The last argument is the comparison predicate + def int_vp_cmp : Intrinsic<[ llvm_anyvector_ty ], + [ llvm_anyvector_ty, + LLVMMatchType<1>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty, + llvm_i8_ty]>; +} + + + +def int_vp_fneg : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<1>, VectorLength<2> ]>; + +def int_vp_fma : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<3>, VectorLength<4> ]>; + +// Shuffle +def int_vp_vshift: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<2>, VectorLength<3> ]>; + +def int_vp_expand: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<2>, VectorLength<3> ]>; + +def int_vp_compress: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<2>, VectorLength<3> ]>; + +// Select +def int_vp_select : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [ IntrNoMem, Passthru<2>, Mask<0>, VectorLength<3> ]>; + +// Compose +def int_vp_compose : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [ IntrNoMem, VectorLength<3> ]>; + + + +// These intrinsics are sensitive to the rounding mode so we need constrained +// versions of each of them. When strict rounding and exception control are +// not required the non-constrained versions of these intrinsics should be +// used. + +let IntrProperties = [IntrInaccessibleMemOnly, Mask<4>, VectorLength<5> ] in { + def int_vp_constrained_fadd : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_fsub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_fmul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_fdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_frem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + + def int_vp_constrained_pow : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_maxnum : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_minnum : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; +} + +def int_vp_constrained_fma : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrInaccessibleMemOnly, Mask<5>, VectorLength<6> ]>; + +let IntrProperties = [IntrInaccessibleMemOnly, Mask<3>, VectorLength<4> ] in { + def int_vp_constrained_sqrt : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_powi : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_sin : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_cos : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_log : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_log10: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_log2 : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_exp : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_exp2 : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_rint : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_nearbyint : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_ceil : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_floor : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_round : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_constrained_trunc : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; +} + + //===-------------------------- Masked Intrinsics -------------------------===// -// +// TODO poised for deprecation (superseded by llvm.vp*. intrinsics) def int_masked_store : Intrinsic<[], [llvm_anyvector_ty, LLVMAnyPointerType>, llvm_i32_ty, @@ -1121,6 +1571,7 @@ [ IntrArgMemOnly, NoCapture<0>, WriteOnly<0>, ImmArg<3> ]>; //===------------------------ Reduction Intrinsics ------------------------===// +// TODO poised for deprecation (superseded by llvm.vp*. intrinsics) // def int_experimental_vector_reduce_fadd : Intrinsic<[llvm_anyfloat_ty], [llvm_anyfloat_ty, Index: include/llvm/IR/MatcherCast.h =================================================================== --- /dev/null +++ include/llvm/IR/MatcherCast.h @@ -0,0 +1,65 @@ +#ifndef LLVM_IR_MATCHERCAST_H +#define LLVM_IR_MATCHERCAST_H + +//===- MatcherCast.h - Match on the LLVM IR --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parameterized class hierachy for templatized pattern matching. +// +//===----------------------------------------------------------------------===// + + +namespace llvm { +namespace PatternMatch { + + +// type modification +template +struct MatcherCast { }; + +// whether the Value \p Obj behaves like a \p Class. +template +bool match_isa(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return isa(Obj); +} + +template +auto match_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + +template +auto match_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_MATCHERCAST_H + Index: include/llvm/IR/PatternMatch.h =================================================================== --- include/llvm/IR/PatternMatch.h +++ include/llvm/IR/PatternMatch.h @@ -39,22 +39,81 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "llvm/IR/MatcherCast.h" + #include + namespace llvm { namespace PatternMatch { +// Use verbatim types in default (empty) context. +struct EmptyContext { + EmptyContext() {} + + EmptyContext(const Value *) {} + + EmptyContext(const EmptyContext & E) {} + + // reset this match context to be rooted at \p V + void reset(Value * V) {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { return true; } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(EmptyContext E) const { return true; } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(EmptyContext E) { return true; } + + // reset this context to \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match in the current context + template bool try_match(Val *V, const Pattern &P) { + return const_cast(P).match_context(V, *this); + } +}; + +template +struct MatcherCast { using ActualCastType = DestClass; }; + + + + + + +// match without (== empty) context template bool match(Val *V, const Pattern &P) { - return const_cast(P).match(V); + EmptyContext ECtx; + return const_cast(P).match_context(V, ECtx); +} + +// match pattern in a given context +template bool match(Val *V, const Pattern &P, MatchContext & MContext) { + return const_cast(P).match_context(V, MContext); } + + template struct OneUse_match { SubPattern_t SubPattern; OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + EmptyContext EContext; return match_context(V, EContext); + } + + template bool match_context(OpTy *V, MatchContext & MContext) { + return V->hasOneUse() && SubPattern.match_context(V, MContext); } }; @@ -63,7 +122,11 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + template bool match(ITy *V) { + EmptyContext EContext; return match_context(V, EContext); + } + template + bool match_context(ITy *V, MatchContext & MContext) { return match_isa(V); } }; /// Match an arbitrary value and ignore it. @@ -95,11 +158,17 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + MatchContext SubContext; + + if (L.match_context(V, SubContext) && MContext.acceptContext(SubContext)) { + MContext.mergeContext(SubContext); return true; - if (R.match(V)) + } + if (R.match_context(V, MContext)) { return true; + } return false; } }; @@ -110,9 +179,10 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + if (L.match_context(V, MContext)) + if (R.match_context(V, MContext)) return true; return false; } @@ -135,7 +205,8 @@ apint_match(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; @@ -155,7 +226,8 @@ struct apfloat_match { const APFloat *&Res; apfloat_match(const APFloat *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; @@ -179,7 +251,8 @@ inline apfloat_match m_APFloat(const APFloat *&Res) { return Res; } template struct constantint_match { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -202,7 +275,8 @@ /// satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cst_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) return this->isValue(CI->getValue()); if (V->getType()->isVectorTy()) { @@ -239,7 +313,8 @@ api_pred_ty(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -261,7 +336,8 @@ /// constants that satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cstfp_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CF = dyn_cast(V)) return this->isValue(CF->getValueAPF()); if (V->getType()->isVectorTy()) { @@ -365,7 +441,8 @@ } struct is_zero { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { auto *C = dyn_cast(V); return C && (C->isNullValue() || cst_pred_ty().match(C)); } @@ -461,8 +538,11 @@ bind_ty(Class *&V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CV = dyn_cast(V)) { + if (!MContext.acceptBoundNode(V)) return false; + VR = CV; return true; } @@ -494,7 +574,8 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { return V == Val; } }; /// Match if we have a specific specified value. @@ -507,7 +588,8 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *const V, MatchContext & MContext) { return V == Val; } }; /// A commutative-friendly version of m_Specific(). @@ -523,7 +605,8 @@ specific_fpval(double V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -546,7 +629,8 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -563,7 +647,8 @@ specific_intval(uint64_t V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -593,11 +678,16 @@ // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (!I) return false; + + if (!MContext.acceptInnerNode(I)) return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; return false; } }; @@ -621,12 +711,15 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + MatchContext LRContext(MContext); + if (!MContext.acceptInnerNode(I)) return false; + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; + return false; } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && @@ -665,20 +758,21 @@ Op_t X; FNeg_match(const Op_t &Op) : X(Op) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { auto *FPMO = dyn_cast(V); - if (!FPMO || FPMO->getOpcode() != Instruction::FSub) + if (!FPMO || match_cast(V)->getOpcode() != Instruction::FSub) return false; if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match_context(FPMO->getOperand(1), MContext); } }; @@ -789,7 +883,8 @@ OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *Op = dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; @@ -799,7 +894,7 @@ if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match_context(Op->getOperand(0), MContext) && R.match_context(Op->getOperand(1), MContext); } return false; } @@ -881,10 +976,11 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && L.match_context(I->getOperand(0), MContext) && + R.match_context(I->getOperand(1), MContext); if (auto *CE = dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); @@ -963,9 +1059,10 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + return PEO->isExact() && SubPattern.match_context(V, MContext); return false; } }; @@ -990,14 +1087,17 @@ CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - if ((L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0)))) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) { + if (!MContext.acceptInnerNode(I)) return false; + MatchContext LRContext(MContext); + if ((L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) || + (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext)))) { Predicate = I->getPredicate(); return true; } + } return false; } }; @@ -1030,10 +1130,11 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext); } return false; } @@ -1046,10 +1147,12 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext); } return false; } @@ -1065,11 +1168,13 @@ ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) : Op1(Op1), Op2(Op2), Op3(Op3) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext) && + Op3.match_context(I->getOperand(2), MContext); } return false; } @@ -1137,9 +1242,10 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto O = match_dyn_cast(V)) + return O->getOpcode() == Opcode && MContext.acceptInnerNode(O) && Op.match_context(O->getOperand(0), MContext); return false; } }; @@ -1214,8 +1320,9 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1233,8 +1340,9 @@ brc_match(const Cond_t &C, BasicBlock *&t, BasicBlock *&f) : Cond(C), T(t), F(f) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isConditional() && Cond.match(BI->getCondition())) { T = BI->getSuccessor(0); F = BI->getSuccessor(1); @@ -1263,13 +1371,14 @@ // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = match_dyn_cast(V); + if (!SI || !MContext.acceptInnerNode(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = match_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.acceptInnerNode(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1285,9 +1394,12 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatchContext LRContext(MContext); + if (L.match_context(LHS, LRContext) && R.match_context(RHS, LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(RHS, MContext) && R.match_context(LHS, MContext))) return true; + return false; } }; @@ -1444,7 +1556,8 @@ UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) : L(L), R(R), S(S) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) @@ -1497,9 +1610,10 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) + if (const auto *CI = match_dyn_cast(V)) return Val.match(CI->getArgOperand(OpI)); return false; } @@ -1517,8 +1631,9 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} - template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (const auto *CI = match_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -1728,7 +1843,8 @@ Opnd_t Val; Signum_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; Index: include/llvm/IR/PredicatedInst.h =================================================================== --- /dev/null +++ include/llvm/IR/PredicatedInst.h @@ -0,0 +1,369 @@ +//===-- llvm/PredicatedInst.h - Predication utility subclass --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with Predicated Instructions. +// Predicated instructions are either regular instructions or calls to +// Vector Predication (VP) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/MatcherCast.h" + +#include + +namespace llvm { + +class BasicBlock; + +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void copyIRFlags(const Value * V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock* getParent() { return cast(this)->getParent(); } + const BasicBlock* getParent() const { return cast(this)->getParent(); } + + void *operator new(size_t s) = delete; + + Value* getMask() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getMask(); + } + + Value* getVectorLength() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getVectorLength(); + } + + unsigned getOpcode() const { + auto * VPInst = dyn_cast(this); + if (VPInst) + return VPInst->getFunctionalOpcode(); + return cast(this)->getOpcode(); + } + +#if 0 + operator Instruction() { return cast(this); } + operator const Value() const { return cast(this); } +#endif + + static bool classof(const Instruction * I) { return isa(I); } + static bool classof(const ConstantExpr * CE) { return false; } + static bool classof(const Value *V) { return isa(V); } +}; + +class PredicatedOperator : public User { +public: + // The PredicatedOperator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedOperator() = delete; + ~PredicatedOperator() = delete; + + void *operator new(size_t s) = delete; + +#if 0 + operator Value*() { return cast(this); } + operator const Value*() const { return cast(this); } +#endif + + /// Return the opcode for this Instruction or ConstantExpr. + unsigned getOpcode() const { + auto * VPInst = dyn_cast(this); + if (VPInst) + return VPInst->getFunctionalOpcode(); + if (const Instruction *I = dyn_cast(this)) + return I->getOpcode(); + return cast(this)->getOpcode(); + } + + Value* getMask() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getMask(); + } + + Value* getVectorLength() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getVectorLength(); + } + + void copyIRFlags(const Value * V, bool IncludeWrapFlags = true); + FastMathFlags getFastMathFlags() const { + auto * I = dyn_cast(this); + if (I) return I->getFastMathFlags(); + else return FastMathFlags(); + } + + static bool classof(const Instruction * I) { return isa(I) || isa(I); } + static bool classof(const ConstantExpr * CE) { return isa(CE); } + static bool classof(const Value *V) { return isa(V) || isa(V); } +}; + +class PredicatedBinaryOperator : public PredicatedOperator { +public: + // The PredicatedBinaryOperator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedBinaryOperator() = delete; + ~PredicatedBinaryOperator() = delete; + + using BinaryOps = Instruction::BinaryOps; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->isBinaryOp(); + } + static bool classof(const ConstantExpr * CE) { return isa(CE); } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && classof(CE); + } + + /// Construct a predicated binary instruction, given the opcode and the two + /// operands. + static Instruction* Create(Module * Mod, + Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock * InsertAtEnd, + Instruction * InsertBefore); + + static Instruction* Create(Module *Mod, + Value *Mask, Value *VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name = Twine(), + Instruction *InsertBefore = nullptr) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, InsertBefore); + } + + static Instruction* Create(Module *Mod, + Value *Mask, Value *VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock *InsertAtEnd) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, InsertAtEnd, nullptr); + } + + static Instruction* CreateWithCopiedFlags(Module *Mod, + Value *Mask, Value* VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + Instruction *CopyBO, + const Twine &Name = "") { + Instruction *BO = Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, nullptr); + BO->copyIRFlags(CopyBO); + return BO; + } +}; + +class PredicatedICmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedICmpInst() = delete; + ~PredicatedICmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::ICmp; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::ICmp; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && classof(CE); + } + + ICmpInst::Predicate getPredicate() const { + auto * ICInst = dyn_cast(this); + if (ICInst) return ICInst->getPredicate(); + auto * CE = dyn_cast(this); + if (CE) return static_cast(CE->getPredicate()); + return static_cast(cast(this)->getCmpPredicate()); + } +}; + + +class PredicatedFCmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedFCmpInst() = delete; + ~PredicatedFCmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::FCmp; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::FCmp; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + return isa(V); + } + + FCmpInst::Predicate getPredicate() const { + auto * FCInst = dyn_cast(this); + if (FCInst) return FCInst->getPredicate(); + auto * CE = dyn_cast(this); + if (CE) return static_cast(CE->getPredicate()); + return static_cast(cast(this)->getCmpPredicate()); + } +}; + + +class PredicatedSelectInst : public PredicatedOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedSelectInst() = delete; + ~PredicatedSelectInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::Select; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::Select; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && CE->getOpcode() == Instruction::Select; + } + + const Value *getCondition() const { return getOperand(0); } + const Value *getTrueValue() const { return getOperand(1); } + const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + + void setCondition(Value *V) { setOperand(0, V); } + void setTrueValue(Value *V) { setOperand(1, V); } + void setFalseValue(Value *V) { setOperand(2, V); } +}; + + +namespace PatternMatch { + +// PredicatedMatchContext for pattern matching +struct PredicatedContext { + Value * Mask; + Value * VectorLength; + Module * Mod; + + void reset(Value * V) { + auto * PI = dyn_cast(V); + if (!PI) { + VectorLength = nullptr; + Mask = nullptr; + Mod = nullptr; + } else { + VectorLength = PI->getVectorLength(); + Mask = PI->getMask(); + Mod = PI->getParent()->getParent()->getParent(); + } + } + + PredicatedContext(Value * Val) + : Mask(nullptr) + , VectorLength(nullptr) + , Mod(nullptr) + { + reset(Val); + } + + PredicatedContext(const PredicatedContext & PC) + : Mask(PC.Mask) + , VectorLength(PC.VectorLength) + , Mod(PC.Mod) + {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { + auto PredI = dyn_cast(Val); + if (!PredI) return VectorLength == nullptr && Mask == nullptr; + return VectorLength == PredI->getVectorLength() && Mask == PredI->getMask(); + } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(PredicatedContext PC) const { + return PC.Mask == Mask && PC.VectorLength == VectorLength; + } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(PredicatedContext PC) const { return acceptContext(PC); } + + // match \p P in a new contest for \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match \p P in the current context. + template bool try_match(Val *V, const Pattern &P) { + PredicatedContext SubContext(*this); + return const_cast(P).match_context(V, SubContext); + } +}; + +struct PredicatedContext; +template<> struct MatcherCast { using ActualCastType = PredicatedBinaryOperator; }; +template<> struct MatcherCast { using ActualCastType = PredicatedOperator; }; +template<> struct MatcherCast { using ActualCastType = PredicatedICmpInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedFCmpInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedSelectInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedInstruction; }; + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H Index: include/llvm/IR/VPBuilder.h =================================================================== --- /dev/null +++ include/llvm/IR/VPBuilder.h @@ -0,0 +1,232 @@ +#ifndef LLVM_IR_VPBUILDER_H +#define LLVM_IR_VPBUILDER_H + +#include +#include +#include +#include +#include +#include + +namespace llvm { + +using ValArray = ArrayRef; + +class VPBuilder { + IRBuilder<> & Builder; + // Explicit mask parameter + Value * Mask; + // Explicit vector length parameter + Value * ExplicitVectorLength; + // Compile-time vector length + int StaticVectorLength; + + // get a vlaid mask/evl argument for the current predication contet + Value& GetMaskForType(VectorType & VecTy); + Value& GetEVLForType(VectorType & VecTy); + +public: + VPBuilder(IRBuilder<> & _builder) + : Builder(_builder) + , Mask(nullptr) + , ExplicitVectorLength(nullptr) + , StaticVectorLength(-1) + {} + + Module & getModule() const; + LLVMContext & getContext() const { return Builder.getContext(); } + + // The cannonical vector type for this \p ElementTy + VectorType& getVectorType(Type &ElementTy); + + // Predication context tracker + VPBuilder& setMask(Value * _Mask) { Mask = _Mask; return *this; } + VPBuilder& setEVL(Value * _ExplicitVectorLength) { ExplicitVectorLength = _ExplicitVectorLength; return *this; } + VPBuilder& setStaticVL(int VLen) { StaticVectorLength = VLen; return *this; } + + VPIntrinsic::VPIntrinsicDesc GetVPIntrinsicDesc(unsigned OC); + + // Create a map-vectorized copy of the instruction \p Inst with the underlying IRBuilder instance. + // This operation may return nullptr if the instruction could not be vectorized. + Value* CreateVectorCopy(Instruction & Inst, ValArray VecOpArray); + + // Memory + Value& CreateContiguousStore(Value & Val, Value & Pointer, unsigned Alignment=0); + Value& CreateContiguousLoad(Value & Pointer, unsigned Alignment=0); + Value& CreateScatter(Value & Val, Value & PointerVec, unsigned Alignment=0); + Value& CreateGather(Value & PointerVec, unsigned Alignment=0); +}; + + + + + +namespace PatternMatch { + // Factory class to generate instructions in a context + template + class MatchContextBuilder { + public: + // MatchContextBuilder(MatcherContext MC); + }; + + +// Context-free instruction builder +template<> +class MatchContextBuilder { +public: + MatchContextBuilder(EmptyContext & EC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Value *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const { \ + auto * Inst = BinaryOperator::Create(Instruction::OPC, V1, V2, Name); \ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + BinaryOperator *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FAdd, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, V1, V2, FMFSource, Name); + } + template + BinaryOperator *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + BinaryOperator *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FMul, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FDiv, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FRem, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, Zero, Op, FMFSource); + } + + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + + + +// Context-free instruction builder +template<> +class MatchContextBuilder { + PredicatedContext & PC; +public: + MatchContextBuilder(PredicatedContext & PC) : PC(PC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + auto * PredInst = Create##OPC(V1, V2, Name); Builder.Insert(PredInst); return PredInst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + Instruction *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FAdd, V1, V2, FMFSource, Name); + } + Instruction *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, V1, V2, FMFSource, Name); + } + template + Instruction *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + Instruction *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FMul, V1, V2, FMFSource, Name); + } + Instruction *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FDiv, V1, V2, FMFSource, Name); + } + Instruction *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FRem, V1, V2, FMFSource, Name); + } + Instruction *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, Zero, Op, FMFSource); + } + + // TODO predicated casts + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + +} + +} // namespace llvm + +#endif // LLVM_IR_VPBUILDER_H Index: include/llvm/Target/TargetSelectionDAG.td =================================================================== --- include/llvm/Target/TargetSelectionDAG.td +++ include/llvm/Target/TargetSelectionDAG.td @@ -128,6 +128,13 @@ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<3> ]>; +def SDTIntBinOpVP : SDTypeProfile<1, 4, [ // vp_add, vp_and, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTIntShiftOpVP : SDTypeProfile<1, 4, [ // shl, sra, srl + SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; + def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc. SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0> ]>; @@ -170,6 +177,16 @@ SDTCisOpSmallerThanOp<1, 0> ]>; +def SDTFPUnOpVP : SDTypeProfile<1, 3, [ // vp_fneg, etc. + SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<3>, SDTCisSameNumEltsAs<0, 2> +]>; +def SDTFPBinOpVP : SDTypeProfile<1, 4, [ // vp_fadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTFPTernaryOpVP : SDTypeProfile<1, 5, [ // vp_fmadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSetCC : SDTypeProfile<1, 3, [ // setcc SDTCisInt<0>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT> ]>; @@ -182,6 +199,10 @@ SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1> ]>; +def SDTVSelectVP : SDTypeProfile<1, 5, [ // vp_vselect + SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSelectCC : SDTypeProfile<1, 5, [ // select_cc SDTCisSameAs<1, 2>, SDTCisSameAs<3, 4>, SDTCisSameAs<0, 3>, SDTCisVT<5, OtherVT> @@ -225,11 +246,20 @@ SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTStoreVP: SDTypeProfile<0, 4, [ // evl store + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3> +]>; + def SDTMaskedLoad: SDTypeProfile<1, 3, [ // masked load SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameAs<0, 3>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTLoadVP : SDTypeProfile<1, 3, [ // evl load + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3>, + SDTCisSameNumEltsAs<0, 2> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; @@ -385,6 +415,26 @@ def umax : SDNode<"ISD::UMAX" , SDTIntBinOp, [SDNPCommutative, SDNPAssociative]>; +def vp_and : SDNode<"ISD::VP_AND" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_or : SDNode<"ISD::VP_OR" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_xor : SDNode<"ISD::VP_XOR" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_srl : SDNode<"ISD::VP_SRL" , SDTIntShiftOpVP>; +def vp_sra : SDNode<"ISD::VP_SRA" , SDTIntShiftOpVP>; +def vp_shl : SDNode<"ISD::VP_SHL" , SDTIntShiftOpVP>; + +def vp_add : SDNode<"ISD::VP_ADD" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_sub : SDNode<"ISD::VP_SUB" , SDTIntBinOpVP>; +def vp_mul : SDNode<"ISD::VP_MUL" , SDTIntBinOpVP, + [SDNPCommutative, SDNPAssociative]>; +def vp_sdiv : SDNode<"ISD::VP_SDIV" , SDTIntBinOpVP>; +def vp_udiv : SDNode<"ISD::VP_UDIV" , SDTIntBinOpVP>; +def vp_srem : SDNode<"ISD::VP_SREM" , SDTIntBinOpVP>; +def vp_urem : SDNode<"ISD::VP_UREM" , SDTIntBinOpVP>; + def saddsat : SDNode<"ISD::SADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def uaddsat : SDNode<"ISD::UADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def ssubsat : SDNode<"ISD::SSUBSAT" , SDTIntBinOp>; @@ -454,6 +504,14 @@ def fpextend : SDNode<"ISD::FP_EXTEND" , SDTFPExtendOp>; def fcopysign : SDNode<"ISD::FCOPYSIGN" , SDTFPSignOp>; +def vp_fneg : SDNode<"ISD::VP_FNEG" , SDTFPUnOpVP>; +def vp_fadd : SDNode<"ISD::VP_FADD" , SDTFPBinOpVP, [SDNPCommutative]>; +def vp_fsub : SDNode<"ISD::VP_FSUB" , SDTFPBinOpVP>; +def vp_fmul : SDNode<"ISD::VP_FMUL" , SDTFPBinOpVP, [SDNPCommutative]>; +def vp_fdiv : SDNode<"ISD::VP_FDIV" , SDTFPBinOpVP>; +def vp_frem : SDNode<"ISD::VP_FREM" , SDTFPBinOpVP>; +def vp_fma : SDNode<"ISD::VP_FMA" , SDTFPTernaryOpVP>; + def sint_to_fp : SDNode<"ISD::SINT_TO_FP" , SDTIntToFPOp>; def uint_to_fp : SDNode<"ISD::UINT_TO_FP" , SDTIntToFPOp>; def fp_to_sint : SDNode<"ISD::FP_TO_SINT" , SDTFPToIntOp>; @@ -461,10 +519,10 @@ def f16_to_fp : SDNode<"ISD::FP16_TO_FP" , SDTIntToFPOp>; def fp_to_f16 : SDNode<"ISD::FP_TO_FP16" , SDTFPToIntOp>; -def setcc : SDNode<"ISD::SETCC" , SDTSetCC>; -def select : SDNode<"ISD::SELECT" , SDTSelect>; -def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>; -def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>; +def setcc : SDNode<"ISD::SETCC" , SDTSetCC>; +def select : SDNode<"ISD::SELECT" , SDTSelect>; +def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>; +def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>; def brcc : SDNode<"ISD::BR_CC" , SDTBrCC, [SDNPHasChain]>; def brcond : SDNode<"ISD::BRCOND" , SDTBrcond, [SDNPHasChain]>; @@ -532,6 +590,11 @@ def masked_load : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def vp_store : SDNode<"ISD::VP_STORE", SDTStoreVP, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def vp_load : SDNode<"ISD::VP_LOAD", SDTLoadVP, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). def ld : SDNode<"ISD::LOAD" , SDTLoad, Index: lib/Analysis/InstructionSimplify.cpp =================================================================== --- lib/Analysis/InstructionSimplify.cpp +++ lib/Analysis/InstructionSimplify.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include @@ -4310,8 +4311,10 @@ /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. -static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { +template +static Value *SimplifyFSubInstGeneric(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, MatchContext & MC) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; @@ -4319,23 +4322,23 @@ return C; // fsub X, +0 ==> X - if (match(Op1, m_PosZeroFP())) + if (MC.try_match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZeroFP()) && + if (MC.try_match(Op1, m_NegZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) + if (MC.try_match(Op0, m_NegZeroFP()) && + MC.try_match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && - match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) + MC.try_match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) return X; // fsub nnan x, x ==> 0.0 @@ -4345,13 +4348,20 @@ // Y - (Y - X) --> X // (X + Y) - Y --> X if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || - match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + MC.try_match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) return X; return nullptr; } +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + EmptyContext EC; + return SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, MaxRecurse, EC); +} + + /// Given the operands for an FMul, see if we can fold the result static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -4392,6 +4402,11 @@ return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); } +Value *llvm::SimplifyPredicatedFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, PredicatedContext & PC) { + return ::SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, PC); +} + Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); @@ -4962,9 +4977,20 @@ Call->arg_end(), Q, RecursionLimit); } +Value *llvm::SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q) { + PredicatedContext PC(&VPInst); + + auto & PI = cast(VPInst); + switch (PI.getOpcode()) { + default: + return nullptr; + + case Instruction::FSub: return SimplifyPredicatedFSubInst(VPInst.getOperand(0), VPInst.getOperand(1), VPInst.getFastMathFlags(), Q, PC); + } +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. - Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); @@ -5098,6 +5124,13 @@ Result = SimplifyPHINode(cast(I), Q); break; case Instruction::Call: { + auto * VPInst = dyn_cast(I); + if (VPInst) { + Result = SimplifyVPIntrinsic(*VPInst, Q); + if (Result) break; + } + + CallSite CS((I)); Result = SimplifyCall(cast(I), Q); break; } Index: lib/AsmParser/LLLexer.cpp =================================================================== --- lib/AsmParser/LLLexer.cpp +++ lib/AsmParser/LLLexer.cpp @@ -642,6 +642,7 @@ KEYWORD(inlinehint); KEYWORD(inreg); KEYWORD(jumptable); + KEYWORD(mask); KEYWORD(minsize); KEYWORD(naked); KEYWORD(nest); @@ -661,6 +662,7 @@ KEYWORD(optforfuzzing); KEYWORD(optnone); KEYWORD(optsize); + KEYWORD(passthru); KEYWORD(readnone); KEYWORD(readonly); KEYWORD(returned); @@ -682,6 +684,7 @@ KEYWORD(swifterror); KEYWORD(swiftself); KEYWORD(uwtable); + KEYWORD(vlen); KEYWORD(writeonly); KEYWORD(zeroext); KEYWORD(immarg); Index: lib/AsmParser/LLParser.cpp =================================================================== --- lib/AsmParser/LLParser.cpp +++ lib/AsmParser/LLParser.cpp @@ -1302,15 +1302,18 @@ case lltok::kw_dereferenceable: case lltok::kw_dereferenceable_or_null: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_noalias: case lltok::kw_nocapture: case lltok::kw_nonnull: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute on a function"); @@ -1591,10 +1594,12 @@ } case lltok::kw_inalloca: B.addAttribute(Attribute::InAlloca); break; case lltok::kw_inreg: B.addAttribute(Attribute::InReg); break; + case lltok::kw_mask: B.addAttribute(Attribute::Mask); break; case lltok::kw_nest: B.addAttribute(Attribute::Nest); break; case lltok::kw_noalias: B.addAttribute(Attribute::NoAlias); break; case lltok::kw_nocapture: B.addAttribute(Attribute::NoCapture); break; case lltok::kw_nonnull: B.addAttribute(Attribute::NonNull); break; + case lltok::kw_passthru: B.addAttribute(Attribute::Passthru); break; case lltok::kw_readnone: B.addAttribute(Attribute::ReadNone); break; case lltok::kw_readonly: B.addAttribute(Attribute::ReadOnly); break; case lltok::kw_returned: B.addAttribute(Attribute::Returned); break; @@ -1602,6 +1607,7 @@ case lltok::kw_sret: B.addAttribute(Attribute::StructRet); break; case lltok::kw_swifterror: B.addAttribute(Attribute::SwiftError); break; case lltok::kw_swiftself: B.addAttribute(Attribute::SwiftSelf); break; + case lltok::kw_vlen: B.addAttribute(Attribute::VectorLength); break; case lltok::kw_writeonly: B.addAttribute(Attribute::WriteOnly); break; case lltok::kw_zeroext: B.addAttribute(Attribute::ZExt); break; case lltok::kw_immarg: B.addAttribute(Attribute::ImmArg); break; @@ -1693,13 +1699,16 @@ // Error handling. case lltok::kw_byval: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_nocapture: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute"); break; @@ -3319,7 +3328,7 @@ ID.Kind = ValID::t_Constant; return false; } - + // Unary Operators. case lltok::kw_fneg: { unsigned Opc = Lex.getUIntVal(); @@ -3329,7 +3338,7 @@ ParseGlobalTypeAndValue(Val) || ParseToken(lltok::rparen, "expected ')' in unary constantexpr")) return true; - + // Check that the type is valid for the operator. switch (Opc) { case Instruction::FNeg: @@ -6225,11 +6234,11 @@ Valid = LHS->getType()->isIntOrIntVectorTy() || LHS->getType()->isFPOrFPVectorTy(); break; - case 1: - Valid = LHS->getType()->isIntOrIntVectorTy(); + case 1: + Valid = LHS->getType()->isIntOrIntVectorTy(); break; - case 2: - Valid = LHS->getType()->isFPOrFPVectorTy(); + case 2: + Valid = LHS->getType()->isFPOrFPVectorTy(); break; } Index: lib/AsmParser/LLToken.h =================================================================== --- lib/AsmParser/LLToken.h +++ lib/AsmParser/LLToken.h @@ -186,6 +186,7 @@ kw_inlinehint, kw_inreg, kw_jumptable, + kw_mask, kw_minsize, kw_naked, kw_nest, @@ -205,6 +206,7 @@ kw_optforfuzzing, kw_optnone, kw_optsize, + kw_passthru, kw_readnone, kw_readonly, kw_returned, @@ -224,6 +226,7 @@ kw_swifterror, kw_swiftself, kw_uwtable, + kw_vlen, kw_writeonly, kw_zeroext, kw_immarg, Index: lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- lib/Bitcode/Reader/BitcodeReader.cpp +++ lib/Bitcode/Reader/BitcodeReader.cpp @@ -1334,6 +1334,8 @@ return Attribute::InReg; case bitc::ATTR_KIND_JUMP_TABLE: return Attribute::JumpTable; + case bitc::ATTR_KIND_MASK: + return Attribute::Mask; case bitc::ATTR_KIND_MIN_SIZE: return Attribute::MinSize; case bitc::ATTR_KIND_NAKED: @@ -1378,6 +1380,8 @@ return Attribute::OptimizeForSize; case bitc::ATTR_KIND_OPTIMIZE_NONE: return Attribute::OptimizeNone; + case bitc::ATTR_KIND_PASSTHRU: + return Attribute::Passthru; case bitc::ATTR_KIND_READ_NONE: return Attribute::ReadNone; case bitc::ATTR_KIND_READ_ONLY: @@ -1422,6 +1426,8 @@ return Attribute::SwiftSelf; case bitc::ATTR_KIND_UW_TABLE: return Attribute::UWTable; + case bitc::ATTR_KIND_VECTORLENGTH: + return Attribute::VectorLength; case bitc::ATTR_KIND_WRITEONLY: return Attribute::WriteOnly; case bitc::ATTR_KIND_Z_EXT: Index: lib/Bitcode/Writer/BitcodeWriter.cpp =================================================================== --- lib/Bitcode/Writer/BitcodeWriter.cpp +++ lib/Bitcode/Writer/BitcodeWriter.cpp @@ -672,6 +672,12 @@ return bitc::ATTR_KIND_READ_ONLY; case Attribute::Returned: return bitc::ATTR_KIND_RETURNED; + case Attribute::Mask: + return bitc::ATTR_KIND_MASK; + case Attribute::VectorLength: + return bitc::ATTR_KIND_VECTORLENGTH; + case Attribute::Passthru: + return bitc::ATTR_KIND_PASSTHRU; case Attribute::ReturnsTwice: return bitc::ATTR_KIND_RETURNS_TWICE; case Attribute::SExt: Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -395,6 +395,7 @@ SDValue visitBITCAST(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); + SDValue visitFADD_VP(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); SDValue visitFMA(SDNode *N); @@ -444,6 +445,7 @@ SDValue visitFP16_TO_FP(SDNode *N); SDValue visitVECREDUCE(SDNode *N); + template SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -698,6 +700,137 @@ void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); } }; +struct EmptyMatchContext { + SelectionDAG & DAG; + + EmptyMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + {} + + bool match(SDValue OpN, unsigned OpCode) const { return OpCode == OpN->getOpcode(); } + + unsigned getFunctionOpCode(SDValue N) const { + return N->getOpcode(); + } + + bool isCompatible(SDValue OpVal) const { return true; } + + // Specialize based on number of operands. + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, Operand, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4, SDValue N5) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4, N5); + } +}; + +struct +VPMatchContext { + SelectionDAG & DAG; + SDNode * Root; + SDValue RootMaskOp; + SDValue RootVectorLenOp; + + VPMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + , Root(Root) + , RootMaskOp() + , RootVectorLenOp() + { + if (Root->isVP()) { + int RootMaskPos = ISD::GetMaskPosVP(Root->getOpcode()); + if (RootMaskPos != -1) { + RootMaskOp = Root->getOperand(RootMaskPos); + } + + int RootVLenPos = ISD::GetVectorLengthPosVP(Root->getOpcode()); + if (RootVLenPos != -1) { + RootVectorLenOp = Root->getOperand(RootVLenPos); + } + } + } + + unsigned getFunctionOpCode(SDValue N) const { + unsigned VPOpCode = N->getOpcode(); + return ISD::GetFunctionOpCodeForVP(VPOpCode); + } + + bool isCompatible(SDValue OpVal) const { + if (!OpVal->isVP()) { + return !Root->isVP(); + + } else { + unsigned VPOpCode = OpVal->getOpcode(); + int MaskPos = ISD::GetMaskPosVP(VPOpCode); + if (MaskPos != -1 && RootMaskOp != OpVal.getOperand(MaskPos)) { + return false; + } + + int VLenPos = ISD::GetVectorLengthPosVP(VPOpCode); + if (VLenPos != -1 && RootVectorLenOp != OpVal.getOperand(VLenPos)) { + return false; + } + + return true; + } + } + + /// whether \p OpN is a node that is functionally compatible with the NodeType \p OpNodeTy + bool match(SDValue OpVal, unsigned OpNT) const { + return isCompatible(OpVal) && getFunctionOpCode(OpVal) == OpNT; + } + + // Specialize based on number of operands. + // TODO emit VP intrinsics where MaskOp/VectorLenOp != null + // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 1 && VLenPos == 2); + + return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 2 && VLenPos == 3); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 3 && VLenPos == 4); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -1611,6 +1744,7 @@ case ISD::BITCAST: return visitBITCAST(N); case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); case ISD::FADD: return visitFADD(N); + case ISD::VP_FADD: return visitFADD_VP(N); case ISD::FSUB: return visitFSUB(N); case ISD::FMUL: return visitFMUL(N); case ISD::FMA: return visitFMA(N); @@ -10760,13 +10894,18 @@ return F.hasAllowContract() || F.hasAllowReassociation(); } + /// Try to perform FMA combining on a given FADD node. +template SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + MatchContextClass matcher(DAG, N); + if (!matcher.isCompatible(N0) || !matcher.isCompatible(N1)) return SDValue(); + const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. @@ -10799,8 +10938,8 @@ // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || isContractable(N.getNode()); }; @@ -10813,42 +10952,42 @@ // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1, Flags); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0, Flags); } // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if ((N0.getOpcode() == ISD::FP_EXTEND) && matcher.isCompatible(N0.getOperand(0))) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1, Flags); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0, Flags); } } @@ -10857,12 +10996,12 @@ if (Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N0, PreferredFusedOpcode) && + matcher.match(N0.getOperand(2), ISD::FMUL) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), N1, Flags), Flags); @@ -10870,12 +11009,12 @@ // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N1, PreferredFusedOpcode) && + matcher.match(N1.getOperand(2), ISD::FMUL) && N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), N0, Flags), Flags); @@ -10887,15 +11026,15 @@ auto FoldFAddFMAFPExtFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, X, Y, + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; - if (N0.getOpcode() == PreferredFusedOpcode) { + if (matcher.match(N0, PreferredFusedOpcode)) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) { @@ -10914,12 +11053,12 @@ auto FoldFAddFPExtFMAFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, X), + matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -11356,6 +11495,15 @@ return SDValue(); } +SDValue DAGCombiner::visitFADD_VP(SDNode *N) { + // FADD -> FMA combines: + if (SDValue Fused = visitFADDForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -11528,7 +11676,7 @@ } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -909,7 +909,7 @@ } // Handle promotion for the ADDE/SUBE/ADDCARRY/SUBCARRY nodes. Notice that -// the third operand of ADDE/SUBE nodes is carry flag, which differs from +// the third operand of ADDE/SUBE nodes is carry flag, which differs from // the ADDCARRY/SUBCARRY nodes in that the third operand is carry Boolean. SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo) { if (ResNo == 1) @@ -1062,6 +1062,9 @@ return false; } + if (N->isVP()) { + Res = PromoteIntOp_VP(N, OpNo); + } else { switch (N->getOpcode()) { default: #ifndef NDEBUG @@ -1136,6 +1139,7 @@ case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: Res = PromoteIntOp_VECREDUCE(N); break; } + } // If the result is null, the sub-method took care of registering results etc. if (!Res.getNode()) return false; @@ -1409,6 +1413,25 @@ TruncateStore, N->isCompressingStore()); } +SDValue DAGTypeLegalizer::PromoteIntOp_VP(SDNode *N, unsigned OpNo) { + EVT DataVT; + switch (N->getOpcode()) { + default: + DataVT = N->getValueType(0); + break; + + case ISD::VP_STORE: + case ISD::VP_SCATTER: + llvm_unreachable("TODO implement VP memory nodes"); + } + + // TODO assert that \p OpNo is the mask + SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = Mask; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo) { assert(OpNo == 2 && "Only know how to promote the mask!"); Index: lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -349,6 +349,7 @@ SDValue PromoteIntRes_VECREDUCE(SDNode *N); SDValue PromoteIntRes_ABS(SDNode *N); + // Integer Operand Promotion. bool PromoteIntegerOperand(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_ANY_EXTEND(SDNode *N); @@ -383,6 +384,7 @@ SDValue PromoteIntOp_MULFIX(SDNode *N); SDValue PromoteIntOp_FPOWI(SDNode *N); SDValue PromoteIntOp_VECREDUCE(SDNode *N); + SDValue PromoteIntOp_VP(SDNode *N, unsigned OpNo); void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code); Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -433,6 +433,181 @@ return Result; } +//===----------------------------------------------------------------------===// +// SDNode VP Support +//===----------------------------------------------------------------------===// + +int +ISD::GetMaskPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case ISD::VP_FNEG: + return 1; + + case ISD::VP_ADD: + case ISD::VP_SUB: + case ISD::VP_MUL: + case ISD::VP_SDIV: + case ISD::VP_SREM: + case ISD::VP_UDIV: + case ISD::VP_UREM: + + case ISD::VP_AND: + case ISD::VP_OR: + case ISD::VP_XOR: + case ISD::VP_SHL: + case ISD::VP_SRA: + case ISD::VP_SRL: + case ISD::VP_FDIV: + case ISD::VP_FREM: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + return 2; + + case ISD::VP_FMA: + case ISD::VP_SELECT: + return 3; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 1; + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + // (implicit) case ISD::VP_COMPOSE: return -1 + } +} + +int +ISD::GetVectorLengthPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case VP_SELECT: + return 0; + + case VP_FNEG: + return 2; + + case VP_ADD: + case VP_SUB: + case VP_MUL: + case VP_SDIV: + case VP_SREM: + case VP_UDIV: + case VP_UREM: + + case VP_AND: + case VP_OR: + case VP_XOR: + case VP_SHL: + case VP_SRA: + case VP_SRL: + + case VP_FADD: + case VP_FMUL: + case VP_FDIV: + case VP_FREM: + return 3; + + case VP_FMA: + return 4; + + case VP_COMPOSE: + return 3; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 2; + } +} + +unsigned +ISD::GetFunctionOpCodeForVP(unsigned OpCode) { + switch (OpCode) { + default: return OpCode; + + case VP_SELECT: return ISD::VSELECT; + case VP_FNEG: return ISD::FNEG; + case VP_ADD: return ISD::ADD; + case VP_SUB: return ISD::SUB; + case VP_MUL: return ISD::MUL; + case VP_SDIV: return ISD::SDIV; + case VP_SREM: return ISD::SREM; + case VP_UDIV: return ISD::UDIV; + case VP_UREM: return ISD::UREM; + + case VP_AND: return ISD::AND; + case VP_OR: return ISD::OR; + case VP_XOR: return ISD::XOR; + case VP_SHL: return ISD::SHL; + case VP_SRA: return ISD::SRA; + case VP_SRL: return ISD::SRL; + case VP_FDIV: return ISD::FDIV; + case VP_FREM: return ISD::FREM; + + case VP_FADD: return ISD::FADD; + case VP_FMUL: return ISD::FMUL; + + case VP_FMA: return ISD::FMA; + } +} + +unsigned +ISD::GetVPForFunctionOpCode(unsigned OpCode) { + switch (OpCode) { + default: llvm_unreachable("can not translate this Opcode to VP"); + + case VSELECT:return ISD::VP_SELECT; + case FNEG: return ISD::VP_FNEG; + case ADD: return ISD::VP_ADD; + case SUB: return ISD::VP_SUB; + case MUL: return ISD::VP_MUL; + case SDIV: return ISD::VP_SDIV; + case SREM: return ISD::VP_SREM; + case UDIV: return ISD::VP_UDIV; + case UREM: return ISD::VP_UREM; + + case AND: return ISD::VP_AND; + case OR: return ISD::VP_OR; + case XOR: return ISD::VP_XOR; + case SHL: return ISD::VP_SHL; + case SRA: return ISD::VP_SRA; + case SRL: return ISD::VP_SRL; + case FDIV: return ISD::VP_FDIV; + case FREM: return ISD::VP_FREM; + + case FADD: return ISD::VP_FADD; + case FMUL: return ISD::VP_FMUL; + + case FMA: return ISD::VP_FMA; + } +} + + //===----------------------------------------------------------------------===// // SDNode Profile Support //===----------------------------------------------------------------------===// @@ -563,6 +738,34 @@ ID.AddInteger(ST->getPointerInfo().getAddrSpace()); break; } + case ISD::VP_LOAD: { + const VPLoadSDNode *ELD = cast(N); + ID.AddInteger(ELD->getMemoryVT().getRawBits()); + ID.AddInteger(ELD->getRawSubclassData()); + ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_STORE: { + const VPStoreSDNode *EST = cast(N); + ID.AddInteger(EST->getMemoryVT().getRawBits()); + ID.AddInteger(EST->getRawSubclassData()); + ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_GATHER: { + const VPGatherSDNode *EG = cast(N); + ID.AddInteger(EG->getMemoryVT().getRawBits()); + ID.AddInteger(EG->getRawSubclassData()); + ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_SCATTER: { + const VPScatterSDNode *ES = cast(N); + ID.AddInteger(ES->getMemoryVT().getRawBits()); + ID.AddInteger(ES->getRawSubclassData()); + ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + break; + } case ISD::MLOAD: { const MaskedLoadSDNode *MLD = cast(N); ID.AddInteger(MLD->getMemoryVT().getRawBits()); @@ -6843,6 +7046,34 @@ return V; } +SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, + SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + ISD::LoadExtType ExtTy) { + SDVTList VTs = getVTList(VT, MVT::Other); + SDValue Ops[] = { Chain, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, ExtTy, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + ExtTy, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Mask, SDValue PassThru, EVT MemVT, MachineMemOperand *MMO, @@ -6871,6 +7102,111 @@ return V; } +SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, + SDValue Val, SDValue Ptr, SDValue Mask, + SDValue VLen, EVT MemVT, MachineMemOperand *MMO, + bool IsTruncating) { + assert(Chain.getValueType() == MVT::Other && + "Invalid chain type"); + EVT VT = Val.getValueType(); + SDVTList VTs = getVTList(MVT::Other); + SDValue Ops[] = { Chain, Val, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_STORE, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, IsTruncating, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + IsTruncating, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + assert(Ops.size() == 6 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_GATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + assert(Ops.size() == 7 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_SCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -950,6 +950,12 @@ const char *visitIntrinsicCall(const CallInst &I, unsigned Intrinsic); void visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic); void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI); + void visitExplicitVectorLengthIntrinsic(const VPIntrinsic &VPI); + void visitCmpVP(const VPIntrinsic &I); + void visitLoadVP(const CallInst &I); + void visitStoreVP(const CallInst &I); + void visitGatherVP(const CallInst &I); + void visitScatterVP(const CallInst &I); void visitVAStart(const CallInst &I); void visitVAArg(const VAArgInst &I); Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4265,6 +4265,46 @@ setValue(&I, StoreNode); } +void SelectionDAGBuilder::visitStoreVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getVPStoreOps = [&](Value* &Ptr, Value* &Mask, Value* &Src0, + Value * &VLen, unsigned & Alignment) { + // llvm.masked.store.*(Src0, Ptr, Mask, VLen) + Src0 = I.getArgOperand(0); + Ptr = I.getArgOperand(1); + Alignment = I.getParamAlignment(1); + Mask = I.getArgOperand(2); + VLen = I.getArgOperand(3); + }; + + Value *PtrOperand, *MaskOperand, *Src0Operand, *VLenOperand; + unsigned Alignment = 0; + getVPStoreOps(PtrOperand, MaskOperand, Src0Operand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue Src0 = getValue(Src0Operand); + SDValue Mask = getValue(MaskOperand); + SDValue VLen = getValue(VLenOperand); + + EVT VT = Src0.getValueType(); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + SDValue StoreNode = DAG.getStoreVP(getRoot(), sdl, Src0, Ptr, Mask, VLen, VT, + MMO, false /* Truncating */); + DAG.setRoot(StoreNode); + setValue(&I, StoreNode); +} + // Get a uniform base for the Gather/Scatter intrinsic. // The first argument of the Gather/Scatter intrinsic is a vector of pointers. // We try to represent it as a base pointer + vector of indices. @@ -4483,6 +4523,158 @@ setValue(&I, Gather); } +void SelectionDAGBuilder::visitGatherVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // @llvm.evl.gather.*(Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(0); + SDValue Mask = getValue(I.getArgOperand(1)); + SDValue VLen = getValue(I.getArgOperand(2)); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + unsigned Alignment = I.getParamAlignment(0); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + SDValue Root = DAG.getRoot(); + SDValue Base; + SDValue Index; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); + bool ConstantMemory = false; + if (UniformBase && AA && + AA->pointsToConstantMemory( + MemoryLocation(BasePtr, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo))) { + // Do not serialize (non-volatile) loads of constant memory with anything. + Root = DAG.getEntryNode(); + ConstantMemory = true; + } + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { Root, Base, Index, Scale, Mask, VLen }; + SDValue Gather = DAG.getGatherVP(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO); + + SDValue OutChain = Gather.getValue(1); + if (!ConstantMemory) + PendingLoads.push_back(OutChain); + setValue(&I, Gather); +} + +void SelectionDAGBuilder::visitScatterVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // llvm.evl.scatter.*(Src0, Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(1); + SDValue Src0 = getValue(I.getArgOperand(0)); + SDValue Mask = getValue(I.getArgOperand(2)); + SDValue VLen = getValue(I.getArgOperand(3)); + EVT VT = Src0.getValueType(); + unsigned Alignment = I.getParamAlignment(1); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + SDValue Base; + SDValue Index; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); + + const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { getRoot(), Src0, Base, Index, Scale, Mask, VLen }; + SDValue Scatter = DAG.getScatterVP(DAG.getVTList(MVT::Other), VT, sdl, + Ops, MMO); + DAG.setRoot(Scatter); + setValue(&I, Scatter); +} + +void SelectionDAGBuilder::visitLoadVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getMaskedLoadOps = [&](Value* &Ptr, Value* &Mask, Value* &VLen, + unsigned& Alignment) { + // @llvm.evl.load.*(Ptr, Mask, Vlen) + Ptr = I.getArgOperand(0); + Alignment = I.getParamAlignment(0); + Mask = I.getArgOperand(1); + VLen = I.getArgOperand(2); + }; + + Value *PtrOperand, *MaskOperand, *VLenOperand; + unsigned Alignment; + getMaskedLoadOps(PtrOperand, MaskOperand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue VLen = getValue(VLenOperand); + SDValue Mask = getValue(MaskOperand); + + // infer the return type + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + SmallVector ValValueVTs; + ComputeValueVTs(TLI, DAG.getDataLayout(), I.getType(), ValValueVTs); + EVT VT = ValValueVTs[0]; + assert((ValValueVTs.size() == 1) && "splitting not implemented"); + + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + // Do not serialize masked loads of constant memory with anything. + bool AddToChain = + !AA || !AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo)); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + SDValue Load = DAG.getLoadVP(VT, sdl, InChain, Ptr, Mask, VLen, VT, MMO, + ISD::NON_EXTLOAD); + if (AddToChain) + PendingLoads.push_back(Load.getValue(1)); + setValue(&I, Load); +} + void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) { SDLoc dl = getCurSDLoc(); AtomicOrdering SuccessOrdering = I.getSuccessOrdering(); @@ -6028,6 +6220,63 @@ case Intrinsic::experimental_constrained_trunc: visitConstrainedFPIntrinsic(cast(I)); return nullptr; + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + case Intrinsic::vp_cmp: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_smin: + visitExplicitVectorLengthIntrinsic(cast(I)); + return nullptr; + case Intrinsic::fmuladd: { EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); if (TM.Options.AllowFPOpFusion != FPOpFusion::Strict && @@ -6847,6 +7096,138 @@ setValue(&FPI, FPResult); } +void SelectionDAGBuilder::visitCmpVP(const VPIntrinsic &I) { + ISD::CondCode Condition; + CmpInst::Predicate predicate = I.getCmpPredicate(); + bool IsFP = I.getOperand(0)->getType()->isFPOrFPVectorTy(); + if (IsFP) { + Condition = getFCmpCondCode(predicate); + auto *FPMO = dyn_cast(&I); + if ((FPMO && FPMO->hasNoNaNs()) || TM.Options.NoNaNsFPMath) + Condition = getFCmpCodeWithoutNaN(Condition); + + } else { + Condition = getICmpCondCode(predicate); + } + + SDValue Op1 = getValue(I.getOperand(0)); + SDValue Op2 = getValue(I.getOperand(1)); + + EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(), + I.getType()); + setValue(&I, DAG.getSetCC(getCurSDLoc(), DestVT, Op1, Op2, Condition)); +} + +void SelectionDAGBuilder::visitExplicitVectorLengthIntrinsic( + const VPIntrinsic & VPInst) { + SDLoc sdl = getCurSDLoc(); + unsigned Opcode; + switch (VPInst.getIntrinsicID()) { + default: + llvm_unreachable("Unforeseen intrinsic"); // Can't reach here. + + case Intrinsic::vp_load: visitLoadVP(VPInst); return; + case Intrinsic::vp_store: visitStoreVP(VPInst); return; + case Intrinsic::vp_gather: visitGatherVP(VPInst); return; + case Intrinsic::vp_scatter: visitScatterVP(VPInst); return; + + case Intrinsic::vp_cmp: visitCmpVP(VPInst); return; + + case Intrinsic::vp_add: Opcode = ISD::VP_ADD; break; + case Intrinsic::vp_sub: Opcode = ISD::VP_SUB; break; + case Intrinsic::vp_mul: Opcode = ISD::VP_MUL; break; + case Intrinsic::vp_udiv: Opcode = ISD::VP_UDIV; break; + case Intrinsic::vp_sdiv: Opcode = ISD::VP_SDIV; break; + case Intrinsic::vp_urem: Opcode = ISD::VP_UREM; break; + case Intrinsic::vp_srem: Opcode = ISD::VP_SREM; break; + + case Intrinsic::vp_and: Opcode = ISD::VP_AND; break; + case Intrinsic::vp_or: Opcode = ISD::VP_OR; break; + case Intrinsic::vp_xor: Opcode = ISD::VP_XOR; break; + case Intrinsic::vp_ashr: Opcode = ISD::VP_SRA; break; + case Intrinsic::vp_lshr: Opcode = ISD::VP_SRL; break; + case Intrinsic::vp_shl: Opcode = ISD::VP_SHL; break; + + case Intrinsic::vp_fneg: Opcode = ISD::VP_FNEG; break; + case Intrinsic::vp_fadd: Opcode = ISD::VP_FADD; break; + case Intrinsic::vp_fsub: Opcode = ISD::VP_FSUB; break; + case Intrinsic::vp_fmul: Opcode = ISD::VP_FMUL; break; + case Intrinsic::vp_fdiv: Opcode = ISD::VP_FDIV; break; + case Intrinsic::vp_frem: Opcode = ISD::VP_FREM; break; + + case Intrinsic::vp_fma: Opcode = ISD::VP_FMA; break; + + case Intrinsic::vp_select: Opcode = ISD::VP_SELECT; break; + case Intrinsic::vp_compose: Opcode = ISD::VP_COMPOSE; break; + case Intrinsic::vp_compress: Opcode = ISD::VP_COMPRESS; break; + case Intrinsic::vp_expand: Opcode = ISD::VP_EXPAND; break; + case Intrinsic::vp_vshift: Opcode = ISD::VP_VSHIFT; break; + + case Intrinsic::vp_reduce_and: Opcode = ISD::VP_REDUCE_AND; break; + case Intrinsic::vp_reduce_or: Opcode = ISD::VP_REDUCE_OR; break; + case Intrinsic::vp_reduce_xor: Opcode = ISD::VP_REDUCE_XOR; break; + case Intrinsic::vp_reduce_add: Opcode = ISD::VP_REDUCE_ADD; break; + case Intrinsic::vp_reduce_mul: Opcode = ISD::VP_REDUCE_MUL; break; + case Intrinsic::vp_reduce_fadd: Opcode = ISD::VP_REDUCE_FADD; break; + case Intrinsic::vp_reduce_fmul: Opcode = ISD::VP_REDUCE_FMUL; break; + case Intrinsic::vp_reduce_smax: Opcode = ISD::VP_REDUCE_SMAX; break; + case Intrinsic::vp_reduce_smin: Opcode = ISD::VP_REDUCE_SMIN; break; + case Intrinsic::vp_reduce_umax: Opcode = ISD::VP_REDUCE_UMAX; break; + case Intrinsic::vp_reduce_umin: Opcode = ISD::VP_REDUCE_UMIN; break; + } + + // TODO memory evl: SDValue Chain = getRoot(); + + SmallVector ValueVTs; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + ComputeValueVTs(TLI, DAG.getDataLayout(), VPInst.getType(), ValueVTs); + SDVTList VTs = DAG.getVTList(ValueVTs); + + // ValueVTs.push_back(MVT::Other); // Out chain + + + SDValue Result; + + switch (VPInst.getNumArgOperands()) { + default: + llvm_unreachable("unexpected number of arguments to evl intrinsic"); + case 3: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(VPInst.getArgOperand(0)), + getValue(VPInst.getArgOperand(1)), + getValue(VPInst.getArgOperand(2)) }); + break; + + case 4: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(VPInst.getArgOperand(0)), + getValue(VPInst.getArgOperand(1)), + getValue(VPInst.getArgOperand(2)), + getValue(VPInst.getArgOperand(3)) }); + break; + + case 5: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(VPInst.getArgOperand(0)), + getValue(VPInst.getArgOperand(1)), + getValue(VPInst.getArgOperand(2)), + getValue(VPInst.getArgOperand(3)), + getValue(VPInst.getArgOperand(4)) }); + break; + } + + if (Result.getNode()->getNumValues() == 2) { + // this evl node has a chain + SDValue OutChain = Result.getValue(1); + DAG.setRoot(OutChain); + SDValue VPResult = Result.getValue(0); + setValue(&VPInst, VPResult); + } else { + // this is a pure node + setValue(&VPInst, Result); + } +} + std::pair SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI, const BasicBlock *EHPadBB) { Index: lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -423,6 +423,65 @@ case ISD::VECREDUCE_UMIN: return "vecreduce_umin"; case ISD::VECREDUCE_FMAX: return "vecreduce_fmax"; case ISD::VECREDUCE_FMIN: return "vecreduce_fmin"; + + // Explicit Vector Length erxtension + // VP Memory + case ISD::VP_LOAD: return "vp_load"; + case ISD::VP_STORE: return "vp_store"; + case ISD::VP_GATHER: return "vp_gather"; + case ISD::VP_SCATTER: return "vp_scatter"; + + // VP Unary operators + case ISD::VP_FNEG: return "vp_fneg"; + + // VP Binary operators + case ISD::VP_ADD: return "vp_add"; + case ISD::VP_SUB: return "vp_sub"; + case ISD::VP_MUL: return "vp_mul"; + case ISD::VP_SDIV: return "vp_sdiv"; + case ISD::VP_UDIV: return "vp_udiv"; + case ISD::VP_SREM: return "vp_srem"; + case ISD::VP_UREM: return "vp_urem"; + case ISD::VP_AND: return "vp_and"; + case ISD::VP_OR: return "vp_or"; + case ISD::VP_XOR: return "vp_xor"; + case ISD::VP_SHL: return "vp_shl"; + case ISD::VP_SRA: return "vp_sra"; + case ISD::VP_SRL: return "vp_srl"; + case ISD::VP_FADD: return "vp_fadd"; + case ISD::VP_FSUB: return "vp_fsub"; + case ISD::VP_FMUL: return "vp_fmul"; + case ISD::VP_FDIV: return "vp_fdiv"; + case ISD::VP_FREM: return "vp_frem"; + + // VP comparison + case ISD::VP_SETCC: return "vp_setcc"; + + // VP ternary operators + case ISD::VP_FMA: return "vp_fma"; + + // VP shuffle + case ISD::VP_VSHIFT: return "vp_vshift"; + case ISD::VP_COMPRESS: return "vp_compress"; + case ISD::VP_EXPAND: return "vp_expand"; + + case ISD::VP_COMPOSE: return "vp_compose"; + case ISD::VP_SELECT: return "vp_select"; + + // VP reduction operators + case ISD::VP_REDUCE_FADD: return "vp_reduce_fadd"; + case ISD::VP_REDUCE_FMUL: return "vp_reduce_fmul"; + case ISD::VP_REDUCE_ADD: return "vp_reduce_add"; + case ISD::VP_REDUCE_MUL: return "vp_reduce_mul"; + case ISD::VP_REDUCE_AND: return "vp_reduce_and"; + case ISD::VP_REDUCE_OR: return "vp_reduce_or"; + case ISD::VP_REDUCE_XOR: return "vp_reduce_xor"; + case ISD::VP_REDUCE_SMAX: return "vp_reduce_smax"; + case ISD::VP_REDUCE_SMIN: return "vp_reduce_smin"; + case ISD::VP_REDUCE_UMAX: return "vp_reduce_umax"; + case ISD::VP_REDUCE_UMIN: return "vp_reduce_umin"; + case ISD::VP_REDUCE_FMAX: return "vp_reduce_fmax"; + case ISD::VP_REDUCE_FMIN: return "vp_reduce_fmin"; } } Index: lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -783,6 +783,10 @@ CurDAG->Combine(BeforeLegalizeTypes, AA, OptLevel); } + if (getenv("SDEBUG")) { + CurDAG->dump(); + } + #ifndef NDEBUG if (TTI.hasBranchDivergence()) CurDAG->VerifyDAGDiverence(); Index: lib/IR/Attributes.cpp =================================================================== --- lib/IR/Attributes.cpp +++ lib/IR/Attributes.cpp @@ -256,6 +256,8 @@ return "byval"; if (hasAttribute(Attribute::Convergent)) return "convergent"; + if (hasAttribute(Attribute::VectorLength)) + return "vlen"; if (hasAttribute(Attribute::SwiftError)) return "swifterror"; if (hasAttribute(Attribute::SwiftSelf)) @@ -272,6 +274,10 @@ return "inreg"; if (hasAttribute(Attribute::JumpTable)) return "jumptable"; + if (hasAttribute(Attribute::Mask)) + return "mask"; + if (hasAttribute(Attribute::Passthru)) + return "passthru"; if (hasAttribute(Attribute::MinSize)) return "minsize"; if (hasAttribute(Attribute::Naked)) Index: lib/IR/CMakeLists.txt =================================================================== --- lib/IR/CMakeLists.txt +++ lib/IR/CMakeLists.txt @@ -46,18 +46,19 @@ PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp + ProfileSummary.cpp RemarkStreamer.cpp SafepointIRVerifier.cpp - ProfileSummary.cpp Statepoint.cpp Type.cpp TypeFinder.cpp Use.cpp User.cpp + VPBuilder.cpp Value.cpp ValueSymbolTable.cpp Verifier.cpp - ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/IR Index: lib/IR/IntrinsicInst.cpp =================================================================== --- lib/IR/IntrinsicInst.cpp +++ lib/IR/IntrinsicInst.cpp @@ -102,41 +102,353 @@ return ConstantInt::get(Type::getInt64Ty(Context), 1); } -ConstrainedFPIntrinsic::RoundingMode +static +RoundingMode +DecodeRoundingMode(StringRef RoundingArg) { + // For dynamic rounding mode, we use round to nearest but we will set the + // 'exact' SDNodeFlag so that the value will not be rounded. + return StringSwitch(RoundingArg) + .Case("round.dynamic", RoundingMode::rmDynamic) + .Case("round.tonearest", RoundingMode::rmToNearest) + .Case("round.downward", RoundingMode::rmDownward) + .Case("round.upward", RoundingMode::rmUpward) + .Case("round.towardzero", RoundingMode::rmTowardZero) + .Default(RoundingMode::rmInvalid); +} + +static +ExceptionBehavior +DecodeExceptionBehavior(StringRef ExceptionArg) { + return StringSwitch(ExceptionArg) + .Case("fpexcept.ignore", ExceptionBehavior::ebIgnore) + .Case("fpexcept.maytrap", ExceptionBehavior::ebMayTrap) + .Case("fpexcept.strict", ExceptionBehavior::ebStrict) + .Default(ExceptionBehavior::ebInvalid); +} + +RoundingMode ConstrainedFPIntrinsic::getRoundingMode() const { unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 2 && "underflow"); Metadata *MD = dyn_cast(getArgOperand(NumOperands - 2))->getMetadata(); if (!MD || !isa(MD)) - return rmInvalid; + return RoundingMode::rmInvalid; StringRef RoundingArg = cast(MD)->getString(); - - // For dynamic rounding mode, we use round to nearest but we will set the - // 'exact' SDNodeFlag so that the value will not be rounded. - return StringSwitch(RoundingArg) - .Case("round.dynamic", rmDynamic) - .Case("round.tonearest", rmToNearest) - .Case("round.downward", rmDownward) - .Case("round.upward", rmUpward) - .Case("round.towardzero", rmTowardZero) - .Default(rmInvalid); + return DecodeRoundingMode(RoundingArg); } -ConstrainedFPIntrinsic::ExceptionBehavior +ExceptionBehavior ConstrainedFPIntrinsic::getExceptionBehavior() const { unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 1 && "underflow"); Metadata *MD = dyn_cast(getArgOperand(NumOperands - 1))->getMetadata(); if (!MD || !isa(MD)) - return ebInvalid; + return ExceptionBehavior::ebInvalid; StringRef ExceptionArg = cast(MD)->getString(); - return StringSwitch(ExceptionArg) - .Case("fpexcept.ignore", ebIgnore) - .Case("fpexcept.maytrap", ebMayTrap) - .Case("fpexcept.strict", ebStrict) - .Default(ebInvalid); + return DecodeExceptionBehavior(ExceptionArg); +} + +CmpInst::Predicate +VPIntrinsic::getCmpPredicate() const { + return static_cast(cast(getArgOperand(4))->getZExtValue()); +} + +RoundingMode +VPIntrinsic::getRoundingMode() const { + unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 4 && "underflow"); + Metadata *MD = + dyn_cast(getArgOperand(NumOperands - 4))->getMetadata(); + if (!MD || !isa(MD)) + return RoundingMode::rmInvalid; + StringRef RoundingArg = cast(MD)->getString(); + return DecodeRoundingMode(RoundingArg); +} + +ExceptionBehavior +VPIntrinsic::getExceptionBehavior() const { + unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 3 && "underflow"); + Metadata *MD = + dyn_cast(getArgOperand(NumOperands - 3))->getMetadata(); + if (!MD || !isa(MD)) + return ExceptionBehavior::ebInvalid; + StringRef ExceptionArg = cast(MD)->getString(); + return DecodeExceptionBehavior(ExceptionArg); } +bool VPIntrinsic::isUnaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + case Intrinsic::vp_fneg: + case Intrinsic::vp_constrained_sin: + case Intrinsic::vp_constrained_cos: + case Intrinsic::vp_constrained_exp: + case Intrinsic::vp_constrained_exp2: + case Intrinsic::vp_constrained_log: + case Intrinsic::vp_constrained_log10: + case Intrinsic::vp_constrained_log2: + case Intrinsic::vp_constrained_sqrt: + case Intrinsic::vp_constrained_ceil: + case Intrinsic::vp_constrained_floor: + case Intrinsic::vp_constrained_round: + case Intrinsic::vp_constrained_trunc: + case Intrinsic::vp_constrained_rint: + case Intrinsic::vp_constrained_nearbyint: + return true; + } +} + +Value* +VPIntrinsic::getMask() const { + int offset = 0; + if (isConstrainedOp()) offset += 2; // skip rounding, exception args + + if (isBinaryOp()) { return getArgOperand(offset + 2); } + else if (isTernaryOp()) { return getArgOperand(offset + 3); } + else if (isUnaryOp()) { return getArgOperand(offset + 1); } + else return nullptr; +} + +Value* +VPIntrinsic::getVectorLength() const { + int offset = 0; + if (isConstrainedOp()) offset += 2; // skip rounding, exception args + + if (isBinaryOp()) { return getArgOperand(offset + 3); } + else if (isTernaryOp()) { return getArgOperand(offset + 4); } + else if (isUnaryOp()) { return getArgOperand(offset + 2); } + else return nullptr; +} + +bool VPIntrinsic::isReductionOp() const { + switch (getIntrinsicID()) { + default: + return false; + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + + return true; + } +} + +bool VPIntrinsic::isConstrainedOp() const { + switch (getIntrinsicID()) { + default: + return false; + + case Intrinsic::vp_constrained_fadd: + case Intrinsic::vp_constrained_fsub: + case Intrinsic::vp_constrained_fmul: + case Intrinsic::vp_constrained_fdiv: + case Intrinsic::vp_constrained_frem: + case Intrinsic::vp_constrained_fma: + case Intrinsic::vp_constrained_sqrt: + case Intrinsic::vp_constrained_pow: + case Intrinsic::vp_constrained_powi: + case Intrinsic::vp_constrained_sin: + case Intrinsic::vp_constrained_cos: + case Intrinsic::vp_constrained_exp: + case Intrinsic::vp_constrained_exp2: + case Intrinsic::vp_constrained_log: + case Intrinsic::vp_constrained_log10: + case Intrinsic::vp_constrained_log2: + case Intrinsic::vp_constrained_rint: + case Intrinsic::vp_constrained_nearbyint: + case Intrinsic::vp_constrained_maxnum: + case Intrinsic::vp_constrained_minnum: + case Intrinsic::vp_constrained_ceil: + case Intrinsic::vp_constrained_floor: + case Intrinsic::vp_constrained_round: + case Intrinsic::vp_constrained_trunc: + return true; + } +} + +bool VPIntrinsic::isBinaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_umin: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_fmin: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + case Intrinsic::vp_constrained_fadd: + case Intrinsic::vp_constrained_fsub: + case Intrinsic::vp_constrained_fmul: + case Intrinsic::vp_constrained_fdiv: + case Intrinsic::vp_constrained_frem: + case Intrinsic::vp_constrained_pow: + case Intrinsic::vp_constrained_powi: + case Intrinsic::vp_constrained_maxnum: + case Intrinsic::vp_constrained_minnum: + + return true; + } +} + +bool VPIntrinsic::isTernaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + case Intrinsic::vp_compose: + case Intrinsic::vp_select: + case Intrinsic::vp_fma: + case Intrinsic::experimental_constrained_fma: + return true; + } +} + +VPIntrinsic::VPIntrinsicDesc +VPIntrinsic::GetVPDescForIntrinsic(unsigned IntrinsicID) { + switch (IntrinsicID) { + default: + return VPIntrinsicDesc{Intrinsic::not_intrinsic, TypeTokenVec(), -1, -1}; + + // llvm.experimental.constrained.* + case Intrinsic::experimental_constrained_cos: return VPIntrinsicDesc{ Intrinsic::vp_constrained_cos, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_sin: return VPIntrinsicDesc{ Intrinsic::vp_constrained_sin, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_exp: return VPIntrinsicDesc{ Intrinsic::vp_constrained_exp, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_exp2: return VPIntrinsicDesc{ Intrinsic::vp_constrained_exp2, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_log: return VPIntrinsicDesc{ Intrinsic::vp_constrained_log, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_log2: return VPIntrinsicDesc{ Intrinsic::vp_constrained_log2, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_log10: return VPIntrinsicDesc{ Intrinsic::vp_constrained_log10, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_sqrt: return VPIntrinsicDesc{ Intrinsic::vp_constrained_sqrt, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_ceil: return VPIntrinsicDesc{ Intrinsic::vp_constrained_ceil, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_floor: return VPIntrinsicDesc{ Intrinsic::vp_constrained_floor, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_round: return VPIntrinsicDesc{ Intrinsic::vp_constrained_round, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_trunc: return VPIntrinsicDesc{ Intrinsic::vp_constrained_trunc, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_rint: return VPIntrinsicDesc{ Intrinsic::vp_constrained_rint, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + case Intrinsic::experimental_constrained_nearbyint: return VPIntrinsicDesc{ Intrinsic::vp_constrained_nearbyint, TypeTokenVec{VPTypeToken::Vector}, 3, 4}; break; + + case Intrinsic::experimental_constrained_fadd: return VPIntrinsicDesc{ Intrinsic::vp_constrained_fadd, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_fsub: return VPIntrinsicDesc{ Intrinsic::vp_constrained_fsub, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_fmul: return VPIntrinsicDesc{ Intrinsic::vp_constrained_fmul, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_fdiv: return VPIntrinsicDesc{ Intrinsic::vp_constrained_fdiv, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_frem: return VPIntrinsicDesc{ Intrinsic::vp_constrained_frem, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_pow: return VPIntrinsicDesc{ Intrinsic::vp_constrained_pow, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_powi: return VPIntrinsicDesc{ Intrinsic::vp_constrained_powi, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_maxnum: return VPIntrinsicDesc{ Intrinsic::vp_constrained_maxnum, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + case Intrinsic::experimental_constrained_minnum: return VPIntrinsicDesc{ Intrinsic::vp_constrained_minnum, TypeTokenVec{VPTypeToken::Vector}, 4, 5}; break; + + case Intrinsic::experimental_constrained_fma: return VPIntrinsicDesc{ Intrinsic::vp_constrained_fma, TypeTokenVec{VPTypeToken::Vector}, 5, 6}; break; + } +} + +VPIntrinsic::VPIntrinsicDesc +VPIntrinsic::GetVPIntrinsicDesc(unsigned OC) { + switch (OC) { + default: + return VPIntrinsicDesc{Intrinsic::not_intrinsic, TypeTokenVec(), -1, -1}; + + // fp unary + case Instruction::FNeg: return VPIntrinsicDesc{ Intrinsic::vp_fneg, TypeTokenVec{VPTypeToken::Vector}, 1, 2}; break; + + // fp binary + case Instruction::FAdd: return VPIntrinsicDesc{ Intrinsic::vp_fadd, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::FSub: return VPIntrinsicDesc{ Intrinsic::vp_fsub, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::FMul: return VPIntrinsicDesc{ Intrinsic::vp_fmul, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::FDiv: return VPIntrinsicDesc{ Intrinsic::vp_fdiv, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::FRem: return VPIntrinsicDesc{ Intrinsic::vp_frem, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + + // sign-oblivious int + case Instruction::Add: return VPIntrinsicDesc{ Intrinsic::vp_add, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::Sub: return VPIntrinsicDesc{ Intrinsic::vp_sub, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::Mul: return VPIntrinsicDesc{ Intrinsic::vp_mul, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + + // signed/unsigned int + case Instruction::SDiv: return VPIntrinsicDesc{ Intrinsic::vp_sdiv, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::UDiv: return VPIntrinsicDesc{ Intrinsic::vp_udiv, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::SRem: return VPIntrinsicDesc{ Intrinsic::vp_srem, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::URem: return VPIntrinsicDesc{ Intrinsic::vp_urem, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + + // logical + case Instruction::Or: return VPIntrinsicDesc{ Intrinsic::vp_or, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::And: return VPIntrinsicDesc{ Intrinsic::vp_and, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::Xor: return VPIntrinsicDesc{ Intrinsic::vp_xor, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + + case Instruction::LShr: return VPIntrinsicDesc{ Intrinsic::vp_lshr, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::AShr: return VPIntrinsicDesc{ Intrinsic::vp_ashr, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + case Instruction::Shl: return VPIntrinsicDesc{ Intrinsic::vp_shl, TypeTokenVec{VPTypeToken::Vector}, 2, 3}; break; + + // comparison + case Instruction::ICmp: + case Instruction::FCmp: + return VPIntrinsicDesc{ Intrinsic::vp_cmp, TypeTokenVec{VPTypeToken::Mask, VPTypeToken::Vector}, 2, 3}; break; + } +} + +VPIntrinsic::ShortTypeVec +VPIntrinsic::EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type & VectorTy, Type & ScalarTy) { + ShortTypeVec STV; + + for (auto Token : TTVec) { + switch (Token) { + default: + llvm_unreachable("unsupported token"); // unsupported VPTypeToken + + case VPIntrinsic::VPTypeToken::Scalar: STV.push_back(&ScalarTy); break; + case VPIntrinsic::VPTypeToken::Vector: STV.push_back(&VectorTy); break; + case VPIntrinsic::VPTypeToken::Mask: + auto NumElems = VectorTy.getVectorNumElements(); + auto MaskTy = VectorType::get(Type::getInt1Ty(VectorTy.getContext()), NumElems); + STV.push_back(MaskTy); break; + } + } + + return STV; +} + + bool ConstrainedFPIntrinsic::isUnaryOp() const { switch (getIntrinsicID()) { default: Index: lib/IR/PredicatedInst.cpp =================================================================== --- /dev/null +++ lib/IR/PredicatedInst.cpp @@ -0,0 +1,61 @@ +#include +#include +#include +#include +#include + +namespace { + using namespace llvm; + using ShortValueVec = SmallVector; +} + +namespace llvm { + +void +PredicatedOperator::copyIRFlags(const Value * V, bool IncludeWrapFlags) { + auto * I = dyn_cast(this); + if (I) I->copyIRFlags(V, IncludeWrapFlags); +} + +Instruction* +PredicatedBinaryOperator::Create(Module * Mod, + Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock * InsertAtEnd, + Instruction * InsertBefore) { + assert(!(InsertAtEnd && InsertBefore)); + + auto evlDesc = VPIntrinsic::GetVPIntrinsicDesc(Opc); + + if ((!Mod || + (!Mask && !VectorLen)) || + evlDesc.ID == Intrinsic::not_intrinsic) { + if (InsertAtEnd) { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertAtEnd); + } else { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertBefore); + } + } + + assert(Mod && "Need a module to emit VP Intrinsics"); + + // Fetch the VP intrinsic + auto & VecTy = cast(*V1->getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto * Func = Intrinsic::getDeclaration(Mod, evlDesc.ID, VPIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + // Materialize the Call + ShortValueVec Args{V1, V2, Mask, VectorLen}; + + if (InsertAtEnd) { + return CallInst::Create(Func, {V1, V2, Mask, VectorLen}, Name, InsertAtEnd); + } else { + return CallInst::Create(Func, {V1, V2, Mask, VectorLen}, Name, InsertBefore); + } +} + +} Index: lib/IR/VPBuilder.cpp =================================================================== --- /dev/null +++ lib/IR/VPBuilder.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include + +namespace { + using namespace llvm; + using ShortTypeVec = VPIntrinsic::ShortTypeVec; + using ShortValueVec = SmallVector; +} + +namespace llvm { + +Module & +VPBuilder::getModule() const { + return *Builder.GetInsertBlock()->getParent()->getParent(); +} + +Value& +VPBuilder::GetMaskForType(VectorType & VecTy) { + if (Mask) return *Mask; + + auto * boolTy = Builder.getInt1Ty(); + auto * maskTy = VectorType::get(boolTy, StaticVectorLength); + return *ConstantInt::getAllOnesValue(maskTy); +} + +Value& +VPBuilder::GetEVLForType(VectorType & VecTy) { + if (ExplicitVectorLength) return *ExplicitVectorLength; + + auto * intTy = Builder.getInt32Ty(); + return *ConstantInt::get(intTy, StaticVectorLength); +} + +Value* +VPBuilder::CreateVectorCopy(Instruction & Inst, ValArray VecOpArray) { + + auto oc = Inst.getOpcode(); + + auto evlDesc = VPIntrinsic::GetVPIntrinsicDesc(oc); + if (evlDesc.ID == Intrinsic::not_intrinsic) { + return nullptr; + } + + if ((oc <= Instruction::BinaryOpsEnd) && + (oc >= Instruction::BinaryOpsBegin)) { + + assert(VecOpArray.size() == 2); + Value & FirstOp = *VecOpArray[0]; + Value & SndOp = *VecOpArray[1]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*FirstOp.getType()); + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + auto & VPCall = + cast(*PredicatedBinaryOperator::Create(&getModule(), &GetMaskForType(VecTy), &GetEVLForType(VecTy), static_cast(oc), &FirstOp, &SndOp)); + Builder.Insert(&VPCall); + + // transfer fast math flags + if (isa(Inst)) { + VPCall.copyFastMathFlags(Inst.getFastMathFlags()); + } + + return &VPCall; + } + + if ((oc <= Instruction::UnaryOpsBegin) && + (oc >= Instruction::UnaryOpsEnd)) { + assert(VecOpArray.size() == 1); + Value & FirstOp = *VecOpArray[0]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*FirstOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto * Func = Intrinsic::getDeclaration(&getModule(), evlDesc.ID, VPIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 1) && (evlDesc.EVLPos == 2)); + + // Materialize the Call + ShortValueVec Args{&FirstOp, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + + auto & VPCall = *Builder.CreateCall(Func, Args); + + // transfer fast math flags + if (isa(Inst)) { + cast(VPCall).copyFastMathFlags(Inst.getFastMathFlags()); + } + + return &VPCall; + } + + switch (oc) { + default: + return nullptr; + + case Instruction::FCmp: + case Instruction::ICmp: { + assert(VecOpArray.size() == 2); + Value & FirstOp = *VecOpArray[0]; + Value & SndOp = *VecOpArray[1]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*FirstOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto * Func = Intrinsic::getDeclaration(&getModule(), evlDesc.ID, VPIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + // encode comparison predicate as MD + uint8_t RawPred = cast(Inst).getPredicate(); + auto Int8Ty = Builder.getInt8Ty(); + auto PredArg = ConstantInt::get(Int8Ty, RawPred, false); + + // Materialize the Call + ShortValueVec Args{&FirstOp, &SndOp, &GetMaskForType(VecTy), &GetEVLForType(VecTy), PredArg}; + + return Builder.CreateCall(Func, Args); + } + + case Instruction::Select: { + assert(VecOpArray.size() == 2); + Value & MaskOp = *VecOpArray[0]; + Value & OnTrueOp = *VecOpArray[1]; + Value & OnFalseOp = *VecOpArray[2]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*OnTrueOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + + auto * Func = Intrinsic::getDeclaration(&getModule(), evlDesc.ID, VPIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + // Materialize the Call + ShortValueVec Args{&OnTrueOp, &OnFalseOp, &MaskOp, &GetEVLForType(VecTy)}; + + return Builder.CreateCall(Func, Args); + } + } +} + +VectorType& +VPBuilder::getVectorType(Type &ElementTy) { + return *VectorType::get(&ElementTy, StaticVectorLength); +} + +Value& +VPBuilder::CreateContiguousStore(Value & Val, Value & Pointer, unsigned Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * StoreFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_store, {Val.getType(), Pointer.getType()}); + ShortValueVec Args{&Val, &Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &StoreCall = *Builder.CreateCall(StoreFunc, Args); + if (Alignment) StoreCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return StoreCall; +} + +Value& +VPBuilder::CreateContiguousLoad(Value & Pointer, unsigned Alignment) { + auto & PointerTy = cast(*Pointer.getType()); + auto & VecTy = getVectorType(*PointerTy.getPointerElementType()); + + auto * LoadFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_load, {&VecTy, &PointerTy}); + ShortValueVec Args{&Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &LoadCall= *Builder.CreateCall(LoadFunc, Args); + if (Alignment) LoadCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return LoadCall; +} + +Value& +VPBuilder::CreateScatter(Value & Val, Value & PointerVec, unsigned Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * ScatterFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_scatter, {Val.getType(), PointerVec.getType()}); + ShortValueVec Args{&Val, &PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &ScatterCall = *Builder.CreateCall(ScatterFunc, Args); + if (Alignment) ScatterCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return ScatterCall; +} + +Value& +VPBuilder::CreateGather(Value & PointerVec, unsigned Alignment) { + auto & PointerVecTy = cast(*PointerVec.getType()); + auto & ElemTy = *cast(*PointerVecTy.getVectorElementType()).getPointerElementType(); + auto & VecTy = *VectorType::get(&ElemTy, PointerVecTy.getNumElements()); + auto * GatherFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_gather, {&VecTy, &PointerVecTy}); + + ShortValueVec Args{&PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &GatherCall = *Builder.CreateCall(GatherFunc, Args); + if (Alignment) GatherCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return GatherCall; +} + +} // namespace llvm Index: lib/IR/Verifier.cpp =================================================================== --- lib/IR/Verifier.cpp +++ lib/IR/Verifier.cpp @@ -472,6 +472,7 @@ void visitUserOp2(Instruction &I) { visitUserOp1(I); } void visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call); void visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI); + void visitVPIntrinsic(VPIntrinsic &FPI); void visitDbgIntrinsic(StringRef Kind, DbgVariableIntrinsic &DII); void visitDbgLabelIntrinsic(StringRef Kind, DbgLabelInst &DLI); void visitAtomicCmpXchgInst(AtomicCmpXchgInst &CXI); @@ -1665,11 +1666,14 @@ if (Attrs.isEmpty()) return; + bool SawMask = false; bool SawNest = false; + bool SawPassthru = false; bool SawReturned = false; bool SawSRet = false; bool SawSwiftSelf = false; bool SawSwiftError = false; + bool SawVectorLength = false; // Verify return value attributes. AttributeSet RetAttrs = Attrs.getRetAttributes(); @@ -1737,12 +1741,33 @@ SawSwiftError = true; } + if (ArgAttrs.hasAttribute(Attribute::VectorLength)) { + Assert(!SawVectorLength, "Cannot have multiple 'vlen' parameters!", + V); + SawVectorLength = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Passthru)) { + Assert(!SawPassthru, "Cannot have multiple 'passthru' parameters!", + V); + SawPassthru = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Mask)) { + Assert(!SawMask, "Cannot have multiple 'mask' parameters!", + V); + SawMask = true; + } + if (ArgAttrs.hasAttribute(Attribute::InAlloca)) { Assert(i == FT->getNumParams() - 1, "inalloca isn't on the last parameter!", V); } } + Assert(!SawPassthru || SawMask, + "Cannot have 'passthru' parameter without 'mask' parameter!", V); + if (!Attrs.hasAttributes(AttributeList::FunctionIndex)) return; @@ -3082,7 +3107,7 @@ /// visitUnaryOperator - Check the argument to the unary operator. /// void Verifier::visitUnaryOperator(UnaryOperator &U) { - Assert(U.getType() == U.getOperand(0)->getType(), + Assert(U.getType() == U.getOperand(0)->getType(), "Unary operators must have same type for" "operands and result!", &U); @@ -4228,6 +4253,88 @@ case Intrinsic::experimental_constrained_trunc: visitConstrainedFPIntrinsic(cast(Call)); break; + + case Intrinsic::vp_cmp: + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + + case Intrinsic::vp_constrained_fadd: + case Intrinsic::vp_constrained_fsub: + case Intrinsic::vp_constrained_fmul: + case Intrinsic::vp_constrained_fdiv: + case Intrinsic::vp_constrained_frem: + case Intrinsic::vp_constrained_fma: + case Intrinsic::vp_constrained_sqrt: + case Intrinsic::vp_constrained_pow: + case Intrinsic::vp_constrained_powi: + case Intrinsic::vp_constrained_sin: + case Intrinsic::vp_constrained_cos: + case Intrinsic::vp_constrained_exp: + case Intrinsic::vp_constrained_exp2: + case Intrinsic::vp_constrained_log: + case Intrinsic::vp_constrained_log10: + case Intrinsic::vp_constrained_log2: + case Intrinsic::vp_constrained_rint: + case Intrinsic::vp_constrained_nearbyint: + case Intrinsic::vp_constrained_maxnum: + case Intrinsic::vp_constrained_minnum: + case Intrinsic::vp_constrained_ceil: + case Intrinsic::vp_constrained_floor: + case Intrinsic::vp_constrained_round: + case Intrinsic::vp_constrained_trunc: + visitVPIntrinsic(cast(Call)); + break; + case Intrinsic::dbg_declare: // llvm.dbg.declare Assert(isa(Call.getArgOperand(0)), "invalid llvm.dbg.declare intrinsic call 1", Call); @@ -4639,6 +4746,15 @@ return nullptr; } +void Verifier::visitVPIntrinsic(VPIntrinsic &VPI) { + if (VPI.isConstrainedOp()) { + Assert(VPI.getExceptionBehavior() != ExceptionBehavior::ebInvalid, + "invalid exception behavior argument", &VPI); + Assert(VPI.getRoundingMode() != RoundingMode::rmInvalid, + "invalid rounding mode argument", &VPI); + } +} + void Verifier::visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI) { unsigned NumOperands = FPI.getNumArgOperands(); bool HasExceptionMD = false; @@ -4696,11 +4812,11 @@ // argument type check is needed here. if (HasExceptionMD) { - Assert(FPI.getExceptionBehavior() != ConstrainedFPIntrinsic::ebInvalid, + Assert(FPI.getExceptionBehavior() != ExceptionBehavior::ebInvalid, "invalid exception behavior argument", &FPI); } if (HasRoundingMD) { - Assert(FPI.getRoundingMode() != ConstrainedFPIntrinsic::rmInvalid, + Assert(FPI.getRoundingMode() != RoundingMode::rmInvalid, "invalid rounding mode argument", &FPI); } } @@ -4946,7 +5062,7 @@ bool runOnFunction(Function &F) override { if (!V->verify(F) && FatalErrors) { - errs() << "in function " << F.getName() << '\n'; + errs() << "in function " << F.getName() << '\n'; report_fatal_error("Broken function found, compilation aborted!"); } return false; Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -24,6 +24,9 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/VPBuilder.h" +#include "llvm/IR/MatcherCast.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AlignOf.h" @@ -1817,6 +1820,17 @@ return Changed ? &I : nullptr; } +Instruction *InstCombiner::visitPredicatedFSub(PredicatedBinaryOperator& I) { + auto * Inst = cast(&I); + PredicatedContext PC(&I); + if (Value *V = SimplifyPredicatedFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(Inst), PC)) + return replaceInstUsesWith(*Inst, V); + + return visitFSubGeneric(*Inst); +} + Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -1826,11 +1840,19 @@ if (Instruction *X = foldVectorBinop(I)) return X; + return visitFSubGeneric(I); +} + +template +Instruction *InstCombiner::visitFSubGeneric(BinaryOpTy &I) { + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Subtraction from -0.0 is the canonical form of fneg. // fsub nsz 0, X ==> fsub nsz -0.0, X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) - return BinaryOperator::CreateFNegFMF(Op1, &I); + if (I.hasNoSignedZeros() && MC.try_match(Op0, m_PosZeroFP())) + return MCBuilder.CreateFNegFMF(Op1, &I); Value *X, *Y; Constant *C; @@ -1838,14 +1860,14 @@ // Fold negation into constant operand. This is limited with one-use because // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. // -(X * C) --> X * (-C) - if (match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); // -(X / C) --> X / (-C) - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); // -(C / X) --> (-C) / X - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) - return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) + return MCBuilder.CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. @@ -1854,71 +1876,75 @@ // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { - if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { - Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); - return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + if (MC.try_match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = MCBuilder.CreateFSubFMF(Builder, Y, X, &I); + return MCBuilder.CreateFAddFMF(Op0, NewSub, &I); } } - if (isa(Op0)) - if (SelectInst *SI = dyn_cast(Op1)) - if (Instruction *NV = FoldOpIntoSelect(I, SI)) - return NV; + if (auto * PlainBinOp = dyn_cast(&I)) + if (isa(Op0)) + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *NV = FoldOpIntoSelect(*PlainBinOp, SI)) + return NV; // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - if (match(Op1, m_Constant(C)) && !isa(Op1)) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(Op1, m_Constant(C)) && !isa(Op1)) + return MCBuilder.CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); // X - (-Y) --> X + Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + if (MC.try_match(Op1, m_FNeg(m_Value(Y)))) + return MCBuilder.CreateFAddFMF(Op0, Y, &I); // Similar to above, but look through a cast of the negated value: // X - (fptrunc(-Y)) --> X + fptrunc(Y) Type *Ty = I.getType(); - if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPTrunc(Builder, Y, Ty), &I); // X - (fpext(-Y)) --> X + fpext(Y) - if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPExt(Builder, Y, Ty), &I); // Handle special cases for FSub with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + if (auto * PlainBinOp = dyn_cast(&I)) + if (Value *V = SimplifySelectsFeedingBinaryOp(*PlainBinOp, Op0, Op1)) + return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X - if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X - if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) - if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + if (MC.try_match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); - return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + return MCBuilder.CreateFMulFMF(Op1, CSubOne, &I); } // X - (X * C) --> X * (1.0 - C) - if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + if (MC.try_match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); - return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + return MCBuilder.CreateFMulFMF(Op0, OneSubC, &I); } - if (Instruction *F = factorizeFAddFSub(I, Builder)) - return F; + if (auto * PlainBinOp = dyn_cast(&I)) { + if (Instruction *F = factorizeFAddFSub(*PlainBinOp, Builder)) + return F; - // TODO: This performs reassociative folds for FP ops. Some fraction of the - // functionality has been subsumed by simple pattern matching here and in - // InstSimplify. We should let a dedicated reassociation pass handle more - // complex pattern matching and remove this from InstCombine. - if (Value *V = FAddCombine(Builder).simplify(&I)) - return replaceInstUsesWith(I, V); + // TODO: This performs reassociative folds for FP ops. Some fraction of the + // functionality has been subsumed by simple pattern matching here and in + // InstSimplify. We should let a dedicated reassociation pass handle more + // complex pattern matching and remove this from InstCombine. + if (Value *V = FAddCombine(Builder).simplify(PlainBinOp)) + return replaceInstUsesWith(*PlainBinOp, V); + } } return nullptr; Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" @@ -1894,6 +1895,14 @@ return &CI; } + // Predicated instruction patterns + auto * VPInst = dyn_cast(&CI); + if (VPInst) { + auto * PredInst = cast(VPInst); + auto Result = visitPredicatedInstruction(PredInst); + if (Result) return Result; + } + IntrinsicInst *II = dyn_cast(&CI); if (!II) return visitCallBase(CI); @@ -1958,7 +1967,8 @@ if (Changed) return II; } - // For vector result intrinsics, use the generic demanded vector support. + // For vector result intrinsics, use the generic demanded vector support to + // simplify any operands before moving on to the per-intrinsic rules. if (II->getType()->isVectorTy()) { auto VWidth = II->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -30,6 +30,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" @@ -351,6 +352,8 @@ Instruction *visitFAdd(BinaryOperator &I); Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty); Instruction *visitSub(BinaryOperator &I); + template Instruction *visitFSubGeneric(BinaryOpTy &I); + Instruction *visitPredicatedFSub(PredicatedBinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); @@ -420,6 +423,16 @@ Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); + // Entry point to VPIntrinsic + Instruction *visitPredicatedInstruction(PredicatedInstruction * PI) { + switch (PI->getOpcode()) { + default: + return nullptr; + case Instruction::FSub: + return visitPredicatedFSub(cast(*PI)); + } + } + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -783,6 +783,7 @@ case Attribute::InaccessibleMemOnly: case Attribute::InaccessibleMemOrArgMemOnly: case Attribute::JumpTable: + case Attribute::Mask: case Attribute::Naked: case Attribute::Nest: case Attribute::NoAlias: @@ -791,6 +792,7 @@ case Attribute::NoReturn: case Attribute::None: case Attribute::NonNull: + case Attribute::Passthru: case Attribute::ReadNone: case Attribute::ReadOnly: case Attribute::Returned: @@ -801,6 +803,7 @@ case Attribute::StructRet: case Attribute::SwiftError: case Attribute::SwiftSelf: + case Attribute::VectorLength: case Attribute::WriteOnly: case Attribute::ZExt: case Attribute::ImmArg: Index: test/Bitcode/attributes.ll =================================================================== --- test/Bitcode/attributes.ll +++ test/Bitcode/attributes.ll @@ -351,6 +351,11 @@ ret void } +; CHECK: define <8 x double> @f60(<8 x double> passthru, <8 x i1> mask, i32 vlen) { +define <8 x double> @f60(<8 x double> passthru, <8 x i1> mask, i32 vlen) { + ret <8 x double> undef +} + ; CHECK: attributes #0 = { noreturn } ; CHECK: attributes #1 = { nounwind } ; CHECK: attributes #2 = { readnone } Index: test/Transforms/InstCombine/vp-fsub.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/vp-fsub.ll @@ -0,0 +1,43 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; PR4374 + +define <4 x float> @test1_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @test1_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) + %t2 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, <4 x i1> %M, i32 %L) + ret <4 x float> %t2 +} + +; Can't do anything with the test above because -0.0 - 0.0 = -0.0, but if we have nsz: +; -(X - Y) --> Y - X + +; TODO predicated FAdd folding +define <4 x float> @neg_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CH***-LABEL: @neg_sub_nsz_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, <4 x i1> %M, i32 %L) + ret <4 x float> %t2 +} + +; With nsz: Z - (X - Y) --> Z + (Y - X) + +define <4 x float> @sub_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x float> %z, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @sub_sub_nsz_vp( +; CHECK-NEXT: %1 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %y, <4 x float> %x, <4 x i1> %M, i32 %L) +; CHECK-NEXT: %t2 = call nsz <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %z, <4 x float> %1, <4 x i1> %M, i32 %L) +; CHECK-NEXT: ret <4 x float> %t2 + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %z, <4 x float> %t1, <4 x i1> %M, i32 %L) + ret <4 x float> %t2 +} + + + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, <4 x i1> mask, i32 vlen) #0 + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, <4 x i1> mask, i32 vlen) #0 Index: test/Transforms/InstSimplify/vp-fsub.ll =================================================================== --- /dev/null +++ test/Transforms/InstSimplify/vp-fsub.ll @@ -0,0 +1,43 @@ +; RUN: opt < %s -instsimplify -S | FileCheck %s + +define <8 x double> @fsub_fadd_fold_vp_xy(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_xy +; CHECK-NEXT: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) + %res = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %x +} + +define <8 x double> @fsub_fadd_fold_vp_yx(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx +; CHECK-NEXT: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) + %res = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %x +} + +define <8 x double> @fsub_fadd_fold_vp_yx_olen(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, i32 %otherLen) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_olen +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %otherLen) +; CHECK-NEXT: %res = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) +; CHECK-NEXT: ret <8 x double> %res + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %otherLen) + %res = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %res +} + +define <8 x double> @fsub_fadd_fold_vp_yx_omask(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, <8 x i1> %othermask) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_omask +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) +; CHECK-NEXT: %res = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %othermask, i32 %len) +; CHECK-NEXT: ret <8 x double> %res + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) + %res = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %othermask, i32 %len) + ret <8 x double> %res +} + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, <8 x i1> mask, i32 vlen) #0 + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, <8 x i1> mask, i32 vlen) #0 Index: test/Verifier/evl_attribs.ll =================================================================== --- /dev/null +++ test/Verifier/evl_attribs.ll @@ -0,0 +1,13 @@ +; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s + +declare void @a(<16 x i1> mask %a, <16 x i1> mask %b) +; CHECK: Cannot have multiple 'mask' parameters! + +declare void @b(<16 x i1> mask %a, i32 vlen %x, i32 vlen %y) +; CHECK: Cannot have multiple 'vlen' parameters! + +declare <16 x double> @c(<16 x double> passthru %a) +; CHECK: Cannot have 'passthru' parameter without 'mask' parameter! + +declare <16 x double> @d(<16 x double> passthru %a, <16 x i1> mask %M, <16 x double> passthru %b) +; CHECK: Cannot have multiple 'passthru' parameters! Index: utils/TableGen/CodeGenIntrinsics.h =================================================================== --- utils/TableGen/CodeGenIntrinsics.h +++ utils/TableGen/CodeGenIntrinsics.h @@ -142,7 +142,10 @@ ReadOnly, WriteOnly, ReadNone, - ImmArg + ImmArg, + Mask, + VectorLength, + Passthru }; std::vector> ArgumentAttributes; Index: utils/TableGen/CodeGenTarget.cpp =================================================================== --- utils/TableGen/CodeGenTarget.cpp +++ utils/TableGen/CodeGenTarget.cpp @@ -612,10 +612,10 @@ "Expected iAny or vAny type"); } else { VT = getValueType(TyEl->getValueAsDef("VT")); - } - if (MVT(VT).isOverloaded()) { - OverloadedVTs.push_back(VT); - isOverloaded = true; + if (MVT(VT).isOverloaded()) { + OverloadedVTs.push_back(VT); + isOverloaded = true; + } } // Reject invalid types. @@ -651,14 +651,15 @@ !TyEl->isSubClassOf("LLVMScalarOrSameVectorWidth")) || VT == MVT::iAny || VT == MVT::vAny) && "Expected iAny or vAny type"); - } else + } else { VT = getValueType(TyEl->getValueAsDef("VT")); - - if (MVT(VT).isOverloaded()) { - OverloadedVTs.push_back(VT); - isOverloaded = true; + if (MVT(VT).isOverloaded()) { + OverloadedVTs.push_back(VT); + isOverloaded = true; + } } + // Reject invalid types. if (VT == MVT::isVoid && i != e-1 /*void at end means varargs*/) PrintFatalError(DefLoc, "Intrinsic '" + DefName + @@ -710,6 +711,15 @@ } else if (Property->isSubClassOf("Returned")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, Returned)); + } else if (Property->isSubClassOf("VectorLength")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, VectorLength)); + } else if (Property->isSubClassOf("Mask")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Mask)); + } else if (Property->isSubClassOf("Passthru")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Passthru)); } else if (Property->isSubClassOf("ReadOnly")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, ReadOnly)); Index: utils/TableGen/IntrinsicEmitter.cpp =================================================================== --- utils/TableGen/IntrinsicEmitter.cpp +++ utils/TableGen/IntrinsicEmitter.cpp @@ -593,6 +593,24 @@ OS << "Attribute::Returned"; addComma = true; break; + case CodeGenIntrinsic::VectorLength: + if (addComma) + OS << ","; + OS << "Attribute::VectorLength"; + addComma = true; + break; + case CodeGenIntrinsic::Mask: + if (addComma) + OS << ","; + OS << "Attribute::Mask"; + addComma = true; + break; + case CodeGenIntrinsic::Passthru: + if (addComma) + OS << ","; + OS << "Attribute::Passthru"; + addComma = true; + break; case CodeGenIntrinsic::ReadOnly: if (addComma) OS << ",";