diff --git a/llvm/docs/LangRef.rst b/llvm/docs/LangRef.rst --- a/llvm/docs/LangRef.rst +++ b/llvm/docs/LangRef.rst @@ -14562,6 +14562,23 @@ after performing the required machine specific adjustments. The pointer returned can then be :ref:`bitcast and executed `. + +.. _int_vp: + +Vector Predication Intrinsics +---------------------------- +VP intrinics are intended for predicated SIMD/vector code. +A typical VP operation takes a mask () and an explicit vector length parameter (i32) as in: + + llvm.vp..*( %x, %y, mask %Mask, i32 vlen %evl) + +The mask and explicit vector length parameter are unambiguously identified by the mask and vlen parameter attributes. +Result elements are only computed for enabled lanes. +The explicit vector length parameter only disables lane if the MSB of the parameter is zero. +A lane is enabled if the mask at that position is true and, if effective, where the lane position is below the explicit vector length. + +In case of purely vertical operations (SIMD binary operators, etc) the result is undef on disabled lanes. + .. _int_mload_mstore: Masked Vector Load and Store Intrinsics diff --git a/llvm/docs/Proposals/VectorPredication.rst b/llvm/docs/Proposals/VectorPredication.rst new file mode 100644 --- /dev/null +++ b/llvm/docs/Proposals/VectorPredication.rst @@ -0,0 +1,83 @@ +========================== +Vector Predication Roadmap +========================== + +.. contents:: Table of Contents + :depth: 3 + :local: + +Motivation +========== + +This proposal defines a roadmap towards native vector predication in LLVM, specifically for vector instructions with a mask and/or an explicit vector length. +LLVM currently has no target-independent means to model predicated vector instructions for modern SIMD ISAs such as AVX512, ARM SVE, the RISC-V V extension and NEC SX-Aurora. +Only some predicated vector operations, such as masked loads and stores are available through intrinsics [MaskedIR]_. + +The Explicit Vector Length extension +==================================== + +The Explicit Vector Length (EVL) extension [EvlRFC]_ can be a first step towards native vector predication. +The EVL prototype in this patch demonstrates the following concepts: + +- Predicated vector intrinsics with an explicit mask and vector length parameter on IR level. +- First-class predicated SDNodes on ISel level. Mask and vector length are value operands. +- An incremental strategy to generalize PatternMatch/InstCombine/InstSimplify and DAGCombiner to work on both regular instructions and EVL intrinsics. +- DAGCombiner example: FMA fusion. +- InstCombine/InstSimplify example: FSub pattern re-writes. +- Early experiments on the LNT test suite (Clang static release, O3 -ffast-math) indicate that compile time on non-EVL IR is not affected by the API abstractions in PatternMatch, etc. + +Roadmap +======= + +Drawing from the EVL prototype, we propose the following roadmap towards native vector predication in LLVM: + + +1. IR-level EVL intrinsics +----------------------------------------- + +- There is a consensus on the semantics/instruction set of EVL. +- EVL intrinsics and attributes are available on IR level. +- TTI has capability flags for EVL (``supportsEVL()``?, ``haveActiveVectorLength()``?). + +Result: EVL usable for IR-level vectorizers (LV, VPlan, RegionVectorizer), potential integration in Clang with builtins. + +2. CodeGen support +------------------ + +- EVL intrinsics translate to first-class SDNodes (``llvm.evl.fdiv.* -> evl_fdiv``). +- EVL legalization (legalize explicit vector length to mask (AVX512), legalize EVL SDNodes to pre-existing ones (SSE, NEON)). + +Result: Backend development based on EVL SDNodes. + +3. Lift InstSimplify/InstCombine/DAGCombiner to EVL +------------------------------------------------ + +- Introduce PredicatedInstruction, PredicatedBinaryOperator, .. helper classes that match standard vector IR and EVL intrinsics. +- Add a matcher context to PatternMatch and context-aware IR Builder APIs. +- Incrementally lift DAGCombiner to work on EVL SDNodes as well as on regular vector instructions. +- Incrementally lift InstCombine/InstSimplify to operate on EVL as well as regular IR instructions. + +Result: Optimization of EVL intrinsics on par with standard vector instructions. + +4. Deprecate llvm.masked.* / llvm.experimental.reduce.* +------------------------------------------------------- + +- Modernize llvm.masked.* / llvm.experimental.reduce* by translating to EVL. +- DCE transitional APIs. + +Result: EVL has superseded earlier vector intrinsics. + +5. Predicated IR Instructions +--------------------------------------- + +- Vector instructions have an optional mask and vector length parameter. These lower to EVL SDNodes (from Stage 2). +- Phase out EVL intrinsics, only keeping those that are not equivalent to vectorized scalar instructions (reduce, shuffles, ..) +- InstCombine/InstSimplify expect predication in regular Instructions (Stage (3) has laid the groundwork). + +Result: Native vector predication in IR. + +References +========== + +.. [MaskedIR] `llvm.masked.*` intrinsics, https://llvm.org/docs/LangRef.html#masked-vector-load-and-store-intrinsics +.. [EvlRFC] Explicit Vector Length RFC, https://reviews.llvm.org/D53613 diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -53,6 +53,10 @@ class Value; class MDNode; class BinaryOperator; +class VPIntrinsic; +namespace PatternMatch { + struct PredicatedContext; +} /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative @@ -138,6 +142,13 @@ Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); +/// Given operands for an FSub, fold the result or return null. +Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); +Value *SimplifyPredicatedFSubInst(Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q, + PatternMatch::PredicatedContext & PC); + /// Given operands for an FMul, fold the result or return null. Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); @@ -259,6 +270,9 @@ /// Given a callsite, fold the result or return null. Value *SimplifyCall(CallBase *Call, const SimplifyQuery &Q); +/// Given a VP intrinsic function, fold the result or return null. +Value *SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q); + /// See if we can compute a simplified version of this instruction. If not, /// return null. Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -632,6 +632,9 @@ ATTR_KIND_NOFREE = 62, ATTR_KIND_NOSYNC = 63, ATTR_KIND_SANITIZE_MEMTAG = 64, + ATTR_KIND_MASK = 65, + ATTR_KIND_VECTORLENGTH = 66, + ATTR_KIND_PASSTHRU = 67, }; enum ComdatSelectionKindCodes { diff --git a/llvm/include/llvm/CodeGen/ISDOpcodes.h b/llvm/include/llvm/CodeGen/ISDOpcodes.h --- a/llvm/include/llvm/CodeGen/ISDOpcodes.h +++ b/llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -198,6 +198,7 @@ /// Simple integer binary arithmetic operators. ADD, SUB, MUL, SDIV, UDIV, SREM, UREM, + VP_ADD, VP_SUB, VP_MUL, VP_SDIV, VP_UDIV, VP_SREM, VP_UREM, /// SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing /// a signed/unsigned value of type i[2*N], and return the full value as @@ -285,6 +286,7 @@ /// Simple binary floating point operators. FADD, FSUB, FMUL, FDIV, FREM, + VP_FADD, VP_FSUB, VP_FMUL, VP_FDIV, VP_FREM, /// Constrained versions of the binary floating point operators. /// These will be lowered to the simple operators before final selection. @@ -310,8 +312,8 @@ STRICT_FP_TO_SINT, STRICT_FP_TO_UINT, - /// X = STRICT_FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating - /// point type down to the precision of the destination VT. TRUNC is a + /// X = STRICT_FP_ROUND(Y, TRUNC) - Rounding 'Y' from a larger floating + /// point type down to the precision of the destination VT. TRUNC is a /// flag, which is always an integer that is zero or one. If TRUNC is 0, /// this is a normal rounding, if it is 1, this FP_ROUND is known to not /// change the value of Y. @@ -332,6 +334,7 @@ /// FMA - Perform a * b + c with no intermediate rounding step. FMA, + VP_FMA, /// FMAD - Perform a * b + c, while getting the same result as the /// separately rounded operations. @@ -398,6 +401,19 @@ /// in terms of the element size of VEC1/VEC2, not in terms of bytes. VECTOR_SHUFFLE, + /// VP_VSHIFT(VEC1, AMOUNT, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. AMOUNT is an integer value. The returned vector is equivalent + /// to VEC1 shifted by AMOUNT (RETURNED_VEC[idx] = VEC1[idx + AMOUNT]). + VP_VSHIFT, + + /// VP_COMPRESS(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_COMPRESS, + + /// VP_EXPAND(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + VP_EXPAND, + /// SCALAR_TO_VECTOR(VAL) - This represents the operation of loading a /// scalar value into element 0 of the resultant vector type. The top /// elements 1 to N-1 of the N-element vector are undefined. The type @@ -424,6 +440,7 @@ /// Bitwise operators - logical and, logical or, logical xor. AND, OR, XOR, + VP_AND, VP_OR, VP_XOR, /// ABS - Determine the unsigned absolute value of a signed integer value of /// the same bitwidth. @@ -447,6 +464,7 @@ /// fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW))) /// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW)) SHL, SRA, SRL, ROTL, ROTR, FSHL, FSHR, + VP_SHL, VP_SRA, VP_SRL, /// Byte Swap and Counting operators. BSWAP, CTTZ, CTLZ, CTPOP, BITREVERSE, @@ -466,6 +484,14 @@ /// change the condition type in order to match the VSELECT node using a /// pattern. The condition follows the BooleanContent format of the target. VSELECT, + VP_SELECT, + + /// Select with an integer pivot (op #0) and two vector operands (ops #1 + /// and #2), returning a vector result. All vectors have the same length. + /// Similar to the vector select, a comparison of the results element index + /// with the integer pivot selects hether the corresponding result element + /// is taken from op #1 or op #2. + VP_COMPOSE, /// Select with condition operator - This selects between a true value and /// a false value (ops #2 and #3) based on the boolean result of comparing @@ -480,6 +506,7 @@ /// them with (op #2) as a CondCodeSDNode. If the operands are vector types /// then the result type must also be a vector type. SETCC, + VP_SETCC, /// Like SetCC, ops #0 and #1 are the LHS and RHS operands to compare, but /// op #2 is a boolean indicating if there is an incoming carry. This @@ -620,6 +647,7 @@ FCEIL, FTRUNC, FRINT, FNEARBYINT, FROUND, FFLOOR, LROUND, LLROUND, LRINT, LLRINT, + VP_FNEG, // TODO supplement VP opcodes /// FMINNUM/FMAXNUM - Perform floating-point minimum or maximum on two /// values. // @@ -868,6 +896,7 @@ // Val, OutChain = MLOAD(BasePtr, Mask, PassThru) // OutChain = MSTORE(Value, BasePtr, Mask) MLOAD, MSTORE, + VP_LOAD, VP_STORE, // Masked gather and scatter - load and store operations for a vector of // random addresses with additional mask operand that prevents memory @@ -879,6 +908,7 @@ // The Index operand can have more vector elements than the other operands // due to type legalization. The extra elements are ignored. MGATHER, MSCATTER, + VP_GATHER, VP_SCATTER, /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. @@ -916,6 +946,14 @@ VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR, VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN, + VP_REDUCE_FADD, VP_REDUCE_FMUL, + VP_REDUCE_ADD, VP_REDUCE_MUL, + VP_REDUCE_AND, VP_REDUCE_OR, VP_REDUCE_XOR, + VP_REDUCE_SMAX, VP_REDUCE_SMIN, VP_REDUCE_UMAX, VP_REDUCE_UMIN, + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + VP_REDUCE_FMAX, VP_REDUCE_FMIN, + /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. BUILTIN_OP_END @@ -1092,6 +1130,20 @@ /// SETCC_INVALID if it is not possible to represent the resultant comparison. CondCode getSetCCAndOperation(CondCode Op1, CondCode Op2, bool isInteger); + /// Return the mask operand of this VP SDNode. + /// Otw, return -1. + int GetMaskPosVP(unsigned OpCode); + + /// Return the vector length operand of this VP SDNode. + /// Otw, return -1. + int GetVectorLengthPosVP(unsigned OpCode); + + /// Translate this VP OpCode to an unpredicated instruction OpCode. + unsigned GetFunctionOpCodeForVP(unsigned VPOpCode, bool hasFPExcept); + + /// Translate this non-VP Opcode to its corresponding VP Opcode + unsigned GetVPForFunctionOpCode(unsigned OpCode); + } // end llvm::ISD namespace } // end llvm namespace diff --git a/llvm/include/llvm/CodeGen/SelectionDAG.h b/llvm/include/llvm/CodeGen/SelectionDAG.h --- a/llvm/include/llvm/CodeGen/SelectionDAG.h +++ b/llvm/include/llvm/CodeGen/SelectionDAG.h @@ -1125,6 +1125,20 @@ SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base, SDValue Offset, ISD::MemIndexedMode AM); + /// Returns sum of the base pointer and offset. + SDValue getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, ISD::LoadExtType); + SDValue getStoreVP(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, bool IsTruncating = false); + SDValue getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + SDValue getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO, + ISD::MemIndexType IndexType); + /// Returns sum of the base pointer and offset. SDValue getMemBasePlusOffset(SDValue Base, unsigned Offset, const SDLoc &DL); diff --git a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h --- a/llvm/include/llvm/CodeGen/SelectionDAGNodes.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGNodes.h @@ -549,6 +549,7 @@ class LSBaseSDNodeBitfields { friend class LSBaseSDNode; friend class MaskedGatherScatterSDNode; + friend class VPGatherScatterSDNode; uint16_t : NumMemSDNodeBits; @@ -563,6 +564,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class VPLoadSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -573,6 +575,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class MaskedStoreSDNode; + friend class VPStoreSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -721,6 +724,66 @@ } } + /// Test whether this is an Explicit Vector Length node. + bool isVP() const { + switch (NodeType) { + default: + return false; + case ISD::VP_LOAD: + case ISD::VP_STORE: + case ISD::VP_GATHER: + case ISD::VP_SCATTER: + + case ISD::VP_FNEG: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + case ISD::VP_FSUB: + case ISD::VP_FDIV: + case ISD::VP_FREM: + + case ISD::VP_FMA: + + case ISD::VP_ADD: + case ISD::VP_MUL: + case ISD::VP_SUB: + case ISD::VP_SRA: + case ISD::VP_SRL: + case ISD::VP_SHL: + case ISD::VP_UDIV: + case ISD::VP_SDIV: + case ISD::VP_UREM: + case ISD::VP_SREM: + + case ISD::VP_EXPAND: + case ISD::VP_COMPRESS: + case ISD::VP_VSHIFT: + case ISD::VP_SETCC: + case ISD::VP_COMPOSE: + + case ISD::VP_AND: + case ISD::VP_XOR: + case ISD::VP_OR: + + case ISD::VP_REDUCE_ADD: + case ISD::VP_REDUCE_SMIN: + case ISD::VP_REDUCE_SMAX: + case ISD::VP_REDUCE_UMIN: + case ISD::VP_REDUCE_UMAX: + + case ISD::VP_REDUCE_MUL: + case ISD::VP_REDUCE_AND: + case ISD::VP_REDUCE_OR: + case ISD::VP_REDUCE_FADD: + case ISD::VP_REDUCE_FMUL: + case ISD::VP_REDUCE_FMIN: + case ISD::VP_REDUCE_FMAX: + + return true; + } + } + + /// Test if this node has a post-isel opcode, directly /// corresponding to a MachineInstr opcode. bool isMachineOpcode() const { return NodeType < 0; } @@ -1426,6 +1489,10 @@ N->getOpcode() == ISD::MSTORE || N->getOpcode() == ISD::MGATHER || N->getOpcode() == ISD::MSCATTER || + N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE || + N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER || N->isMemIntrinsic() || N->isTargetMemoryOpcode(); } @@ -2287,6 +2354,96 @@ } }; +/// This base class is used to represent MLOAD and MSTORE nodes +class VPLoadStoreSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + + // VPLoadSDNode (Chain, ptr, mask, VLen) + // VPStoreSDNode (Chain, data, ptr, mask, VLen) + // Mask is a vector of i1 elements, Vlen is i32 + const SDValue &getBasePtr() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 1 : 2); + } + const SDValue &getMask() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 2 : 3); + } + const SDValue &getVectorLength() const { + return getOperand(getOpcode() == ISD::VP_LOAD ? 3 : 4); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD || + N->getOpcode() == ISD::VP_STORE; + } +}; + +/// This class is used to represent an MLOAD node +class VPLoadSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + ISD::LoadExtType ETy, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_LOAD, Order, dl, VTs, MemVT, MMO) { + LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = false; + } + + ISD::LoadExtType getExtensionType() const { + return static_cast(LoadSDNodeBits.ExtTy); + } + + const SDValue &getBasePtr() const { return getOperand(1); } + const SDValue &getMask() const { return getOperand(2); } + const SDValue &getVectorLength() const { return getOperand(3); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_LOAD; + } + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } +}; + +/// This class is used to represent an MSTORE node +class VPStoreSDNode : public VPLoadStoreSDNode { +public: + friend class SelectionDAG; + + VPStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + bool isTrunc, EVT MemVT, + MachineMemOperand *MMO) + : VPLoadStoreSDNode(ISD::VP_STORE, Order, dl, VTs, MemVT, MMO) { + StoreSDNodeBits.IsTruncating = isTrunc; + StoreSDNodeBits.IsCompressing = false; + } + + /// Return true if the op does a truncation before store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } + + /// Returns true if the op does a compression to the vector before storing. + /// The node contiguously stores the active elements (integers or floats) + /// in src (those with their respective bit set in writemask k) to unaligned + /// memory at base_addr. + bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; } + + const SDValue &getValue() const { return getOperand(1); } + const SDValue &getBasePtr() const { return getOperand(2); } + const SDValue &getMask() const { return getOperand(3); } + const SDValue &getVectorLength() const { return getOperand(4); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_STORE; + } +}; + /// This base class is used to represent MLOAD and MSTORE nodes class MaskedLoadStoreSDNode : public MemSDNode { public: @@ -2374,6 +2531,85 @@ } }; +/// This is a base class used to represent +/// VP_GATHER and VP_SCATTER nodes +/// +class VPGatherScatterSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + VPGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO, ISD::MemIndexType IndexType) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) { + LSBaseSDNodeBits.AddressingMode = IndexType; + assert(getIndexType() == IndexType && "Value truncated"); + } + + /// How is Index applied to BasePtr when computing addresses. + ISD::MemIndexType getIndexType() const { + return static_cast(LSBaseSDNodeBits.AddressingMode); + } + bool isIndexScaled() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::UNSIGNED_SCALED); + } + bool isIndexSigned() const { + return (getIndexType() == ISD::SIGNED_SCALED) || + (getIndexType() == ISD::SIGNED_UNSCALED); + } + + // In the both nodes address is Op1, mask is Op2: + // VPGatherSDNode (Chain, base, index, scale, mask, vlen) + // VPScatterSDNode (Chain, value, base, index, scale, mask, vlen) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 1 : 2); } + const SDValue &getIndex() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 2 : 3); } + const SDValue &getScale() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 3 : 4); } + const SDValue &getMask() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 4 : 5); } + const SDValue &getVectorLength() const { return getOperand((getOpcode() == ISD::VP_GATHER) ? 5 : 6); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER || + N->getOpcode() == ISD::VP_SCATTER; + } +}; + +/// This class is used to represent an VP_GATHER node +/// +class VPGatherSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_GATHER, Order, dl, VTs, MemVT, MMO, IndexType) {} + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_GATHER; + } +}; + +/// This class is used to represent an VP_SCATTER node +/// +class VPScatterSDNode : public VPGatherScatterSDNode { +public: + friend class SelectionDAG; + + VPScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO, + ISD::MemIndexType IndexType) + : VPGatherScatterSDNode(ISD::VP_SCATTER, Order, dl, VTs, MemVT, MMO, IndexType) {} + + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::VP_SCATTER; + } +}; + + /// This is a base class used to represent /// MGATHER and MSCATTER nodes /// diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td --- a/llvm/include/llvm/IR/Attributes.td +++ b/llvm/include/llvm/IR/Attributes.td @@ -139,6 +139,15 @@ /// Parameter is required to be a trivial constant. def ImmArg : EnumAttr<"immarg">; +/// Return value that is equal to this argument on enabled lanes (mask). +def Passthru : EnumAttr<"passthru">; + +/// Mask argument that applies to this function. +def Mask : EnumAttr<"mask">; + +/// Dynamic Vector Length argument of this function. +def VectorLength : EnumAttr<"vlen">; + /// Function can return twice. def ReturnsTwice : EnumAttr<"returns_twice">; diff --git a/llvm/include/llvm/IR/IRBuilder.h b/llvm/include/llvm/IR/IRBuilder.h --- a/llvm/include/llvm/IR/IRBuilder.h +++ b/llvm/include/llvm/IR/IRBuilder.h @@ -29,6 +29,7 @@ #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" @@ -97,8 +98,8 @@ FastMathFlags FMF; bool IsFPConstrained; - ConstrainedFPIntrinsic::ExceptionBehavior DefaultConstrainedExcept; - ConstrainedFPIntrinsic::RoundingMode DefaultConstrainedRounding; + ExceptionBehavior DefaultConstrainedExcept; + RoundingMode DefaultConstrainedRounding; ArrayRef DefaultOperandBundles; @@ -106,8 +107,8 @@ IRBuilderBase(LLVMContext &context, MDNode *FPMathTag = nullptr, ArrayRef OpBundles = None) : Context(context), DefaultFPMathTag(FPMathTag), IsFPConstrained(false), - DefaultConstrainedExcept(ConstrainedFPIntrinsic::ebStrict), - DefaultConstrainedRounding(ConstrainedFPIntrinsic::rmDynamic), + DefaultConstrainedExcept(ExceptionBehavior::ebStrict), + DefaultConstrainedRounding(RoundingMode::rmDynamic), DefaultOperandBundles(OpBundles) { ClearInsertionPoint(); } @@ -235,23 +236,23 @@ /// Set the exception handling to be used with constrained floating point void setDefaultConstrainedExcept( - ConstrainedFPIntrinsic::ExceptionBehavior NewExcept) { + ExceptionBehavior NewExcept) { DefaultConstrainedExcept = NewExcept; } /// Set the rounding mode handling to be used with constrained floating point void setDefaultConstrainedRounding( - ConstrainedFPIntrinsic::RoundingMode NewRounding) { + RoundingMode NewRounding) { DefaultConstrainedRounding = NewRounding; } /// Get the exception handling used with constrained floating point - ConstrainedFPIntrinsic::ExceptionBehavior getDefaultConstrainedExcept() { + ExceptionBehavior getDefaultConstrainedExcept() { return DefaultConstrainedExcept; } /// Get the rounding mode handling used with constrained floating point - ConstrainedFPIntrinsic::RoundingMode getDefaultConstrainedRounding() { + RoundingMode getDefaultConstrainedRounding() { return DefaultConstrainedRounding; } @@ -1098,35 +1099,25 @@ } Value *getConstrainedFPRounding( - Optional Rounding) { - ConstrainedFPIntrinsic::RoundingMode UseRounding = + Optional Rounding) { + RoundingMode UseRounding = DefaultConstrainedRounding; if (Rounding.hasValue()) UseRounding = Rounding.getValue(); - Optional RoundingStr = - ConstrainedFPIntrinsic::RoundingModeToStr(UseRounding); - assert(RoundingStr.hasValue() && "Garbage strict rounding mode!"); - auto *RoundingMDS = MDString::get(Context, RoundingStr.getValue()); - - return MetadataAsValue::get(Context, RoundingMDS); + return GetConstrainedFPRounding(Context, UseRounding); } Value *getConstrainedFPExcept( - Optional Except) { - ConstrainedFPIntrinsic::ExceptionBehavior UseExcept = + Optional Except) { + ExceptionBehavior UseExcept = DefaultConstrainedExcept; if (Except.hasValue()) UseExcept = Except.getValue(); - Optional ExceptStr = - ConstrainedFPIntrinsic::ExceptionBehaviorToStr(UseExcept); - assert(ExceptStr.hasValue() && "Garbage strict exception behavior!"); - auto *ExceptMDS = MDString::get(Context, ExceptStr.getValue()); - - return MetadataAsValue::get(Context, ExceptMDS); + return GetConstrainedFPExcept(Context, UseExcept); } public: @@ -1483,8 +1474,8 @@ CallInst *CreateConstrainedFPBinOp( Intrinsic::ID ID, Value *L, Value *R, Instruction *FMFSource = nullptr, const Twine &Name = "", MDNode *FPMathTag = nullptr, - Optional Rounding = None, - Optional Except = None) { + Optional Rounding = None, + Optional Except = None) { Value *RoundingV = getConstrainedFPRounding(Rounding); Value *ExceptV = getConstrainedFPExcept(Except); @@ -2078,8 +2069,8 @@ Intrinsic::ID ID, Value *V, Type *DestTy, Instruction *FMFSource = nullptr, const Twine &Name = "", MDNode *FPMathTag = nullptr, - Optional Rounding = None, - Optional Except = None) { + Optional Rounding = None, + Optional Except = None) { Value *ExceptV = getConstrainedFPExcept(Except); FastMathFlags UseFMF = FMF; diff --git a/llvm/include/llvm/IR/IntrinsicInst.h b/llvm/include/llvm/IR/IntrinsicInst.h --- a/llvm/include/llvm/IR/IntrinsicInst.h +++ b/llvm/include/llvm/IR/IntrinsicInst.h @@ -205,50 +205,258 @@ /// @} }; - /// This is the common base class for constrained floating point intrinsics. - class ConstrainedFPIntrinsic : public IntrinsicInst { + enum class RoundingMode : uint8_t { + rmInvalid, + rmDynamic, + rmToNearest, + rmDownward, + rmUpward, + rmTowardZero + }; + + enum class ExceptionBehavior : uint8_t { + ebInvalid, + ebIgnore, + ebMayTrap, + ebStrict + }; + + /// Returns a valid RoundingMode enumerator when given a string + /// that is valid as input in constrained intrinsic rounding mode + /// metadata. + Optional StrToRoundingMode(StringRef); + + /// For any RoundingMode enumerator, returns a string valid as input in + /// constrained intrinsic rounding mode metadata. + Optional RoundingModeToStr(RoundingMode); + + /// Returns a valid ExceptionBehavior enumerator when given a string + /// valid as input in constrained intrinsic exception behavior metadata. + Optional StrToExceptionBehavior(StringRef); + + /// For any ExceptionBehavior enumerator, returns a string valid as + /// input in constrained intrinsic exception behavior metadata. + Optional ExceptionBehaviorToStr(ExceptionBehavior); + + /// Return the IR Value representation of any ExceptionBehavior. + Value* + GetConstrainedFPExcept(LLVMContext&, ExceptionBehavior); + + /// Return the IR Value representation of any RoundingMode. + Value* + GetConstrainedFPRounding(LLVMContext&, RoundingMode); + + class VPIntrinsic : public IntrinsicInst { public: - /// Specifies the rounding mode to be assumed. This is only used when - /// when constrained floating point is enabled. See the LLVM Language - /// Reference Manual for details. - enum RoundingMode : uint8_t { - rmDynamic, ///< This corresponds to "fpround.dynamic". - rmToNearest, ///< This corresponds to "fpround.tonearest". - rmDownward, ///< This corresponds to "fpround.downward". - rmUpward, ///< This corresponds to "fpround.upward". - rmTowardZero ///< This corresponds to "fpround.tozero". + enum class VPTypeToken : int8_t { + Returned = 0, // return type token + Scalar = 1, // scalar operand type + Vector = 2, // vectorized operand type + Mask = 3 // vector mask type }; - /// Specifies the required exception behavior. This is only used when - /// when constrained floating point is used. See the LLVM Language - /// Reference Manual for details. - enum ExceptionBehavior : uint8_t { - ebIgnore, ///< This corresponds to "fpexcept.ignore". - ebMayTrap, ///< This corresponds to "fpexcept.maytrap". - ebStrict ///< This corresponds to "fpexcept.strict". - }; + using TypeTokenVec = SmallVector; + using ShortTypeVec = SmallVector; + + // Type tokens required to instantiate this intrinsic. + static TypeTokenVec GetTypeTokens(Intrinsic::ID); + + // whether the intrinsic has a rounding mode parameter (regardless of setting). + static bool hasRoundingModeParam(Intrinsic::ID VPID); + // whether the intrinsic has a exception behavior parameter (regardless of setting). + static bool hasExceptionBehaviorParam(Intrinsic::ID VPID); + static Optional getMaskParamPos(Intrinsic::ID IntrinsicID); + static Optional getVectorLengthParamPos(Intrinsic::ID IntrinsicID); + static Optional getExceptionBehaviorParamPos(Intrinsic::ID IntrinsicID); + static Optional getRoundingModeParamPos(Intrinsic::ID IntrinsicID); + // the llvm.vp.* intrinsic for this llvm.experimental.constrained.* intrinsic + static Intrinsic::ID getForConstrainedIntrinsic(Intrinsic::ID IntrinsicID); + static Intrinsic::ID getForOpcode(unsigned OC); + + // Generate the disambiguating type vec for this VP Intrinsic + static VPIntrinsic::ShortTypeVec + EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type * VecRetTy, Type & VectorTy, Type & ScalarTy); + + // available for all VP intrinsics + Value* getMask() const; + Value* getVectorLength() const; bool isUnaryOp() const; + bool isBinaryOp() const; bool isTernaryOp() const; + + // compare intrinsic + bool isCompareOp() const { + switch (getIntrinsicID()) { + default: + return false; + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + return true; + } + } + CmpInst::Predicate getCmpPredicate() const; + + // Contrained fp-math + // whether this is an fp op with non-standard rounding or exception behavior. + bool isConstrainedOp() const; + + // the specified rounding mode. Optional getRoundingMode() const; + // the specified exception behavior. Optional getExceptionBehavior() const; - /// Returns a valid RoundingMode enumerator when given a string - /// that is valid as input in constrained intrinsic rounding mode - /// metadata. - static Optional StrToRoundingMode(StringRef); + // llvm.vp.reduction.* + bool isReductionOp() const; + + // Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + switch (I->getIntrinsicID()) { + default: + return false; + + // general cmp + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + + // int arith + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + // memory + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + // shuffle + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + // fp arith + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + + // reductions + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + return true; + } + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } + + // Equivalent non-predicated opcode + unsigned getFunctionalOpcode() const { + if (isConstrainedOp()) { + return Instruction::Call; // TODO pass as constrained op + } + + switch (getIntrinsicID()) { + default: return Instruction::Call; + + case Intrinsic::vp_icmp: return Instruction::ICmp; + case Intrinsic::vp_fcmp: return Instruction::FCmp; + + case Intrinsic::vp_and: return Instruction::And; + case Intrinsic::vp_or: return Instruction::Or; + case Intrinsic::vp_xor: return Instruction::Xor; + case Intrinsic::vp_ashr: return Instruction::AShr; + case Intrinsic::vp_lshr: return Instruction::LShr; + case Intrinsic::vp_shl: return Instruction::Shl; + + case Intrinsic::vp_select: return Instruction::Select; + + case Intrinsic::vp_load: return Instruction::Load; + case Intrinsic::vp_store: return Instruction::Store; - /// For any RoundingMode enumerator, returns a string valid as input in - /// constrained intrinsic rounding mode metadata. - static Optional RoundingModeToStr(RoundingMode); + case Intrinsic::vp_fneg: return Instruction::FNeg; - /// Returns a valid ExceptionBehavior enumerator when given a string - /// valid as input in constrained intrinsic exception behavior metadata. - static Optional StrToExceptionBehavior(StringRef); + case Intrinsic::vp_fadd: return Instruction::FAdd; + case Intrinsic::vp_fsub: return Instruction::FSub; + case Intrinsic::vp_fmul: return Instruction::FMul; + case Intrinsic::vp_fdiv: return Instruction::FDiv; + case Intrinsic::vp_frem: return Instruction::FRem; - /// For any ExceptionBehavior enumerator, returns a string valid as - /// input in constrained intrinsic exception behavior metadata. - static Optional ExceptionBehaviorToStr(ExceptionBehavior); + case Intrinsic::vp_add: return Instruction::Add; + case Intrinsic::vp_sub: return Instruction::Sub; + case Intrinsic::vp_mul: return Instruction::Mul; + case Intrinsic::vp_udiv: return Instruction::UDiv; + case Intrinsic::vp_sdiv: return Instruction::SDiv; + case Intrinsic::vp_urem: return Instruction::URem; + case Intrinsic::vp_srem: return Instruction::SRem; + } + } + }; + + /// This is the common base class for constrained floating point intrinsics. + class ConstrainedFPIntrinsic : public IntrinsicInst { + public: + + bool isUnaryOp() const; + bool isTernaryOp() const; + Optional getRoundingMode() const; + Optional getExceptionBehavior() const; // Methods for support type inquiry through isa, cast, and dyn_cast: static bool classof(const IntrinsicInst *I) { diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -98,6 +98,25 @@ int ArgNo = argNo; } +// VectorLength - The specified argument is the Dynamic Vector Length of the +// operation. +class VectorLength : IntrinsicProperty { + int ArgNo = argNo; +} + +// Mask - The specified argument contains the per-lane mask of this +// intrinsic. Inputs on masked-out lanes must not affect the result of this +// intrinsic (except for the Passthru argument). +class Mask : IntrinsicProperty { + int ArgNo = argNo; +} +// Passthru - The specified argument contains the per-lane return value +// for this vector intrinsic where the mask is false. +// (requires the Mask attribute in the same function) +class Passthru : IntrinsicProperty { + int ArgNo = argNo; +} + def IntrNoReturn : IntrinsicProperty; def IntrWillReturn : IntrinsicProperty; @@ -1099,8 +1118,477 @@ def int_ptrmask: Intrinsic<[llvm_anyptr_ty], [llvm_anyptr_ty, llvm_anyint_ty], [IntrNoMem, IntrSpeculatable, IntrWillReturn]>; +//===---------------- Vector Predication Intrinsics --------------===// + +// Memory Intrinsics +def int_vp_store : Intrinsic<[], + [ llvm_anyvector_ty, + LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture<1>, IntrArgMemOnly, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +def int_vp_load : Intrinsic<[ llvm_anyvector_ty], + [ LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ NoCapture<0>, IntrReadMem, IntrWillReturn, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_vp_gather: Intrinsic<[ llvm_anyvector_ty], + [ LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrReadMem, IntrWillReturn, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_vp_scatter: Intrinsic<[], + [ llvm_anyvector_ty, + LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrArgMemOnly, IntrWillReturn, Mask<2>, VectorLength<3> ]>; +// TODO allow IntrNoCapture for vectors of pointers + +// Reductions +let IntrProperties = [IntrNoMem, IntrWillReturn, Mask<1>, VectorLength<2>] in { + def int_vp_reduce_add : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_mul : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_and : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_or : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_xor : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_smax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_smin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_umax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_umin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmax : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmin : Intrinsic<[LLVMVectorElementType<0>], + [llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +let IntrProperties = [IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3>] in { + def int_vp_reduce_fadd : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorElementType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_reduce_fmul : Intrinsic<[LLVMVectorElementType<0>], + [LLVMVectorElementType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +// Binary operators +let IntrProperties = [IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3>] in { + def int_vp_add : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_mul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_udiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_srem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_urem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Logical operators + def int_vp_ashr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_lshr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_shl : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_or : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_and : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_xor : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Comparison +// TODO add signalling fcmp +// The last argument is the comparison predicate + def int_vp_icmp : Intrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], + [ llvm_anyvector_ty, + LLVMMatchType<0>, + llvm_i8_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrNoMem, Mask<3>, VectorLength<4>, ImmArg<2> ]>; + + def int_vp_fcmp : Intrinsic<[ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty> ], + [ llvm_anyvector_ty, + LLVMMatchType<0>, + llvm_i8_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrNoMem, Mask<3>, VectorLength<4>, ImmArg<2> ]>; +} + + +// Shuffle +def int_vp_vshift: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +def int_vp_expand: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Mask<1>, VectorLength<2> ]>; + +def int_vp_compress: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, VectorLength<2> ]>; + +// Select +def int_vp_select : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, Passthru<2>, Mask<0>, VectorLength<3> ]>; + +// Compose +def int_vp_compose : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [ IntrNoMem, IntrWillReturn, VectorLength<3> ]>; + + + +// VP fp rounding and truncation +let IntrProperties = [IntrInaccessibleMemOnly, IntrWillReturn, Mask<2>, VectorLength<3> ] in { + + def int_vp_fptosi : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_fptoui : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_fpext : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_vp_lround : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_llround : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + +let IntrProperties = [IntrInaccessibleMemOnly, IntrWillReturn, Mask<3>, VectorLength<4> ] in { + def int_vp_fptrunc : Intrinsic<[ llvm_anyfloat_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + +// VP single argument constrained intrinsics. +let IntrProperties = [IntrInaccessibleMemOnly, IntrWillReturn, Mask<3>, VectorLength<4> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_sqrt : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_sin : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_cos : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log10: Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_log2 : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_exp : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_exp2 : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_rint : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_nearbyint : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_lrint : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_llrint : Intrinsic<[ llvm_anyint_ty ], + [ llvm_anyfloat_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_ceil : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_floor : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_round : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_trunc : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; +} + + +// VP two argument constrained intrinsics. +let IntrProperties = [IntrInaccessibleMemOnly, IntrWillReturn, Mask<4>, VectorLength<5> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_powi : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_pow : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_maxnum : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_vp_minnum : Intrinsic<[ llvm_anyfloat_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +} + + +// VP standard fp-math intrinsics. +def int_vp_fneg : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrInaccessibleMemOnly, IntrWillReturn, Mask<2>, VectorLength<3> ]>; + +let IntrProperties = [IntrInaccessibleMemOnly, IntrWillReturn, Mask<4>, VectorLength<5> ] in { + // These intrinsics are sensitive to the rounding mode so we need constrained + // versions of each of them. When strict rounding and exception control are + // not required the non-constrained versions of these intrinsics should be + // used. + def int_vp_fadd : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fsub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fmul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_fdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; + def int_vp_frem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ]>; +} + +def int_vp_fma : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_metadata_ty, + llvm_metadata_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty ], + [ IntrInaccessibleMemOnly, IntrWillReturn, Mask<5>, VectorLength<6> ]>; + + + + //===-------------------------- Masked Intrinsics -------------------------===// -// +// TODO poised for deprecation (to be superseded by llvm.vp.* intrinsics) def int_masked_store : Intrinsic<[], [llvm_anyvector_ty, LLVMAnyPointerType>, llvm_i32_ty, @@ -1200,6 +1688,7 @@ [ IntrArgMemOnly, IntrWillReturn, NoCapture<0>, WriteOnly<0>, ImmArg<3> ]>; //===------------------------ Reduction Intrinsics ------------------------===// +// TODO poised for deprecation (to be superseded by llvm.vp.*. intrinsics) // let IntrProperties = [IntrNoMem, IntrWillReturn] in { def int_experimental_vector_reduce_v2_fadd : Intrinsic<[llvm_anyfloat_ty], diff --git a/llvm/include/llvm/IR/MatcherCast.h b/llvm/include/llvm/IR/MatcherCast.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/MatcherCast.h @@ -0,0 +1,65 @@ +#ifndef LLVM_IR_MATCHERCAST_H +#define LLVM_IR_MATCHERCAST_H + +//===- MatcherCast.h - Match on the LLVM IR --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parameterized class hierachy for templatized pattern matching. +// +//===----------------------------------------------------------------------===// + + +namespace llvm { +namespace PatternMatch { + + +// type modification +template +struct MatcherCast { }; + +// whether the Value \p Obj behaves like a \p Class. +template +bool match_isa(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return isa(Obj); +} + +template +auto match_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + +template +auto match_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_MATCHERCAST_H + diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -39,22 +39,81 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "llvm/IR/MatcherCast.h" + #include + namespace llvm { namespace PatternMatch { +// Use verbatim types in default (empty) context. +struct EmptyContext { + EmptyContext() {} + + EmptyContext(const Value *) {} + + EmptyContext(const EmptyContext & E) {} + + // reset this match context to be rooted at \p V + void reset(Value * V) {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { return true; } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(EmptyContext E) const { return true; } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(EmptyContext E) { return true; } + + // reset this context to \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match in the current context + template bool try_match(Val *V, const Pattern &P) { + return const_cast(P).match_context(V, *this); + } +}; + +template +struct MatcherCast { using ActualCastType = DestClass; }; + + + + + + +// match without (== empty) context template bool match(Val *V, const Pattern &P) { - return const_cast(P).match(V); + EmptyContext ECtx; + return const_cast(P).match_context(V, ECtx); +} + +// match pattern in a given context +template bool match(Val *V, const Pattern &P, MatchContext & MContext) { + return const_cast(P).match_context(V, MContext); } + + template struct OneUse_match { SubPattern_t SubPattern; OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + EmptyContext EContext; return match_context(V, EContext); + } + + template bool match_context(OpTy *V, MatchContext & MContext) { + return V->hasOneUse() && SubPattern.match_context(V, MContext); } }; @@ -63,7 +122,11 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + template bool match(ITy *V) { + EmptyContext EContext; return match_context(V, EContext); + } + template + bool match_context(ITy *V, MatchContext & MContext) { return match_isa(V); } }; /// Match an arbitrary value and ignore it. @@ -114,11 +177,17 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + MatchContext SubContext; + + if (L.match_context(V, SubContext) && MContext.acceptContext(SubContext)) { + MContext.mergeContext(SubContext); return true; - if (R.match(V)) + } + if (R.match_context(V, MContext)) { return true; + } return false; } }; @@ -129,9 +198,10 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + if (L.match_context(V, MContext)) + if (R.match_context(V, MContext)) return true; return false; } @@ -154,7 +224,8 @@ apint_match(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; @@ -174,7 +245,8 @@ struct apfloat_match { const APFloat *&Res; apfloat_match(const APFloat *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; @@ -198,7 +270,8 @@ inline apfloat_match m_APFloat(const APFloat *&Res) { return Res; } template struct constantint_match { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -221,7 +294,8 @@ /// satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cst_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) return this->isValue(CI->getValue()); if (V->getType()->isVectorTy()) { @@ -258,7 +332,8 @@ api_pred_ty(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -280,7 +355,8 @@ /// constants that satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cstfp_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CF = dyn_cast(V)) return this->isValue(CF->getValueAPF()); if (V->getType()->isVectorTy()) { @@ -393,7 +469,8 @@ } struct is_zero { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { auto *C = dyn_cast(V); return C && (C->isNullValue() || cst_pred_ty().match(C)); } @@ -541,8 +618,11 @@ bind_ty(Class *&V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CV = dyn_cast(V)) { + if (!MContext.acceptBoundNode(V)) return false; + VR = CV; return true; } @@ -580,7 +660,8 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { return V == Val; } }; /// Match if we have a specific specified value. @@ -593,7 +674,8 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *const V, MatchContext & MContext) { return V == Val; } }; /// A commutative-friendly version of m_Specific(). @@ -609,7 +691,8 @@ specific_fpval(double V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -632,7 +715,8 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -649,7 +733,8 @@ specific_intval(APInt V) : Val(std::move(V)) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -679,7 +764,8 @@ specific_bbval(BasicBlock *Val) : Val(Val) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EC; return match_context(V, EC); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *BB = dyn_cast(V); return BB && BB == Val; } @@ -711,11 +797,16 @@ // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (!I) return false; + + if (!MContext.acceptInnerNode(I)) return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; return false; } }; @@ -739,12 +830,15 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + MatchContext LRContext(MContext); + if (!MContext.acceptInnerNode(I)) return false; + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; + return false; } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && @@ -783,25 +877,26 @@ Op_t X; FNeg_match(const Op_t &Op) : X(Op) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { auto *FPMO = dyn_cast(V); if (!FPMO) return false; - if (FPMO->getOpcode() == Instruction::FNeg) + if (match_cast(V)->getOpcode() == Instruction::FNeg) return X.match(FPMO->getOperand(0)); - if (FPMO->getOpcode() == Instruction::FSub) { + if (match_cast(V)->getOpcode() == Instruction::FSub) { if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match_context(FPMO->getOperand(1), MContext); } return false; @@ -915,7 +1010,8 @@ OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *Op = dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; @@ -925,7 +1021,7 @@ if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match_context(Op->getOperand(0), MContext) && R.match_context(Op->getOperand(1), MContext); } return false; } @@ -1007,10 +1103,11 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && L.match_context(I->getOperand(0), MContext) && + R.match_context(I->getOperand(1), MContext); if (auto *CE = dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); @@ -1102,9 +1199,10 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + return PEO->isExact() && SubPattern.match_context(V, MContext); return false; } }; @@ -1129,14 +1227,17 @@ CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - if ((L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0)))) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) { + if (!MContext.acceptInnerNode(I)) return false; + MatchContext LRContext(MContext); + if ((L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) || + (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext)))) { Predicate = I->getPredicate(); return true; } + } return false; } }; @@ -1169,10 +1270,11 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext); } return false; } @@ -1185,10 +1287,12 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext); } return false; } @@ -1204,11 +1308,13 @@ ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) : Op1(Op1), Op2(Op2), Op3(Op3) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext) && + Op3.match_context(I->getOperand(2), MContext); } return false; } @@ -1276,9 +1382,10 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto O = match_dyn_cast(V)) + return O->getOpcode() == Opcode && MContext.acceptInnerNode(O) && Op.match_context(O->getOperand(0), MContext); return false; } }; @@ -1380,8 +1487,9 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1401,10 +1509,12 @@ brc_match(const Cond_t &C, const TrueBlock_t &t, const FalseBlock_t &f) : Cond(C), T(t), F(f) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) - if (BI->isConditional() && Cond.match(BI->getCondition())) - return T.match(BI->getSuccessor(0)) && F.match(BI->getSuccessor(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) + if (BI->isConditional() && Cond.match(BI->getCondition())) { + return T.match_context(BI->getSuccessor(0), MContext) && F.match_context(BI->getSuccessor(1), MContext); + } return false; } }; @@ -1436,13 +1546,14 @@ // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = match_dyn_cast(V); + if (!SI || !MContext.acceptInnerNode(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = match_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.acceptInnerNode(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1458,9 +1569,12 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatchContext LRContext(MContext); + if (L.match_context(LHS, LRContext) && R.match_context(RHS, LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(RHS, MContext) && R.match_context(LHS, MContext))) return true; + return false; } }; @@ -1617,7 +1731,8 @@ UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) : L(L), R(R), S(S) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) @@ -1670,9 +1785,10 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) + if (const auto *CI = match_dyn_cast(V)) return Val.match(CI->getArgOperand(OpI)); return false; } @@ -1690,8 +1806,9 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} - template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (const auto *CI = match_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -1901,7 +2018,8 @@ Opnd_t Val; Signum_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; diff --git a/llvm/include/llvm/IR/PredicatedInst.h b/llvm/include/llvm/IR/PredicatedInst.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/PredicatedInst.h @@ -0,0 +1,368 @@ +//===-- llvm/PredicatedInst.h - Predication utility subclass --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with predicated instructions. +// Predicated instructions are either regular instructions or calls to +// Vector Predication (VP) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/MatcherCast.h" + +#include + +namespace llvm { + +class BasicBlock; + +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void copyIRFlags(const Value * V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock* getParent() { return cast(this)->getParent(); } + const BasicBlock* getParent() const { return cast(this)->getParent(); } + + void *operator new(size_t s) = delete; + + Value* getMask() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getMask(); + } + + Value* getVectorLength() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getVectorLength(); + } + + unsigned getOpcode() const { + auto * VPInst = dyn_cast(this); + + if (!VPInst) { + return cast(this)->getOpcode(); + } + + return VPInst->getFunctionalOpcode(); + } + + static bool classof(const Instruction * I) { return isa(I); } + static bool classof(const ConstantExpr * CE) { return false; } + static bool classof(const Value *V) { return isa(V); } +}; + +class PredicatedOperator : public User { +public: + // The PredicatedOperator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedOperator() = delete; + ~PredicatedOperator() = delete; + + void *operator new(size_t s) = delete; + + /// Return the opcode for this Instruction or ConstantExpr. + unsigned getOpcode() const { + auto * VPInst = dyn_cast(this); + + + // Conceal the fp operation if it has non-default rounding mode or exception behavior + if (VPInst && !VPInst->isConstrainedOp()) { + return VPInst->getFunctionalOpcode(); + } + + if (const Instruction *I = dyn_cast(this)) + return I->getOpcode(); + + return cast(this)->getOpcode(); + } + + Value* getMask() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getMask(); + } + + Value* getVectorLength() const { + auto thisVP = dyn_cast(this); + if (!thisVP) return nullptr; + return thisVP->getVectorLength(); + } + + void copyIRFlags(const Value * V, bool IncludeWrapFlags = true); + FastMathFlags getFastMathFlags() const { + auto * I = dyn_cast(this); + if (I) return I->getFastMathFlags(); + else return FastMathFlags(); + } + + static bool classof(const Instruction * I) { return isa(I) || isa(I); } + static bool classof(const ConstantExpr * CE) { return isa(CE); } + static bool classof(const Value *V) { return isa(V) || isa(V); } +}; + +class PredicatedBinaryOperator : public PredicatedOperator { +public: + // The PredicatedBinaryOperator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedBinaryOperator() = delete; + ~PredicatedBinaryOperator() = delete; + + using BinaryOps = Instruction::BinaryOps; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->isBinaryOp(); + } + static bool classof(const ConstantExpr * CE) { return isa(CE); } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && classof(CE); + } + + /// Construct a predicated binary instruction, given the opcode and the two + /// operands. + static Instruction* Create(Module * Mod, + Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock * InsertAtEnd, + Instruction * InsertBefore); + + static Instruction* Create(Module *Mod, + Value *Mask, Value *VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name = Twine(), + Instruction *InsertBefore = nullptr) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, InsertBefore); + } + + static Instruction* Create(Module *Mod, + Value *Mask, Value *VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock *InsertAtEnd) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, InsertAtEnd, nullptr); + } + + static Instruction* CreateWithCopiedFlags(Module *Mod, + Value *Mask, Value* VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + Instruction *CopyBO, + const Twine &Name = "") { + Instruction *BO = Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, nullptr); + BO->copyIRFlags(CopyBO); + return BO; + } +}; + +class PredicatedICmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedICmpInst() = delete; + ~PredicatedICmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::ICmp; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::ICmp; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && classof(CE); + } + + ICmpInst::Predicate getPredicate() const { + auto * ICInst = dyn_cast(this); + if (ICInst) return ICInst->getPredicate(); + auto * CE = dyn_cast(this); + if (CE) return static_cast(CE->getPredicate()); + return static_cast(cast(this)->getCmpPredicate()); + } +}; + + +class PredicatedFCmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedFCmpInst() = delete; + ~PredicatedFCmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::FCmp; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::FCmp; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + return isa(V); + } + + FCmpInst::Predicate getPredicate() const { + auto * FCInst = dyn_cast(this); + if (FCInst) return FCInst->getPredicate(); + auto * CE = dyn_cast(this); + if (CE) return static_cast(CE->getPredicate()); + return static_cast(cast(this)->getCmpPredicate()); + } +}; + + +class PredicatedSelectInst : public PredicatedOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedSelectInst() = delete; + ~PredicatedSelectInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto VPInst = dyn_cast(I); + return VPInst && VPInst->getFunctionalOpcode() == Instruction::Select; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::Select; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && CE->getOpcode() == Instruction::Select; + } + + const Value *getCondition() const { return getOperand(0); } + const Value *getTrueValue() const { return getOperand(1); } + const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + + void setCondition(Value *V) { setOperand(0, V); } + void setTrueValue(Value *V) { setOperand(1, V); } + void setFalseValue(Value *V) { setOperand(2, V); } +}; + + +namespace PatternMatch { + +// PredicatedMatchContext for pattern matching +struct PredicatedContext { + Value * Mask; + Value * VectorLength; + Module * Mod; + + void reset(Value * V) { + auto * PI = dyn_cast(V); + if (!PI) { + VectorLength = nullptr; + Mask = nullptr; + Mod = nullptr; + } else { + VectorLength = PI->getVectorLength(); + Mask = PI->getMask(); + Mod = PI->getParent()->getParent()->getParent(); + } + } + + PredicatedContext(Value * Val) + : Mask(nullptr) + , VectorLength(nullptr) + , Mod(nullptr) + { + reset(Val); + } + + PredicatedContext(const PredicatedContext & PC) + : Mask(PC.Mask) + , VectorLength(PC.VectorLength) + , Mod(PC.Mod) + {} + + /// accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { + auto PredI = dyn_cast(Val); + if (!PredI) return VectorLength == nullptr && Mask == nullptr; + return VectorLength == PredI->getVectorLength() && Mask == PredI->getMask(); + } + + /// accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + /// whether this context is compatiable with \p E. + bool acceptContext(PredicatedContext PC) const { + return std::tie(PC.Mask, PC.VectorLength) == std::tie(Mask, VectorLength); + } + + /// merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(PredicatedContext PC) const { return acceptContext(PC); } + + /// match \p P in a new contest for \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + /// match \p P in the current context. + template bool try_match(Val *V, const Pattern &P) { + PredicatedContext SubContext(*this); + return const_cast(P).match_context(V, SubContext); + } +}; + +struct PredicatedContext; +template<> struct MatcherCast { using ActualCastType = PredicatedBinaryOperator; }; +template<> struct MatcherCast { using ActualCastType = PredicatedOperator; }; +template<> struct MatcherCast { using ActualCastType = PredicatedICmpInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedFCmpInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedSelectInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedInstruction; }; + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H diff --git a/llvm/include/llvm/IR/VPBuilder.h b/llvm/include/llvm/IR/VPBuilder.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/IR/VPBuilder.h @@ -0,0 +1,231 @@ +#ifndef LLVM_IR_VPBUILDER_H +#define LLVM_IR_VPBUILDER_H + +#include +#include +#include +#include +#include +#include + +namespace llvm { + +using ValArray = ArrayRef; + +class VPBuilder { + IRBuilder<> & Builder; + + // Explicit mask parameter + Value * Mask; + // Explicit vector length parameter + Value * ExplicitVectorLength; + // Compile-time vector length + int StaticVectorLength; + + // get a valid mask/evl argument for the current predication contet + Value& GetMaskForType(VectorType & VecTy); + Value& GetEVLForType(VectorType & VecTy); + +public: + VPBuilder(IRBuilder<> & _builder) + : Builder(_builder) + , Mask(nullptr) + , ExplicitVectorLength(nullptr) + , StaticVectorLength(-1) + {} + + Module & getModule() const; + LLVMContext & getContext() const { return Builder.getContext(); } + + // The cannonical vector type for this \p ElementTy + VectorType& getVectorType(Type &ElementTy); + + // Predication context tracker + VPBuilder& setMask(Value * _Mask) { Mask = _Mask; return *this; } + VPBuilder& setEVL(Value * _ExplicitVectorLength) { ExplicitVectorLength = _ExplicitVectorLength; return *this; } + VPBuilder& setStaticVL(int VLen) { StaticVectorLength = VLen; return *this; } + + // Create a map-vectorized copy of the instruction \p Inst with the underlying IRBuilder instance. + // This operation may return nullptr if the instruction could not be vectorized. + Value* CreateVectorCopy(Instruction & Inst, ValArray VecOpArray); + + // Memory + Value& CreateContiguousStore(Value & Val, Value & Pointer, Align Alignment); + Value& CreateContiguousLoad(Value & Pointer, Align Alignment); + Value& CreateScatter(Value & Val, Value & PointerVec, Align Alignment); + Value& CreateGather(Value & PointerVec, Align Alignment); +}; + + + + + +namespace PatternMatch { + // Factory class to generate instructions in a context + template + class MatchContextBuilder { + public: + // MatchContextBuilder(MatcherContext MC); + }; + + +// Context-free instruction builder +template<> +class MatchContextBuilder { +public: + MatchContextBuilder(EmptyContext & EC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Value *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const { \ + auto * Inst = BinaryOperator::Create(Instruction::OPC, V1, V2, Name); \ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + BinaryOperator *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FAdd, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, V1, V2, FMFSource, Name); + } + template + BinaryOperator *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + BinaryOperator *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FMul, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FDiv, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FRem, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, Zero, Op, FMFSource); + } + + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + + + +// Context-free instruction builder +template<> +class MatchContextBuilder { + PredicatedContext & PC; +public: + MatchContextBuilder(PredicatedContext & PC) : PC(PC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + auto * PredInst = Create##OPC(V1, V2, Name); Builder.Insert(PredInst); return PredInst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + Instruction *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FAdd, V1, V2, FMFSource, Name); + } + Instruction *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, V1, V2, FMFSource, Name); + } + template + Instruction *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + Instruction *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FMul, V1, V2, FMFSource, Name); + } + Instruction *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FDiv, V1, V2, FMFSource, Name); + } + Instruction *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FRem, V1, V2, FMFSource, Name); + } + Instruction *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, Zero, Op, FMFSource); + } + + // TODO predicated casts + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + +} + +} // namespace llvm + +#endif // LLVM_IR_VPBUILDER_H diff --git a/llvm/include/llvm/Target/TargetSelectionDAG.td b/llvm/include/llvm/Target/TargetSelectionDAG.td --- a/llvm/include/llvm/Target/TargetSelectionDAG.td +++ b/llvm/include/llvm/Target/TargetSelectionDAG.td @@ -128,6 +128,13 @@ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<3> ]>; +def SDTIntBinOpVP : SDTypeProfile<1, 4, [ // vp_add, vp_and, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTIntShiftOpVP : SDTypeProfile<1, 4, [ // shl, sra, srl + SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; + def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc. SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0> ]>; @@ -173,6 +180,16 @@ SDTCisOpSmallerThanOp<1, 0> ]>; +def SDTFPUnOpVP : SDTypeProfile<1, 3, [ // vp_fneg, etc. + SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<3>, SDTCisSameNumEltsAs<0, 2> +]>; +def SDTFPBinOpVP : SDTypeProfile<1, 4, [ // vp_fadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTFPTernaryOpVP : SDTypeProfile<1, 5, [ // vp_fmadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSetCC : SDTypeProfile<1, 3, [ // setcc SDTCisInt<0>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT> ]>; @@ -185,6 +202,10 @@ SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1> ]>; +def SDTVSelectVP : SDTypeProfile<1, 5, [ // vp_vselect + SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSelectCC : SDTypeProfile<1, 5, [ // select_cc SDTCisSameAs<1, 2>, SDTCisSameAs<3, 4>, SDTCisSameAs<0, 3>, SDTCisVT<5, OtherVT> @@ -228,11 +249,20 @@ SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTStoreVP: SDTypeProfile<0, 4, [ // evl store + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3> +]>; + def SDTMaskedLoad: SDTypeProfile<1, 3, [ // masked load SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameAs<0, 3>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTLoadVP : SDTypeProfile<1, 3, [ // evl load + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3>, + SDTCisSameNumEltsAs<0, 2> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; @@ -391,6 +421,26 @@ def umax : SDNode<"ISD::UMAX" , SDTIntBinOp, [SDNPCommutative, SDNPAssociative]>; +def vp_and : SDNode<"ISD::VP_AND" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_or : SDNode<"ISD::VP_OR" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_xor : SDNode<"ISD::VP_XOR" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_srl : SDNode<"ISD::VP_SRL" , SDTIntShiftOpVP>; +def vp_sra : SDNode<"ISD::VP_SRA" , SDTIntShiftOpVP>; +def vp_shl : SDNode<"ISD::VP_SHL" , SDTIntShiftOpVP>; + +def vp_add : SDNode<"ISD::VP_ADD" , SDTIntBinOpVP , + [SDNPCommutative, SDNPAssociative]>; +def vp_sub : SDNode<"ISD::VP_SUB" , SDTIntBinOpVP>; +def vp_mul : SDNode<"ISD::VP_MUL" , SDTIntBinOpVP, + [SDNPCommutative, SDNPAssociative]>; +def vp_sdiv : SDNode<"ISD::VP_SDIV" , SDTIntBinOpVP>; +def vp_udiv : SDNode<"ISD::VP_UDIV" , SDTIntBinOpVP>; +def vp_srem : SDNode<"ISD::VP_SREM" , SDTIntBinOpVP>; +def vp_urem : SDNode<"ISD::VP_UREM" , SDTIntBinOpVP>; + def saddsat : SDNode<"ISD::SADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def uaddsat : SDNode<"ISD::UADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def ssubsat : SDNode<"ISD::SSUBSAT" , SDTIntBinOp>; @@ -473,6 +523,14 @@ def fpextend : SDNode<"ISD::FP_EXTEND" , SDTFPExtendOp>; def fcopysign : SDNode<"ISD::FCOPYSIGN" , SDTFPSignOp>; +def vp_fneg : SDNode<"ISD::VP_FNEG" , SDTFPUnOpVP>; +def vp_fadd : SDNode<"ISD::VP_FADD" , SDTFPBinOpVP, [SDNPCommutative]>; +def vp_fsub : SDNode<"ISD::VP_FSUB" , SDTFPBinOpVP>; +def vp_fmul : SDNode<"ISD::VP_FMUL" , SDTFPBinOpVP, [SDNPCommutative]>; +def vp_fdiv : SDNode<"ISD::VP_FDIV" , SDTFPBinOpVP>; +def vp_frem : SDNode<"ISD::VP_FREM" , SDTFPBinOpVP>; +def vp_fma : SDNode<"ISD::VP_FMA" , SDTFPTernaryOpVP>; + def sint_to_fp : SDNode<"ISD::SINT_TO_FP" , SDTIntToFPOp>; def uint_to_fp : SDNode<"ISD::UINT_TO_FP" , SDTIntToFPOp>; def fp_to_sint : SDNode<"ISD::FP_TO_SINT" , SDTFPToIntOp>; @@ -610,6 +668,11 @@ def masked_ld : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def vp_store : SDNode<"ISD::VP_STORE", SDTStoreVP, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def vp_load : SDNode<"ISD::VP_LOAD", SDTLoadVP, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). def ld : SDNode<"ISD::LOAD" , SDTLoad, diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include @@ -4545,8 +4546,10 @@ /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. -static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { +template +static Value *SimplifyFSubInstGeneric(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, MatchContext & MC) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; @@ -4554,26 +4557,26 @@ return C; // fsub X, +0 ==> X - if (match(Op1, m_PosZeroFP())) + if (MC.try_match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZeroFP()) && + if (MC.try_match(Op1, m_NegZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X // fsub -0.0, (fneg X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FNeg(m_Value(X)))) + if (MC.try_match(Op0, m_NegZeroFP()) && + MC.try_match(Op1, m_FNeg(m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. // fsub 0.0, (fneg X) ==> X if signed zeros are ignored. if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && - (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || - match(Op1, m_FNeg(m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) || + MC.try_match(Op1, m_FNeg(m_Value(X))))) return X; // fsub nnan x, x ==> 0.0 @@ -4583,8 +4586,8 @@ // Y - (Y - X) --> X // (X + Y) - Y --> X if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || - match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + MC.try_match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) return X; return nullptr; @@ -4639,9 +4642,26 @@ } + +/// Given operands for an FSub, see if we can fold the result. +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) + return C; + + EmptyContext EC; + return SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, EC); +} + Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); + // Now apply simplifications that do not require rounding. + return SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); +} + +Value *llvm::SimplifyPredicatedFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, PredicatedContext & PC) { + return ::SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, PC); } Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, @@ -5232,9 +5252,20 @@ return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); } +Value *llvm::SimplifyVPIntrinsic(VPIntrinsic & VPInst, const SimplifyQuery &Q) { + PredicatedContext PC(&VPInst); + + auto & PI = cast(VPInst); + switch (PI.getOpcode()) { + default: + return nullptr; + + case Instruction::FSub: return SimplifyPredicatedFSubInst(VPInst.getOperand(0), VPInst.getOperand(1), VPInst.getFastMathFlags(), Q, PC); + } +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. - Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); @@ -5371,6 +5402,13 @@ Result = SimplifyPHINode(cast(I), Q); break; case Instruction::Call: { + auto * VPInst = dyn_cast(I); + if (VPInst) { + Result = SimplifyVPIntrinsic(*VPInst, Q); + if (Result) break; + } + + CallSite CS((I)); Result = SimplifyCall(cast(I), Q); break; } diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -645,6 +645,7 @@ KEYWORD(inlinehint); KEYWORD(inreg); KEYWORD(jumptable); + KEYWORD(mask); KEYWORD(minsize); KEYWORD(naked); KEYWORD(nest); @@ -666,6 +667,7 @@ KEYWORD(optforfuzzing); KEYWORD(optnone); KEYWORD(optsize); + KEYWORD(passthru); KEYWORD(readnone); KEYWORD(readonly); KEYWORD(returned); @@ -689,6 +691,7 @@ KEYWORD(swiftself); KEYWORD(uwtable); KEYWORD(willreturn); + KEYWORD(vlen); KEYWORD(writeonly); KEYWORD(zeroext); KEYWORD(immarg); diff --git a/llvm/lib/AsmParser/LLParser.cpp b/llvm/lib/AsmParser/LLParser.cpp --- a/llvm/lib/AsmParser/LLParser.cpp +++ b/llvm/lib/AsmParser/LLParser.cpp @@ -1338,15 +1338,18 @@ case lltok::kw_dereferenceable: case lltok::kw_dereferenceable_or_null: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_noalias: case lltok::kw_nocapture: case lltok::kw_nonnull: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute on a function"); @@ -1633,10 +1636,12 @@ } case lltok::kw_inalloca: B.addAttribute(Attribute::InAlloca); break; case lltok::kw_inreg: B.addAttribute(Attribute::InReg); break; + case lltok::kw_mask: B.addAttribute(Attribute::Mask); break; case lltok::kw_nest: B.addAttribute(Attribute::Nest); break; case lltok::kw_noalias: B.addAttribute(Attribute::NoAlias); break; case lltok::kw_nocapture: B.addAttribute(Attribute::NoCapture); break; case lltok::kw_nonnull: B.addAttribute(Attribute::NonNull); break; + case lltok::kw_passthru: B.addAttribute(Attribute::Passthru); break; case lltok::kw_readnone: B.addAttribute(Attribute::ReadNone); break; case lltok::kw_readonly: B.addAttribute(Attribute::ReadOnly); break; case lltok::kw_returned: B.addAttribute(Attribute::Returned); break; @@ -1644,6 +1649,7 @@ case lltok::kw_sret: B.addAttribute(Attribute::StructRet); break; case lltok::kw_swifterror: B.addAttribute(Attribute::SwiftError); break; case lltok::kw_swiftself: B.addAttribute(Attribute::SwiftSelf); break; + case lltok::kw_vlen: B.addAttribute(Attribute::VectorLength); break; case lltok::kw_writeonly: B.addAttribute(Attribute::WriteOnly); break; case lltok::kw_zeroext: B.addAttribute(Attribute::ZExt); break; case lltok::kw_immarg: B.addAttribute(Attribute::ImmArg); break; @@ -1736,13 +1742,16 @@ // Error handling. case lltok::kw_byval: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_nocapture: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: case lltok::kw_immarg: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute"); break; @@ -3411,7 +3420,7 @@ ID.Kind = ValID::t_Constant; return false; } - + // Unary Operators. case lltok::kw_fneg: { unsigned Opc = Lex.getUIntVal(); @@ -3421,7 +3430,7 @@ ParseGlobalTypeAndValue(Val) || ParseToken(lltok::rparen, "expected ')' in unary constantexpr")) return true; - + // Check that the type is valid for the operator. switch (Opc) { case Instruction::FNeg: @@ -4757,7 +4766,7 @@ OPTIONAL(declaration, MDField, ); \ OPTIONAL(name, MDStringField, ); \ OPTIONAL(file, MDField, ); \ - OPTIONAL(line, LineField, ); + OPTIONAL(line, LineField, ); PARSE_MD_FIELDS(); #undef VISIT_MD_FIELDS diff --git a/llvm/lib/AsmParser/LLToken.h b/llvm/lib/AsmParser/LLToken.h --- a/llvm/lib/AsmParser/LLToken.h +++ b/llvm/lib/AsmParser/LLToken.h @@ -191,6 +191,7 @@ kw_inlinehint, kw_inreg, kw_jumptable, + kw_mask, kw_minsize, kw_naked, kw_nest, @@ -212,6 +213,7 @@ kw_optforfuzzing, kw_optnone, kw_optsize, + kw_passthru, kw_readnone, kw_readonly, kw_returned, @@ -232,6 +234,7 @@ kw_swiftself, kw_uwtable, kw_willreturn, + kw_vlen, kw_writeonly, kw_zeroext, kw_immarg, diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1438,6 +1438,8 @@ return Attribute::InReg; case bitc::ATTR_KIND_JUMP_TABLE: return Attribute::JumpTable; + case bitc::ATTR_KIND_MASK: + return Attribute::Mask; case bitc::ATTR_KIND_MIN_SIZE: return Attribute::MinSize; case bitc::ATTR_KIND_NAKED: @@ -1486,6 +1488,8 @@ return Attribute::OptimizeForSize; case bitc::ATTR_KIND_OPTIMIZE_NONE: return Attribute::OptimizeNone; + case bitc::ATTR_KIND_PASSTHRU: + return Attribute::Passthru; case bitc::ATTR_KIND_READ_NONE: return Attribute::ReadNone; case bitc::ATTR_KIND_READ_ONLY: @@ -1532,6 +1536,8 @@ return Attribute::UWTable; case bitc::ATTR_KIND_WILLRETURN: return Attribute::WillReturn; + case bitc::ATTR_KIND_VECTORLENGTH: + return Attribute::VectorLength; case bitc::ATTR_KIND_WRITEONLY: return Attribute::WriteOnly; case bitc::ATTR_KIND_Z_EXT: diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -677,6 +677,12 @@ return bitc::ATTR_KIND_READ_ONLY; case Attribute::Returned: return bitc::ATTR_KIND_RETURNED; + case Attribute::Mask: + return bitc::ATTR_KIND_MASK; + case Attribute::VectorLength: + return bitc::ATTR_KIND_VECTORLENGTH; + case Attribute::Passthru: + return bitc::ATTR_KIND_PASSTHRU; case Attribute::ReturnsTwice: return bitc::ATTR_KIND_RETURNS_TWICE; case Attribute::SExt: diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -427,6 +427,7 @@ SDValue visitBITCAST(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); + SDValue visitFADD_VP(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); SDValue visitFMA(SDNode *N); @@ -475,6 +476,7 @@ SDValue visitFP16_TO_FP(SDNode *N); SDValue visitVECREDUCE(SDNode *N); + template SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -736,6 +738,137 @@ void NodeInserted(SDNode *N) override { DC.ConsiderForPruning(N); } }; +struct EmptyMatchContext { + SelectionDAG & DAG; + + EmptyMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + {} + + bool match(SDValue OpN, unsigned OpCode) const { return OpCode == OpN->getOpcode(); } + + unsigned getFunctionOpCode(SDValue N) const { + return N->getOpcode(); + } + + bool isCompatible(SDValue OpVal) const { return true; } + + // Specialize based on number of operands. + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, Operand, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4, SDValue N5) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4, N5); + } +}; + +struct +VPMatchContext { + SelectionDAG & DAG; + SDNode * Root; + SDValue RootMaskOp; + SDValue RootVectorLenOp; + + VPMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + , Root(Root) + , RootMaskOp() + , RootVectorLenOp() + { + if (Root->isVP()) { + int RootMaskPos = ISD::GetMaskPosVP(Root->getOpcode()); + if (RootMaskPos != -1) { + RootMaskOp = Root->getOperand(RootMaskPos); + } + + int RootVLenPos = ISD::GetVectorLengthPosVP(Root->getOpcode()); + if (RootVLenPos != -1) { + RootVectorLenOp = Root->getOperand(RootVLenPos); + } + } + } + + unsigned getFunctionOpCode(SDValue N) const { + unsigned VPOpCode = N->getOpcode(); + return ISD::GetFunctionOpCodeForVP(VPOpCode, N->getFlags().hasFPExcept()); + } + + bool isCompatible(SDValue OpVal) const { + if (!OpVal->isVP()) { + return !Root->isVP(); + + } else { + unsigned VPOpCode = OpVal->getOpcode(); + int MaskPos = ISD::GetMaskPosVP(VPOpCode); + if (MaskPos != -1 && RootMaskOp != OpVal.getOperand(MaskPos)) { + return false; + } + + int VLenPos = ISD::GetVectorLengthPosVP(VPOpCode); + if (VLenPos != -1 && RootVectorLenOp != OpVal.getOperand(VLenPos)) { + return false; + } + + return true; + } + } + + /// whether \p OpN is a node that is functionally compatible with the NodeType \p OpNodeTy + bool match(SDValue OpVal, unsigned OpNT) const { + return isCompatible(OpVal) && getFunctionOpCode(OpVal) == OpNT; + } + + // Specialize based on number of operands. + // TODO emit VP intrinsics where MaskOp/VectorLenOp != null + // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 1 && VLenPos == 2); + + return DAG.getNode(VPOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 2 && VLenPos == 3); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned VPOpcode = ISD::GetVPForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosVP(VPOpcode); + int VLenPos = ISD::GetVectorLengthPosVP(VPOpcode); + assert(MaskPos == 3 && VLenPos == 4); + + return DAG.getNode(VPOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -1560,6 +1693,7 @@ case ISD::BITCAST: return visitBITCAST(N); case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); case ISD::FADD: return visitFADD(N); + case ISD::VP_FADD: return visitFADD_VP(N); case ISD::FSUB: return visitFSUB(N); case ISD::FMUL: return visitFMUL(N); case ISD::FMA: return visitFMA(N); @@ -11320,13 +11454,18 @@ return F.hasAllowContract() || F.hasAllowReassociation(); } + /// Try to perform FMA combining on a given FADD node. +template SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + MatchContextClass matcher(DAG, N); + if (!matcher.isCompatible(N0) || !matcher.isCompatible(N1)) return SDValue(); + const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. @@ -11359,8 +11498,8 @@ // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || isContractable(N.getNode()); }; @@ -11373,42 +11512,42 @@ // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1, Flags); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0, Flags); } // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if ((N0.getOpcode() == ISD::FP_EXTEND) && matcher.isCompatible(N0.getOperand(0))) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1, Flags); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0, Flags); } } @@ -11417,12 +11556,12 @@ if (Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N0, PreferredFusedOpcode) && + matcher.match(N0.getOperand(2), ISD::FMUL) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), N1, Flags), Flags); @@ -11430,12 +11569,12 @@ // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N1, PreferredFusedOpcode) && + matcher.match(N1.getOperand(2), ISD::FMUL) && N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), N0, Flags), Flags); @@ -11447,15 +11586,15 @@ auto FoldFAddFMAFPExtFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, X, Y, + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; - if (N0.getOpcode() == PreferredFusedOpcode) { + if (matcher.match(N0, PreferredFusedOpcode)) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) { @@ -11474,12 +11613,12 @@ auto FoldFAddFPExtFMAFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, X), + matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -11916,6 +12055,15 @@ return SDValue(); } +SDValue DAGCombiner::visitFADD_VP(SDNode *N) { + // FADD -> FMA combines: + if (SDValue Fused = visitFADDForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -12088,7 +12236,7 @@ } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -992,7 +992,7 @@ } // Handle promotion for the ADDE/SUBE/ADDCARRY/SUBCARRY nodes. Notice that -// the third operand of ADDE/SUBE nodes is carry flag, which differs from +// the third operand of ADDE/SUBE nodes is carry flag, which differs from // the ADDCARRY/SUBCARRY nodes in that the third operand is carry Boolean. SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo) { if (ResNo == 1) @@ -1144,6 +1144,9 @@ return false; } + if (N->isVP()) { + Res = PromoteIntOp_VP(N, OpNo); + } else { switch (N->getOpcode()) { default: #ifndef NDEBUG @@ -1222,6 +1225,7 @@ case ISD::VECREDUCE_UMAX: case ISD::VECREDUCE_UMIN: Res = PromoteIntOp_VECREDUCE(N); break; } + } // If the result is null, the sub-method took care of registering results etc. if (!Res.getNode()) return false; @@ -1502,6 +1506,25 @@ TruncateStore, N->isCompressingStore()); } +SDValue DAGTypeLegalizer::PromoteIntOp_VP(SDNode *N, unsigned OpNo) { + EVT DataVT; + switch (N->getOpcode()) { + default: + DataVT = N->getValueType(0); + break; + + case ISD::VP_STORE: + case ISD::VP_SCATTER: + llvm_unreachable("TODO implement VP memory nodes"); + } + + // TODO assert that \p OpNo is the mask + SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = Mask; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo) { assert(OpNo == 2 && "Only know how to promote the mask!"); diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -341,6 +341,7 @@ SDValue PromoteIntRes_VECREDUCE(SDNode *N); SDValue PromoteIntRes_ABS(SDNode *N); + // Integer Operand Promotion. bool PromoteIntegerOperand(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_ANY_EXTEND(SDNode *N); @@ -376,6 +377,7 @@ SDValue PromoteIntOp_MULFIX(SDNode *N); SDValue PromoteIntOp_FPOWI(SDNode *N); SDValue PromoteIntOp_VECREDUCE(SDNode *N); + SDValue PromoteIntOp_VP(SDNode *N, unsigned OpNo); void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -428,6 +428,213 @@ return Result; } +//===----------------------------------------------------------------------===// +// SDNode VP Support +//===----------------------------------------------------------------------===// + +int +ISD::GetMaskPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case ISD::VP_FNEG: + return 1; + + case ISD::VP_ADD: + case ISD::VP_SUB: + case ISD::VP_MUL: + case ISD::VP_SDIV: + case ISD::VP_SREM: + case ISD::VP_UDIV: + case ISD::VP_UREM: + + case ISD::VP_AND: + case ISD::VP_OR: + case ISD::VP_XOR: + case ISD::VP_SHL: + case ISD::VP_SRA: + case ISD::VP_SRL: + + case ISD::VP_FADD: + case ISD::VP_FMUL: + case ISD::VP_FSUB: + case ISD::VP_FDIV: + case ISD::VP_FREM: + return 2; + + case ISD::VP_FMA: + case ISD::VP_SELECT: + return 3; + + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 1; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + return 2; + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + // (implicit) case ISD::VP_COMPOSE: return -1 + } +} + +int +ISD::GetVectorLengthPosVP(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case VP_SELECT: + return 0; + + case VP_FNEG: + return 2; + + case VP_ADD: + case VP_SUB: + case VP_MUL: + case VP_SDIV: + case VP_SREM: + case VP_UDIV: + case VP_UREM: + + case VP_AND: + case VP_OR: + case VP_XOR: + case VP_SHL: + case VP_SRA: + case VP_SRL: + + case VP_FADD: + case VP_FMUL: + case VP_FDIV: + case VP_FREM: + return 3; + + case VP_FMA: + return 4; + + case VP_COMPOSE: + return 3; + + case VP_REDUCE_ADD: + case VP_REDUCE_MUL: + case VP_REDUCE_AND: + case VP_REDUCE_OR: + case VP_REDUCE_XOR: + case VP_REDUCE_SMAX: + case VP_REDUCE_SMIN: + case VP_REDUCE_UMAX: + case VP_REDUCE_UMIN: + case VP_REDUCE_FMAX: + case VP_REDUCE_FMIN: + return 2; + + case VP_REDUCE_FADD: + case VP_REDUCE_FMUL: + return 3; + + } +} + +unsigned +ISD::GetFunctionOpCodeForVP(unsigned OpCode, bool hasFPExcept) { + switch (OpCode) { + default: return OpCode; + + case VP_SELECT: return ISD::VSELECT; + case VP_ADD: return ISD::ADD; + case VP_SUB: return ISD::SUB; + case VP_MUL: return ISD::MUL; + case VP_SDIV: return ISD::SDIV; + case VP_SREM: return ISD::SREM; + case VP_UDIV: return ISD::UDIV; + case VP_UREM: return ISD::UREM; + + case VP_AND: return ISD::AND; + case VP_OR: return ISD::OR; + case VP_XOR: return ISD::XOR; + case VP_SHL: return ISD::SHL; + case VP_SRA: return ISD::SRA; + case VP_SRL: return ISD::SRL; + + case VP_FNEG: return ISD::FNEG; + case VP_FADD: return hasFPExcept ? ISD::STRICT_FADD : ISD::FADD; + case VP_FSUB: return hasFPExcept ? ISD::STRICT_FSUB : ISD::FSUB; + case VP_FMUL: return hasFPExcept ? ISD::STRICT_FMUL : ISD::FMUL; + case VP_FDIV: return hasFPExcept ? ISD::STRICT_FDIV : ISD::FDIV; + case VP_FREM: return hasFPExcept ? ISD::STRICT_FREM : ISD::FREM; + + case VP_REDUCE_AND: return VECREDUCE_AND; + case VP_REDUCE_OR: return VECREDUCE_OR; + case VP_REDUCE_XOR: return VECREDUCE_XOR; + case VP_REDUCE_ADD: return VECREDUCE_ADD; + case VP_REDUCE_FADD: return VECREDUCE_FADD; + case VP_REDUCE_FMUL: return VECREDUCE_FMUL; + case VP_REDUCE_FMAX: return VECREDUCE_FMAX; + case VP_REDUCE_FMIN: return VECREDUCE_FMIN; + case VP_REDUCE_UMAX: return VECREDUCE_UMAX; + case VP_REDUCE_UMIN: return VECREDUCE_UMIN; + case VP_REDUCE_SMAX: return VECREDUCE_SMAX; + case VP_REDUCE_SMIN: return VECREDUCE_SMIN; + + case VP_STORE: return ISD::MSTORE; + case VP_LOAD: return ISD::MLOAD; + case VP_GATHER: return ISD::MGATHER; + case VP_SCATTER: return ISD::MSCATTER; + + case VP_FMA: return hasFPExcept ? ISD::STRICT_FMA : ISD::FMA; + } +} + +unsigned +ISD::GetVPForFunctionOpCode(unsigned OpCode) { + switch (OpCode) { + default: llvm_unreachable("can not translate this Opcode to VP"); + + case VSELECT: return ISD::VP_SELECT; + case ADD: return ISD::VP_ADD; + case SUB: return ISD::VP_SUB; + case MUL: return ISD::VP_MUL; + case SDIV: return ISD::VP_SDIV; + case SREM: return ISD::VP_SREM; + case UDIV: return ISD::VP_UDIV; + case UREM: return ISD::VP_UREM; + + case AND: return ISD::VP_AND; + case OR: return ISD::VP_OR; + case XOR: return ISD::VP_XOR; + case SHL: return ISD::VP_SHL; + case SRA: return ISD::VP_SRA; + case SRL: return ISD::VP_SRL; + + case FNEG: return ISD::VP_FNEG; + case STRICT_FADD: + case FADD: return ISD::VP_FADD; + case STRICT_FSUB: + case FSUB: return ISD::VP_FSUB; + case STRICT_FMUL: + case FMUL: return ISD::VP_FMUL; + case STRICT_FDIV: + case FDIV: return ISD::VP_FDIV; + case STRICT_FREM: + case FREM: return ISD::VP_FREM; + + case STRICT_FMA: + case FMA: return ISD::VP_FMA; + } +} + + //===----------------------------------------------------------------------===// // SDNode Profile Support //===----------------------------------------------------------------------===// @@ -558,6 +765,34 @@ ID.AddInteger(ST->getPointerInfo().getAddrSpace()); break; } + case ISD::VP_LOAD: { + const VPLoadSDNode *ELD = cast(N); + ID.AddInteger(ELD->getMemoryVT().getRawBits()); + ID.AddInteger(ELD->getRawSubclassData()); + ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_STORE: { + const VPStoreSDNode *EST = cast(N); + ID.AddInteger(EST->getMemoryVT().getRawBits()); + ID.AddInteger(EST->getRawSubclassData()); + ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_GATHER: { + const VPGatherSDNode *EG = cast(N); + ID.AddInteger(EG->getMemoryVT().getRawBits()); + ID.AddInteger(EG->getRawSubclassData()); + ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + break; + } + case ISD::VP_SCATTER: { + const VPScatterSDNode *ES = cast(N); + ID.AddInteger(ES->getMemoryVT().getRawBits()); + ID.AddInteger(ES->getRawSubclassData()); + ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + break; + } case ISD::MLOAD: { const MaskedLoadSDNode *MLD = cast(N); ID.AddInteger(MLD->getMemoryVT().getRawBits()); @@ -6989,6 +7224,142 @@ return V; } +SDValue SelectionDAG::getLoadVP(EVT VT, const SDLoc &dl, SDValue Chain, + SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + ISD::LoadExtType ExtTy) { + SDVTList VTs = getVTList(VT, MVT::Other); + SDValue Ops[] = { Chain, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, ExtTy, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + ExtTy, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + + +SDValue SelectionDAG::getStoreVP(SDValue Chain, const SDLoc &dl, + SDValue Val, SDValue Ptr, SDValue Mask, + SDValue VLen, EVT MemVT, MachineMemOperand *MMO, + bool IsTruncating) { + assert(Chain.getValueType() == MVT::Other && + "Invalid chain type"); + SDVTList VTs = getVTList(MVT::Other); + SDValue Ops[] = { Chain, Val, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::MSTORE, VTs, Ops); + ID.AddInteger(MemVT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, IsTruncating, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + IsTruncating, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getGatherVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 6 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_GATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getScatterVP(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO, + ISD::MemIndexType IndexType) { + assert(Ops.size() == 7 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::VP_SCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO, IndexType)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO, IndexType); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -747,6 +747,12 @@ void visitIntrinsicCall(const CallInst &I, unsigned Intrinsic); void visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic); void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI); + void visitExplicitVectorLengthIntrinsic(const VPIntrinsic &VPI); + void visitCmpVP(const VPIntrinsic &I); + void visitLoadVP(const CallInst &I); + void visitStoreVP(const CallInst &I); + void visitGatherVP(const CallInst &I); + void visitScatterVP(const CallInst &I); void visitVAStart(const CallInst &I); void visitVAArg(const VAArgInst &I); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4313,6 +4313,46 @@ setValue(&I, StoreNode); } +void SelectionDAGBuilder::visitStoreVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getVPStoreOps = [&](Value* &Ptr, Value* &Mask, Value* &Src0, + Value * &VLen, unsigned & Alignment) { + // llvm.masked.store.*(Src0, Ptr, Mask, VLen) + Src0 = I.getArgOperand(0); + Ptr = I.getArgOperand(1); + Alignment = I.getParamAlignment(1); + Mask = I.getArgOperand(2); + VLen = I.getArgOperand(3); + }; + + Value *PtrOperand, *MaskOperand, *Src0Operand, *VLenOperand; + unsigned Alignment = 0; + getVPStoreOps(PtrOperand, MaskOperand, Src0Operand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue Src0 = getValue(Src0Operand); + SDValue Mask = getValue(MaskOperand); + SDValue VLen = getValue(VLenOperand); + + EVT VT = Src0.getValueType(); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + SDValue StoreNode = DAG.getStoreVP(getRoot(), sdl, Src0, Ptr, Mask, VLen, VT, + MMO, false /* Truncating */); + DAG.setRoot(StoreNode); + setValue(&I, StoreNode); +} + // Get a uniform base for the Gather/Scatter intrinsic. // The first argument of the Gather/Scatter intrinsic is a vector of pointers. // We try to represent it as a base pointer + vector of indices. @@ -4547,6 +4587,160 @@ setValue(&I, Gather); } +void SelectionDAGBuilder::visitGatherVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // @llvm.evl.gather.*(Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(0); + SDValue Mask = getValue(I.getArgOperand(1)); + SDValue VLen = getValue(I.getArgOperand(2)); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + unsigned Alignment = I.getParamAlignment(0); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + SDValue Root = DAG.getRoot(); + SDValue Base; + SDValue Index; + ISD::MemIndexType IndexType; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); + bool ConstantMemory = false; + if (UniformBase && AA && + AA->pointsToConstantMemory( + MemoryLocation(BasePtr, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo))) { + // Do not serialize (non-volatile) loads of constant memory with anything. + Root = DAG.getEntryNode(); + ConstantMemory = true; + } + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { Root, Base, Index, Scale, Mask, VLen }; + SDValue Gather = DAG.getGatherVP(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO, IndexType); + + SDValue OutChain = Gather.getValue(1); + if (!ConstantMemory) + PendingLoads.push_back(OutChain); + setValue(&I, Gather); +} + +void SelectionDAGBuilder::visitScatterVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // llvm.evl.scatter.*(Src0, Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(1); + SDValue Src0 = getValue(I.getArgOperand(0)); + SDValue Mask = getValue(I.getArgOperand(2)); + SDValue VLen = getValue(I.getArgOperand(3)); + EVT VT = Src0.getValueType(); + unsigned Alignment = I.getParamAlignment(1); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + SDValue Base; + SDValue Index; + ISD::MemIndexType IndexType; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, IndexType, Scale, this); + + const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { getRoot(), Src0, Base, Index, Scale, Mask, VLen }; + SDValue Scatter = DAG.getScatterVP(DAG.getVTList(MVT::Other), VT, sdl, + Ops, MMO, IndexType); + DAG.setRoot(Scatter); + setValue(&I, Scatter); +} + +void SelectionDAGBuilder::visitLoadVP(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getMaskedLoadOps = [&](Value* &Ptr, Value* &Mask, Value* &VLen, + unsigned& Alignment) { + // @llvm.evl.load.*(Ptr, Mask, Vlen) + Ptr = I.getArgOperand(0); + Alignment = I.getParamAlignment(0); + Mask = I.getArgOperand(1); + VLen = I.getArgOperand(2); + }; + + Value *PtrOperand, *MaskOperand, *VLenOperand; + unsigned Alignment; + getMaskedLoadOps(PtrOperand, MaskOperand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue VLen = getValue(VLenOperand); + SDValue Mask = getValue(MaskOperand); + + // infer the return type + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + SmallVector ValValueVTs; + ComputeValueVTs(TLI, DAG.getDataLayout(), I.getType(), ValValueVTs); + EVT VT = ValValueVTs[0]; + assert((ValValueVTs.size() == 1) && "splitting not implemented"); + + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + // Do not serialize masked loads of constant memory with anything. + bool AddToChain = + !AA || !AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo)); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + SDValue Load = DAG.getLoadVP(VT, sdl, InChain, Ptr, Mask, VLen, VT, MMO, + ISD::NON_EXTLOAD); + if (AddToChain) + PendingLoads.push_back(Load.getValue(1)); + setValue(&I, Load); +} + void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) { SDLoc dl = getCurSDLoc(); AtomicOrdering SuccessOrdering = I.getSuccessOrdering(); @@ -6140,6 +6334,64 @@ case Intrinsic::experimental_constrained_trunc: visitConstrainedFPIntrinsic(cast(I)); return; + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + case Intrinsic::vp_fcmp: + case Intrinsic::vp_icmp: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_smin: + visitExplicitVectorLengthIntrinsic(cast(I)); + return; + case Intrinsic::fmuladd: { EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); if (TM.Options.AllowFPOpFusion != FPOpFusion::Strict && @@ -7006,8 +7258,7 @@ { Chain, getValue(FPI.getArgOperand(0)), getValue(FPI.getArgOperand(1)) }); - if (FPI.getExceptionBehavior() != - ConstrainedFPIntrinsic::ExceptionBehavior::ebIgnore) { + if (FPI.getExceptionBehavior() != ExceptionBehavior::ebIgnore) { SDNodeFlags Flags; Flags.setFPExcept(true); Result->setFlags(Flags); @@ -7020,6 +7271,139 @@ setValue(&FPI, FPResult); } +void SelectionDAGBuilder::visitCmpVP(const VPIntrinsic &I) { + ISD::CondCode Condition; + CmpInst::Predicate predicate = I.getCmpPredicate(); + bool IsFP = I.getOperand(0)->getType()->isFPOrFPVectorTy(); + if (IsFP) { + Condition = getFCmpCondCode(predicate); + auto *FPMO = dyn_cast(&I); + if ((FPMO && FPMO->hasNoNaNs()) || TM.Options.NoNaNsFPMath) + Condition = getFCmpCodeWithoutNaN(Condition); + + } else { + Condition = getICmpCondCode(predicate); + } + + SDValue Op1 = getValue(I.getOperand(0)); + SDValue Op2 = getValue(I.getOperand(1)); + + EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(), + I.getType()); + setValue(&I, DAG.getSetCC(getCurSDLoc(), DestVT, Op1, Op2, Condition)); +} + +void SelectionDAGBuilder::visitExplicitVectorLengthIntrinsic( + const VPIntrinsic & VPIntrin) { + SDLoc sdl = getCurSDLoc(); + unsigned Opcode; + switch (VPIntrin.getIntrinsicID()) { + default: + llvm_unreachable("Unforeseen intrinsic"); // Can't reach here. + + case Intrinsic::vp_load: visitLoadVP(VPIntrin); return; + case Intrinsic::vp_store: visitStoreVP(VPIntrin); return; + case Intrinsic::vp_gather: visitGatherVP(VPIntrin); return; + case Intrinsic::vp_scatter: visitScatterVP(VPIntrin); return; + + case Intrinsic::vp_fcmp: + case Intrinsic::vp_icmp: visitCmpVP(VPIntrin); return; + + case Intrinsic::vp_add: Opcode = ISD::VP_ADD; break; + case Intrinsic::vp_sub: Opcode = ISD::VP_SUB; break; + case Intrinsic::vp_mul: Opcode = ISD::VP_MUL; break; + case Intrinsic::vp_udiv: Opcode = ISD::VP_UDIV; break; + case Intrinsic::vp_sdiv: Opcode = ISD::VP_SDIV; break; + case Intrinsic::vp_urem: Opcode = ISD::VP_UREM; break; + case Intrinsic::vp_srem: Opcode = ISD::VP_SREM; break; + + case Intrinsic::vp_and: Opcode = ISD::VP_AND; break; + case Intrinsic::vp_or: Opcode = ISD::VP_OR; break; + case Intrinsic::vp_xor: Opcode = ISD::VP_XOR; break; + case Intrinsic::vp_ashr: Opcode = ISD::VP_SRA; break; + case Intrinsic::vp_lshr: Opcode = ISD::VP_SRL; break; + case Intrinsic::vp_shl: Opcode = ISD::VP_SHL; break; + + case Intrinsic::vp_fneg: Opcode = ISD::VP_FNEG; break; + case Intrinsic::vp_fadd: Opcode = ISD::VP_FADD; break; + case Intrinsic::vp_fsub: Opcode = ISD::VP_FSUB; break; + case Intrinsic::vp_fmul: Opcode = ISD::VP_FMUL; break; + case Intrinsic::vp_fdiv: Opcode = ISD::VP_FDIV; break; + case Intrinsic::vp_frem: Opcode = ISD::VP_FREM; break; + + case Intrinsic::vp_fma: Opcode = ISD::VP_FMA; break; + + case Intrinsic::vp_select: Opcode = ISD::VP_SELECT; break; + case Intrinsic::vp_compose: Opcode = ISD::VP_COMPOSE; break; + case Intrinsic::vp_compress: Opcode = ISD::VP_COMPRESS; break; + case Intrinsic::vp_expand: Opcode = ISD::VP_EXPAND; break; + case Intrinsic::vp_vshift: Opcode = ISD::VP_VSHIFT; break; + + case Intrinsic::vp_reduce_and: Opcode = ISD::VP_REDUCE_AND; break; + case Intrinsic::vp_reduce_or: Opcode = ISD::VP_REDUCE_OR; break; + case Intrinsic::vp_reduce_xor: Opcode = ISD::VP_REDUCE_XOR; break; + case Intrinsic::vp_reduce_add: Opcode = ISD::VP_REDUCE_ADD; break; + case Intrinsic::vp_reduce_mul: Opcode = ISD::VP_REDUCE_MUL; break; + case Intrinsic::vp_reduce_fadd: Opcode = ISD::VP_REDUCE_FADD; break; + case Intrinsic::vp_reduce_fmul: Opcode = ISD::VP_REDUCE_FMUL; break; + case Intrinsic::vp_reduce_smax: Opcode = ISD::VP_REDUCE_SMAX; break; + case Intrinsic::vp_reduce_smin: Opcode = ISD::VP_REDUCE_SMIN; break; + case Intrinsic::vp_reduce_umax: Opcode = ISD::VP_REDUCE_UMAX; break; + case Intrinsic::vp_reduce_umin: Opcode = ISD::VP_REDUCE_UMIN; break; + } + + // TODO memory evl: SDValue Chain = getRoot(); + + SmallVector ValueVTs; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + ComputeValueVTs(TLI, DAG.getDataLayout(), VPIntrin.getType(), ValueVTs); + SDVTList VTs = DAG.getVTList(ValueVTs); + + // ValueVTs.push_back(MVT::Other); // Out chain + + + SDValue Result; + + switch (VPIntrin.getNumArgOperands()) { + default: + llvm_unreachable("unexpected number of arguments to evl intrinsic"); + case 3: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(VPIntrin.getArgOperand(0)), + getValue(VPIntrin.getArgOperand(1)), + getValue(VPIntrin.getArgOperand(2)) }); + break; + + case 4: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(VPIntrin.getArgOperand(0)), + getValue(VPIntrin.getArgOperand(1)), + getValue(VPIntrin.getArgOperand(2)), + getValue(VPIntrin.getArgOperand(3)) }); + break; + + case 5: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(VPIntrin.getArgOperand(0)), + getValue(VPIntrin.getArgOperand(1)), + getValue(VPIntrin.getArgOperand(2)), + getValue(VPIntrin.getArgOperand(3)), + getValue(VPIntrin.getArgOperand(4)) }); + break; + } + + if (Result.getNode()->getNumValues() == 2) { + // this VP node has a chain + SDValue OutChain = Result.getValue(1); + DAG.setRoot(OutChain); + SDValue VPResult = Result.getValue(0); + setValue(&VPIntrin, VPResult); + } else { + // this is a pure node + setValue(&VPIntrin, Result); + } +} + std::pair SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI, const BasicBlock *EHPadBB) { diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -439,6 +439,65 @@ case ISD::VECREDUCE_UMIN: return "vecreduce_umin"; case ISD::VECREDUCE_FMAX: return "vecreduce_fmax"; case ISD::VECREDUCE_FMIN: return "vecreduce_fmin"; + + // Explicit Vector Length erxtension + // VP Memory + case ISD::VP_LOAD: return "vp_load"; + case ISD::VP_STORE: return "vp_store"; + case ISD::VP_GATHER: return "vp_gather"; + case ISD::VP_SCATTER: return "vp_scatter"; + + // VP Unary operators + case ISD::VP_FNEG: return "vp_fneg"; + + // VP Binary operators + case ISD::VP_ADD: return "vp_add"; + case ISD::VP_SUB: return "vp_sub"; + case ISD::VP_MUL: return "vp_mul"; + case ISD::VP_SDIV: return "vp_sdiv"; + case ISD::VP_UDIV: return "vp_udiv"; + case ISD::VP_SREM: return "vp_srem"; + case ISD::VP_UREM: return "vp_urem"; + case ISD::VP_AND: return "vp_and"; + case ISD::VP_OR: return "vp_or"; + case ISD::VP_XOR: return "vp_xor"; + case ISD::VP_SHL: return "vp_shl"; + case ISD::VP_SRA: return "vp_sra"; + case ISD::VP_SRL: return "vp_srl"; + case ISD::VP_FADD: return "vp_fadd"; + case ISD::VP_FSUB: return "vp_fsub"; + case ISD::VP_FMUL: return "vp_fmul"; + case ISD::VP_FDIV: return "vp_fdiv"; + case ISD::VP_FREM: return "vp_frem"; + + // VP comparison + case ISD::VP_SETCC: return "vp_setcc"; + + // VP ternary operators + case ISD::VP_FMA: return "vp_fma"; + + // VP shuffle + case ISD::VP_VSHIFT: return "vp_vshift"; + case ISD::VP_COMPRESS: return "vp_compress"; + case ISD::VP_EXPAND: return "vp_expand"; + + case ISD::VP_COMPOSE: return "vp_compose"; + case ISD::VP_SELECT: return "vp_select"; + + // VP reduction operators + case ISD::VP_REDUCE_FADD: return "vp_reduce_fadd"; + case ISD::VP_REDUCE_FMUL: return "vp_reduce_fmul"; + case ISD::VP_REDUCE_ADD: return "vp_reduce_add"; + case ISD::VP_REDUCE_MUL: return "vp_reduce_mul"; + case ISD::VP_REDUCE_AND: return "vp_reduce_and"; + case ISD::VP_REDUCE_OR: return "vp_reduce_or"; + case ISD::VP_REDUCE_XOR: return "vp_reduce_xor"; + case ISD::VP_REDUCE_SMAX: return "vp_reduce_smax"; + case ISD::VP_REDUCE_SMIN: return "vp_reduce_smin"; + case ISD::VP_REDUCE_UMAX: return "vp_reduce_umax"; + case ISD::VP_REDUCE_UMIN: return "vp_reduce_umin"; + case ISD::VP_REDUCE_FMAX: return "vp_reduce_fmax"; + case ISD::VP_REDUCE_FMIN: return "vp_reduce_fmin"; } } diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -816,6 +816,10 @@ CurDAG->Combine(BeforeLegalizeTypes, AA, OptLevel); } + if (getenv("SDEBUG")) { + CurDAG->dump(); + } + #ifndef NDEBUG if (TTI.hasBranchDivergence()) CurDAG->VerifyDAGDiverence(); diff --git a/llvm/lib/IR/Attributes.cpp b/llvm/lib/IR/Attributes.cpp --- a/llvm/lib/IR/Attributes.cpp +++ b/llvm/lib/IR/Attributes.cpp @@ -290,6 +290,8 @@ return "builtin"; if (hasAttribute(Attribute::Convergent)) return "convergent"; + if (hasAttribute(Attribute::VectorLength)) + return "vlen"; if (hasAttribute(Attribute::SwiftError)) return "swifterror"; if (hasAttribute(Attribute::SwiftSelf)) @@ -306,6 +308,10 @@ return "inreg"; if (hasAttribute(Attribute::JumpTable)) return "jumptable"; + if (hasAttribute(Attribute::Mask)) + return "mask"; + if (hasAttribute(Attribute::Passthru)) + return "passthru"; if (hasAttribute(Attribute::MinSize)) return "minsize"; if (hasAttribute(Attribute::Naked)) diff --git a/llvm/lib/IR/CMakeLists.txt b/llvm/lib/IR/CMakeLists.txt --- a/llvm/lib/IR/CMakeLists.txt +++ b/llvm/lib/IR/CMakeLists.txt @@ -46,18 +46,19 @@ PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp + ProfileSummary.cpp RemarkStreamer.cpp SafepointIRVerifier.cpp - ProfileSummary.cpp Statepoint.cpp Type.cpp TypeFinder.cpp Use.cpp User.cpp + VPBuilder.cpp Value.cpp ValueSymbolTable.cpp Verifier.cpp - ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/IR diff --git a/llvm/lib/IR/IntrinsicInst.cpp b/llvm/lib/IR/IntrinsicInst.cpp --- a/llvm/lib/IR/IntrinsicInst.cpp +++ b/llvm/lib/IR/IntrinsicInst.cpp @@ -102,55 +102,108 @@ return ConstantInt::get(Type::getInt64Ty(Context), 1); } -Optional -ConstrainedFPIntrinsic::getRoundingMode() const { - unsigned NumOperands = getNumArgOperands(); - Metadata *MD = - cast(getArgOperand(NumOperands - 2))->getMetadata(); - if (!MD || !isa(MD)) - return None; - return StrToRoundingMode(cast(MD)->getString()); +Optional +llvm::StrToExceptionBehavior(StringRef ExceptionArg) { + return StringSwitch>(ExceptionArg) + .Case("fpexcept.ignore", ExceptionBehavior::ebIgnore) + .Case("fpexcept.maytrap", ExceptionBehavior::ebMayTrap) + .Case("fpexcept.strict", ExceptionBehavior::ebStrict) + .Default(None); +} + +Optional +llvm::ExceptionBehaviorToStr(ExceptionBehavior UseExcept) { + Optional ExceptStr = None; + switch (UseExcept) { + default: break; + case ExceptionBehavior::ebStrict: + ExceptStr = "fpexcept.strict"; + break; + case ExceptionBehavior::ebIgnore: + ExceptStr = "fpexcept.ignore"; + break; + case ExceptionBehavior::ebMayTrap: + ExceptStr = "fpexcept.maytrap"; + break; + } + return ExceptStr; } -Optional -ConstrainedFPIntrinsic::StrToRoundingMode(StringRef RoundingArg) { +Optional +llvm::StrToRoundingMode(StringRef RoundingArg) { // For dynamic rounding mode, we use round to nearest but we will set the // 'exact' SDNodeFlag so that the value will not be rounded. return StringSwitch>(RoundingArg) - .Case("round.dynamic", rmDynamic) - .Case("round.tonearest", rmToNearest) - .Case("round.downward", rmDownward) - .Case("round.upward", rmUpward) - .Case("round.towardzero", rmTowardZero) + .Case("round.dynamic", RoundingMode::rmDynamic) + .Case("round.tonearest", RoundingMode::rmToNearest) + .Case("round.downward", RoundingMode::rmDownward) + .Case("round.upward", RoundingMode::rmUpward) + .Case("round.towardzero", RoundingMode::rmTowardZero) .Default(None); } Optional -ConstrainedFPIntrinsic::RoundingModeToStr(RoundingMode UseRounding) { +llvm::RoundingModeToStr(RoundingMode UseRounding) { Optional RoundingStr = None; switch (UseRounding) { - case ConstrainedFPIntrinsic::rmDynamic: + default: break; + case RoundingMode::rmDynamic: RoundingStr = "round.dynamic"; break; - case ConstrainedFPIntrinsic::rmToNearest: + case RoundingMode::rmToNearest: RoundingStr = "round.tonearest"; break; - case ConstrainedFPIntrinsic::rmDownward: + case RoundingMode::rmDownward: RoundingStr = "round.downward"; break; - case ConstrainedFPIntrinsic::rmUpward: + case RoundingMode::rmUpward: RoundingStr = "round.upward"; break; - case ConstrainedFPIntrinsic::rmTowardZero: + case RoundingMode::rmTowardZero: RoundingStr = "round.towardzero"; break; } return RoundingStr; } -Optional +/// Return the IR Value representation of any ExceptionBehavior. +Value* +llvm::GetConstrainedFPExcept(LLVMContext& Context, ExceptionBehavior UseExcept) { + Optional ExceptStr = + ExceptionBehaviorToStr(UseExcept); + assert(ExceptStr.hasValue() && "Garbage strict exception behavior!"); + auto *ExceptMDS = MDString::get(Context, ExceptStr.getValue()); + + return MetadataAsValue::get(Context, ExceptMDS); +} + +/// Return the IR Value representation of any RoundingMode. +Value* +llvm::GetConstrainedFPRounding(LLVMContext& Context, RoundingMode UseRounding) { + Optional RoundingStr = + RoundingModeToStr(UseRounding); + assert(RoundingStr.hasValue() && "Garbage strict rounding mode!"); + auto *RoundingMDS = MDString::get(Context, RoundingStr.getValue()); + + return MetadataAsValue::get(Context, RoundingMDS); +} + + +Optional +ConstrainedFPIntrinsic::getRoundingMode() const { + unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 2 && "underflow"); + Metadata *MD = + cast(getArgOperand(NumOperands - 2))->getMetadata(); + if (!MD || !isa(MD)) + return None; + return StrToRoundingMode(cast(MD)->getString()); +} + +Optional ConstrainedFPIntrinsic::getExceptionBehavior() const { unsigned NumOperands = getNumArgOperands(); + assert(NumOperands >= 1 && "underflow"); Metadata *MD = cast(getArgOperand(NumOperands - 1))->getMetadata(); if (!MD || !isa(MD)) @@ -158,32 +211,617 @@ return StrToExceptionBehavior(cast(MD)->getString()); } -Optional -ConstrainedFPIntrinsic::StrToExceptionBehavior(StringRef ExceptionArg) { - return StringSwitch>(ExceptionArg) - .Case("fpexcept.ignore", ebIgnore) - .Case("fpexcept.maytrap", ebMayTrap) - .Case("fpexcept.strict", ebStrict) - .Default(None); + +CmpInst::Predicate +VPIntrinsic::getCmpPredicate() const { + return static_cast(cast(getArgOperand(4))->getZExtValue()); } -Optional -ConstrainedFPIntrinsic::ExceptionBehaviorToStr(ExceptionBehavior UseExcept) { - Optional ExceptStr = None; - switch (UseExcept) { - case ConstrainedFPIntrinsic::ebStrict: - ExceptStr = "fpexcept.strict"; - break; - case ConstrainedFPIntrinsic::ebIgnore: - ExceptStr = "fpexcept.ignore"; - break; - case ConstrainedFPIntrinsic::ebMayTrap: - ExceptStr = "fpexcept.maytrap"; - break; +Optional +VPIntrinsic::getRoundingMode() const { + auto RmParamPos = getRoundingModeParamPos(getIntrinsicID()); + if (!RmParamPos) return None; + + Metadata *MD = + dyn_cast(getArgOperand(RmParamPos.getValue()))->getMetadata(); + if (!MD || !isa(MD)) + return None; + StringRef RoundingArg = cast(MD)->getString(); + return StrToRoundingMode(RoundingArg); +} + +Optional +VPIntrinsic::getExceptionBehavior() const { + auto EbParamPos = getExceptionBehaviorParamPos(getIntrinsicID()); + if (!EbParamPos) return None; + + Metadata *MD = + dyn_cast(getArgOperand(EbParamPos.getValue()))->getMetadata(); + if (!MD || !isa(MD)) + return None; + StringRef ExceptionArg = cast(MD)->getString(); + return StrToExceptionBehavior(ExceptionArg); +} + +bool +VPIntrinsic::hasRoundingModeParam(Intrinsic::ID VPID) { + switch (VPID) { + default: + return false; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_fadd: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_floor: + case Intrinsic::vp_fma: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fneg: + case Intrinsic::vp_frem: + case Intrinsic::vp_fsub: + case Intrinsic::vp_llround: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + case Intrinsic::vp_lround: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_lrint: + case Intrinsic::vp_llrint: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return true; + } +} + +bool +VPIntrinsic::hasExceptionBehaviorParam(Intrinsic::ID VPID) { + switch (VPID) { + default: + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptrunc: + case Intrinsic::vp_llround: + case Intrinsic::vp_lround: + return false; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_fadd: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_floor: + case Intrinsic::vp_fma: + case Intrinsic::vp_fmul: + case Intrinsic::vp_frem: + case Intrinsic::vp_fsub: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_lrint: + case Intrinsic::vp_llrint: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return true; + } +} + +Value* +VPIntrinsic::getMask() const { + auto maskPos = getMaskParamPos(getIntrinsicID()); + if (maskPos) return getArgOperand(maskPos.getValue()); + return nullptr; +} + +Value* +VPIntrinsic::getVectorLength() const { + auto vlenPos = getVectorLengthParamPos(getIntrinsicID()); + if (vlenPos) return getArgOperand(vlenPos.getValue()); + return nullptr; +} + +VPIntrinsic::TypeTokenVec +VPIntrinsic::GetTypeTokens(Intrinsic::ID ID) { + switch (ID) { + default: + return TypeTokenVec(); + + case Intrinsic::vp_cos: + case Intrinsic::vp_sin: + case Intrinsic::vp_exp: + case Intrinsic::vp_exp2: + + case Intrinsic::vp_log: + case Intrinsic::vp_log2: + case Intrinsic::vp_log10: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_ceil: + case Intrinsic::vp_floor: + case Intrinsic::vp_round: + case Intrinsic::vp_trunc: + case Intrinsic::vp_rint: + case Intrinsic::vp_nearbyint: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return TypeTokenVec{VPTypeToken::Returned}; + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + return TypeTokenVec{VPTypeToken::Vector}; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + case Intrinsic::vp_lrint: + case Intrinsic::vp_llrint: + return TypeTokenVec{VPTypeToken::Returned, VPTypeToken::Vector}; + + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + return TypeTokenVec{VPTypeToken::Mask, VPTypeToken::Vector}; } - return ExceptStr; } +bool VPIntrinsic::isReductionOp() const { + switch (getIntrinsicID()) { + default: + return false; + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + + return true; + } +} + +bool VPIntrinsic::isConstrainedOp() const { + return (getRoundingMode() != None && getRoundingMode() != RoundingMode::rmToNearest) || + (getExceptionBehavior() != None && getExceptionBehavior() != ExceptionBehavior::ebIgnore); +} + +bool VPIntrinsic::isBinaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_umin: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmax: + case Intrinsic::vp_reduce_fmin: + + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + return true; + } +} + +bool VPIntrinsic::isTernaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + case Intrinsic::vp_compose: + case Intrinsic::vp_select: + case Intrinsic::vp_fma: + return true; + } +} + +Optional +VPIntrinsic::getMaskParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: return None; + + // general cmp + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + return 2; + + // int arith + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + return 2; + + // memory + case Intrinsic::vp_load: + case Intrinsic::vp_gather: + return 1; + case Intrinsic::vp_store: + case Intrinsic::vp_scatter: + return 2; + + // shuffle + case Intrinsic::vp_select: + return 0; + + case Intrinsic::vp_compose: + return None; + + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + return 2; + + // fp arith + case Intrinsic::vp_fneg: + return 1; + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + return 4; + + case Intrinsic::vp_fma: + return 5; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + return 3; + + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return 4; + + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return 3; + + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + return 2; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + return 3; + + // reductions + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + return 2; + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + return 3; + } +} + +Optional +VPIntrinsic::getVectorLengthParamPos(Intrinsic::ID IntrinsicID) { + auto maskPos = getMaskParamPos(IntrinsicID); + if (maskPos) { + return maskPos.getValue() + 1; + } + + if (IntrinsicID == Intrinsic::vp_compose) { + return 3; + } + + return None; +} + +Optional +VPIntrinsic::getExceptionBehaviorParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + return 3; + + case Intrinsic::vp_fma: + return 4; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + return 2; + + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return 3; + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return 2; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + return 2; + + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + return 1; + } +} + +Optional +VPIntrinsic::getRoundingModeParamPos(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return None; + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + return 2; + + case Intrinsic::vp_fma: + return 3; + + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + return 1; + + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + return 2; + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + return 1; + + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + return None; + + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + return 2; + } +} + +Intrinsic::ID +VPIntrinsic::getForConstrainedIntrinsic(Intrinsic::ID IntrinsicID) { + switch (IntrinsicID) { + default: + return Intrinsic::not_intrinsic; + + // llvm.experimental.constrained.* + case Intrinsic::experimental_constrained_cos: return Intrinsic::vp_cos; + case Intrinsic::experimental_constrained_sin: return Intrinsic::vp_sin; + case Intrinsic::experimental_constrained_exp: return Intrinsic::vp_exp; + case Intrinsic::experimental_constrained_exp2: return Intrinsic::vp_exp2; + case Intrinsic::experimental_constrained_log: return Intrinsic::vp_log; + case Intrinsic::experimental_constrained_log2: return Intrinsic::vp_log2; + case Intrinsic::experimental_constrained_log10: return Intrinsic::vp_log10; + case Intrinsic::experimental_constrained_sqrt: return Intrinsic::vp_sqrt; + case Intrinsic::experimental_constrained_ceil: return Intrinsic::vp_ceil; + case Intrinsic::experimental_constrained_floor: return Intrinsic::vp_floor; + case Intrinsic::experimental_constrained_round: return Intrinsic::vp_round; + case Intrinsic::experimental_constrained_trunc: return Intrinsic::vp_trunc; + case Intrinsic::experimental_constrained_rint: return Intrinsic::vp_rint; + case Intrinsic::experimental_constrained_nearbyint: return Intrinsic::vp_nearbyint; + + case Intrinsic::experimental_constrained_fadd: return Intrinsic::vp_fadd; + case Intrinsic::experimental_constrained_fsub: return Intrinsic::vp_fsub; + case Intrinsic::experimental_constrained_fmul: return Intrinsic::vp_fmul; + case Intrinsic::experimental_constrained_fdiv: return Intrinsic::vp_fdiv; + case Intrinsic::experimental_constrained_frem: return Intrinsic::vp_frem; + case Intrinsic::experimental_constrained_pow: return Intrinsic::vp_pow; + case Intrinsic::experimental_constrained_powi: return Intrinsic::vp_powi; + case Intrinsic::experimental_constrained_maxnum: return Intrinsic::vp_maxnum; + case Intrinsic::experimental_constrained_minnum: return Intrinsic::vp_minnum; + + case Intrinsic::experimental_constrained_fma: return Intrinsic::fma; + } +} + +Intrinsic::ID +VPIntrinsic::getForOpcode(unsigned OC) { + switch (OC) { + default: + return Intrinsic::not_intrinsic; + + // fp unary + case Instruction::FNeg: return Intrinsic::vp_fneg; + + // fp binary + case Instruction::FAdd: return Intrinsic::vp_fadd; + case Instruction::FSub: return Intrinsic::vp_fsub; + case Instruction::FMul: return Intrinsic::vp_fmul; + case Instruction::FDiv: return Intrinsic::vp_fdiv; + case Instruction::FRem: return Intrinsic::vp_frem; + + // sign-oblivious int + case Instruction::Add: return Intrinsic::vp_add; + case Instruction::Sub: return Intrinsic::vp_sub; + case Instruction::Mul: return Intrinsic::vp_mul; + + // signed/unsigned int + case Instruction::SDiv: return Intrinsic::vp_sdiv; + case Instruction::UDiv: return Intrinsic::vp_udiv; + case Instruction::SRem: return Intrinsic::vp_srem; + case Instruction::URem: return Intrinsic::vp_urem; + + // logical + case Instruction::Or: return Intrinsic::vp_or; + case Instruction::And: return Intrinsic::vp_and; + case Instruction::Xor: return Intrinsic::vp_xor; + + case Instruction::LShr: return Intrinsic::vp_lshr; + case Instruction::AShr: return Intrinsic::vp_ashr; + case Instruction::Shl: return Intrinsic::vp_shl; + + // comparison + case Instruction::ICmp: return Intrinsic::vp_icmp; + case Instruction::FCmp: return Intrinsic::vp_fcmp; + } +} + +VPIntrinsic::ShortTypeVec +VPIntrinsic::EncodeTypeTokens(VPIntrinsic::TypeTokenVec TTVec, Type * VecRetTy, Type & VectorTy, Type & ScalarTy) { + ShortTypeVec STV; + + for (auto Token : TTVec) { + switch (Token) { + default: + llvm_unreachable("unsupported token"); // unsupported VPTypeToken + + case VPIntrinsic::VPTypeToken::Scalar: STV.push_back(&ScalarTy); break; + case VPIntrinsic::VPTypeToken::Vector: STV.push_back(&VectorTy); break; + case VPIntrinsic::VPTypeToken::Returned: + assert(VecRetTy); + STV.push_back(VecRetTy); + break; + case VPIntrinsic::VPTypeToken::Mask: + auto NumElems = VectorTy.getVectorNumElements(); + auto MaskTy = VectorType::get(Type::getInt1Ty(VectorTy.getContext()), NumElems); + STV.push_back(MaskTy); break; + } + } + + return STV; +} + + bool ConstrainedFPIntrinsic::isUnaryOp() const { switch (getIntrinsicID()) { default: diff --git a/llvm/lib/IR/PredicatedInst.cpp b/llvm/lib/IR/PredicatedInst.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/PredicatedInst.cpp @@ -0,0 +1,75 @@ +#include +#include +#include +#include +#include + +namespace { + using namespace llvm; + using ShortValueVec = SmallVector; +} + +namespace llvm { + +void +PredicatedOperator::copyIRFlags(const Value * V, bool IncludeWrapFlags) { + auto * I = dyn_cast(this); + if (I) I->copyIRFlags(V, IncludeWrapFlags); +} + +Instruction* +PredicatedBinaryOperator::Create(Module * Mod, + Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock * InsertAtEnd, + Instruction * InsertBefore) { + assert(!(InsertAtEnd && InsertBefore)); + auto VPID = VPIntrinsic::getForOpcode(Opc); + + // Default Code Path + if ((!Mod || + (!Mask && !VectorLen)) || + VPID == Intrinsic::not_intrinsic) { + if (InsertAtEnd) { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertAtEnd); + } else { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertBefore); + } + } + + assert(Mod && "Need a module to emit VP Intrinsics"); + + // Fetch the VP intrinsic + auto & VecTy = cast(*V1->getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + auto * VPFunc = Intrinsic::getDeclaration(Mod, VPID, VPIntrinsic::EncodeTypeTokens(TypeTokens, &VecTy, VecTy, ScalarTy)); + + // Encode default environment fp behavior + LLVMContext & Ctx = V1->getContext(); + SmallVector BinOpArgs({V1, V2}); + if (VPIntrinsic::hasRoundingModeParam(VPID)) { + BinOpArgs.push_back(GetConstrainedFPRounding(Ctx, RoundingMode::rmToNearest)); + } + if (VPIntrinsic::hasExceptionBehaviorParam(VPID)) { + BinOpArgs.push_back(GetConstrainedFPExcept(Ctx, ExceptionBehavior::ebIgnore)); + } + + BinOpArgs.push_back(Mask); + BinOpArgs.push_back(VectorLen); + + CallInst * CI; + if (InsertAtEnd) { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertAtEnd); + } else { + CI = CallInst::Create(VPFunc, BinOpArgs, Name, InsertBefore); + } + + // the VP inst does not touch memory if the exception behavior is "fpecept.ignore" + CI->setDoesNotAccessMemory(); + return CI; +} + +} diff --git a/llvm/lib/IR/VPBuilder.cpp b/llvm/lib/IR/VPBuilder.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/IR/VPBuilder.cpp @@ -0,0 +1,201 @@ +#include +#include +#include +#include +#include + +namespace { + using namespace llvm; + using ShortTypeVec = VPIntrinsic::ShortTypeVec; + using ShortValueVec = SmallVector; +} + +namespace llvm { + +Module & +VPBuilder::getModule() const { + return *Builder.GetInsertBlock()->getParent()->getParent(); +} + +Value& +VPBuilder::GetMaskForType(VectorType & VecTy) { + if (Mask) return *Mask; + + auto * boolTy = Builder.getInt1Ty(); + auto * maskTy = VectorType::get(boolTy, StaticVectorLength); + return *ConstantInt::getAllOnesValue(maskTy); +} + +Value& +VPBuilder::GetEVLForType(VectorType & VecTy) { + if (ExplicitVectorLength) return *ExplicitVectorLength; + + auto * intTy = Builder.getInt32Ty(); + return *ConstantInt::get(intTy, StaticVectorLength); +} + +Value* +VPBuilder::CreateVectorCopy(Instruction & Inst, ValArray VecOpArray) { + auto OC = Inst.getOpcode(); + auto VPID = VPIntrinsic::getForOpcode(OC); + if (VPID == Intrinsic::not_intrinsic) { + return nullptr; + } + + // Regular binary instructions + if ((OC <= Instruction::BinaryOpsEnd) && + (OC >= Instruction::BinaryOpsBegin)) { + + assert(VecOpArray.size() == 2); + Value & FirstOp = *VecOpArray[0]; + Value & SndOp = *VecOpArray[1]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*FirstOp.getType()); + + auto & VPCall = + cast(*PredicatedBinaryOperator::Create(&getModule(), &GetMaskForType(VecTy), &GetEVLForType(VecTy), static_cast(OC), &FirstOp, &SndOp)); + Builder.Insert(&VPCall); + + // transfer fast math flags + if (isa(Inst)) { + VPCall.copyFastMathFlags(Inst.getFastMathFlags()); + } + + return &VPCall; + } + + // Regular unary instructions + if ((OC <= Instruction::UnaryOpsBegin) && + (OC >= Instruction::UnaryOpsEnd)) { + assert(VecOpArray.size() == 1); + Value & FirstOp = *VecOpArray[0]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*FirstOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + auto * VPFunc = Intrinsic::getDeclaration(&getModule(), VPID, VPIntrinsic::EncodeTypeTokens(TypeTokens, &VecTy, VecTy, ScalarTy)); + + // Materialize the Call + ShortValueVec Args{&FirstOp, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + + auto & VPCall = *Builder.CreateCall(VPFunc, Args); + + // transfer fast math flags + if (isa(Inst)) { + cast(VPCall).copyFastMathFlags(Inst.getFastMathFlags()); + } + + return &VPCall; + } + + // Special cases + switch (OC) { + default: + break; + + case Instruction::FCmp: + case Instruction::ICmp: { + assert(VecOpArray.size() == 2); + Value & FirstOp = *VecOpArray[0]; + Value & SndOp = *VecOpArray[1]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*FirstOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + auto * Func = Intrinsic::getDeclaration(&getModule(), VPID, VPIntrinsic::EncodeTypeTokens(TypeTokens, &VecTy, VecTy, ScalarTy)); + + assert((VPIntrinsic::getMaskParamPos(VPID) == 2) && (VPIntrinsic::getVectorLengthParamPos(VPID) == 3)); + + // encode comparison predicate as MD + uint8_t RawPred = cast(Inst).getPredicate(); + auto Int8Ty = Builder.getInt8Ty(); + auto PredArg = ConstantInt::get(Int8Ty, RawPred, false); + + // Materialize the Call + ShortValueVec Args{&FirstOp, &SndOp, &GetMaskForType(VecTy), &GetEVLForType(VecTy), PredArg}; + + return Builder.CreateCall(Func, Args); + } + + case Instruction::Select: { + assert(VecOpArray.size() == 2); + Value & MaskOp = *VecOpArray[0]; + Value & OnTrueOp = *VecOpArray[1]; + Value & OnFalseOp = *VecOpArray[2]; + + // Fetch the VP intrinsic + auto & VecTy = cast(*OnTrueOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto TypeTokens = VPIntrinsic::GetTypeTokens(VPID); + + auto * Func = Intrinsic::getDeclaration(&getModule(), VPID, VPIntrinsic::EncodeTypeTokens(TypeTokens, &VecTy, VecTy, ScalarTy)); + + assert((VPIntrinsic::getMaskParamPos(VPID) == 2) && (VPIntrinsic::getVectorLengthParamPos(VPID) == 3)); + + // Materialize the Call + ShortValueVec Args{&OnTrueOp, &OnFalseOp, &MaskOp, &GetEVLForType(VecTy)}; + + return Builder.CreateCall(Func, Args); + } + } + + // TODO VP reductions + // TODO VP casts + return nullptr; +} + + +VectorType& +VPBuilder::getVectorType(Type &ElementTy) { + return *VectorType::get(&ElementTy, StaticVectorLength); +} + +Value& +VPBuilder::CreateContiguousStore(Value & Val, Value & Pointer, Align Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * StoreFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_store, {Val.getType(), Pointer.getType()}); + ShortValueVec Args{&Val, &Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &StoreCall = *Builder.CreateCall(StoreFunc, Args); + if (Alignment != None) StoreCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return StoreCall; +} + +Value& +VPBuilder::CreateContiguousLoad(Value & Pointer, Align Alignment) { + auto & PointerTy = cast(*Pointer.getType()); + auto & VecTy = getVectorType(*PointerTy.getPointerElementType()); + + auto * LoadFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_load, {&VecTy, &PointerTy}); + ShortValueVec Args{&Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &LoadCall= *Builder.CreateCall(LoadFunc, Args); + if (Alignment != None) LoadCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return LoadCall; +} + +Value& +VPBuilder::CreateScatter(Value & Val, Value & PointerVec, Align Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * ScatterFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_scatter, {Val.getType(), PointerVec.getType()}); + ShortValueVec Args{&Val, &PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &ScatterCall = *Builder.CreateCall(ScatterFunc, Args); + if (Alignment != None) ScatterCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return ScatterCall; +} + +Value& +VPBuilder::CreateGather(Value & PointerVec, Align Alignment) { + auto & PointerVecTy = cast(*PointerVec.getType()); + auto & ElemTy = *cast(*PointerVecTy.getVectorElementType()).getPointerElementType(); + auto & VecTy = *VectorType::get(&ElemTy, PointerVecTy.getNumElements()); + auto * GatherFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::vp_gather, {&VecTy, &PointerVecTy}); + + ShortValueVec Args{&PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &GatherCall = *Builder.CreateCall(GatherFunc, Args); + if (Alignment != None) GatherCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return GatherCall; +} + +} // namespace llvm diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp --- a/llvm/lib/IR/Verifier.cpp +++ b/llvm/lib/IR/Verifier.cpp @@ -475,6 +475,7 @@ void visitUserOp2(Instruction &I) { visitUserOp1(I); } void visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call); void visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI); + void visitVPIntrinsic(VPIntrinsic &FPI); void visitDbgIntrinsic(StringRef Kind, DbgVariableIntrinsic &DII); void visitDbgLabelIntrinsic(StringRef Kind, DbgLabelInst &DLI); void visitAtomicCmpXchgInst(AtomicCmpXchgInst &CXI); @@ -1690,11 +1691,14 @@ if (Attrs.isEmpty()) return; + bool SawMask = false; bool SawNest = false; + bool SawPassthru = false; bool SawReturned = false; bool SawSRet = false; bool SawSwiftSelf = false; bool SawSwiftError = false; + bool SawVectorLength = false; // Verify return value attributes. AttributeSet RetAttrs = Attrs.getRetAttributes(); @@ -1762,12 +1766,33 @@ SawSwiftError = true; } + if (ArgAttrs.hasAttribute(Attribute::VectorLength)) { + Assert(!SawVectorLength, "Cannot have multiple 'vlen' parameters!", + V); + SawVectorLength = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Passthru)) { + Assert(!SawPassthru, "Cannot have multiple 'passthru' parameters!", + V); + SawPassthru = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Mask)) { + Assert(!SawMask, "Cannot have multiple 'mask' parameters!", + V); + SawMask = true; + } + if (ArgAttrs.hasAttribute(Attribute::InAlloca)) { Assert(i == FT->getNumParams() - 1, "inalloca isn't on the last parameter!", V); } } + Assert(!SawPassthru || SawMask, + "Cannot have 'passthru' parameter without 'mask' parameter!", V); + if (!Attrs.hasAttributes(AttributeList::FunctionIndex)) return; @@ -3133,7 +3158,7 @@ /// visitUnaryOperator - Check the argument to the unary operator. /// void Verifier::visitUnaryOperator(UnaryOperator &U) { - Assert(U.getType() == U.getOperand(0)->getType(), + Assert(U.getType() == U.getOperand(0)->getType(), "Unary operators must have same type for" "operands and result!", &U); @@ -4331,6 +4356,94 @@ case Intrinsic::experimental_constrained_trunc: visitConstrainedFPIntrinsic(cast(Call)); break; + + // general cmp + case Intrinsic::vp_icmp: + case Intrinsic::vp_fcmp: + + // int arith + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_ashr: + case Intrinsic::vp_lshr: + case Intrinsic::vp_shl: + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + + // memory + case Intrinsic::vp_load: + case Intrinsic::vp_store: + case Intrinsic::vp_gather: + case Intrinsic::vp_scatter: + + // shuffle + case Intrinsic::vp_select: + case Intrinsic::vp_compose: + case Intrinsic::vp_compress: + case Intrinsic::vp_expand: + case Intrinsic::vp_vshift: + + // fp arith + case Intrinsic::vp_fneg: + + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_frem: + + case Intrinsic::vp_fma: + case Intrinsic::vp_ceil: + case Intrinsic::vp_cos: + case Intrinsic::vp_exp2: + case Intrinsic::vp_exp: + case Intrinsic::vp_floor: + case Intrinsic::vp_log10: + case Intrinsic::vp_log2: + case Intrinsic::vp_log: + case Intrinsic::vp_maxnum: + case Intrinsic::vp_minnum: + case Intrinsic::vp_nearbyint: + case Intrinsic::vp_pow: + case Intrinsic::vp_powi: + case Intrinsic::vp_rint: + case Intrinsic::vp_round: + case Intrinsic::vp_sin: + case Intrinsic::vp_sqrt: + case Intrinsic::vp_trunc: + case Intrinsic::vp_fptoui: + case Intrinsic::vp_fptosi: + case Intrinsic::vp_fpext: + case Intrinsic::vp_fptrunc: + + case Intrinsic::vp_lround: + case Intrinsic::vp_llround: + + // reductions + case Intrinsic::vp_reduce_add: + case Intrinsic::vp_reduce_mul: + case Intrinsic::vp_reduce_umin: + case Intrinsic::vp_reduce_umax: + case Intrinsic::vp_reduce_smin: + case Intrinsic::vp_reduce_smax: + + case Intrinsic::vp_reduce_and: + case Intrinsic::vp_reduce_or: + case Intrinsic::vp_reduce_xor: + + case Intrinsic::vp_reduce_fadd: + case Intrinsic::vp_reduce_fmul: + case Intrinsic::vp_reduce_fmin: + case Intrinsic::vp_reduce_fmax: + visitVPIntrinsic(cast(Call)); + break; + case Intrinsic::dbg_declare: // llvm.dbg.declare Assert(isa(Call.getArgOperand(0)), "invalid llvm.dbg.declare intrinsic call 1", Call); @@ -4754,6 +4867,15 @@ return nullptr; } +void Verifier::visitVPIntrinsic(VPIntrinsic &VPI) { + if (VPI.isConstrainedOp()) { + Assert(VPI.getExceptionBehavior() != ExceptionBehavior::ebInvalid, + "invalid exception behavior argument", &VPI); + Assert(VPI.getRoundingMode() != RoundingMode::rmInvalid, + "invalid rounding mode argument", &VPI); + } +} + void Verifier::visitConstrainedFPIntrinsic(ConstrainedFPIntrinsic &FPI) { unsigned NumOperands = FPI.getNumArgOperands(); bool HasExceptionMD = false; @@ -5159,7 +5281,7 @@ bool runOnFunction(Function &F) override { if (!V->verify(F) && FatalErrors) { - errs() << "in function " << F.getName() << '\n'; + errs() << "in function " << F.getName() << '\n'; report_fatal_error("Broken function found, compilation aborted!"); } return false; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -24,6 +24,9 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/VPBuilder.h" +#include "llvm/IR/MatcherCast.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AlignOf.h" @@ -2087,6 +2090,17 @@ return nullptr; } +Instruction *InstCombiner::visitPredicatedFSub(PredicatedBinaryOperator& I) { + auto * Inst = cast(&I); + PredicatedContext PC(&I); + if (Value *V = SimplifyPredicatedFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(Inst), PC)) + return replaceInstUsesWith(*Inst, V); + + return visitFSubGeneric(*Inst); +} + Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -2096,11 +2110,19 @@ if (Instruction *X = foldVectorBinop(I)) return X; + return visitFSubGeneric(I); +} + +template +Instruction *InstCombiner::visitFSubGeneric(BinaryOpTy &I) { + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Subtraction from -0.0 is the canonical form of fneg. // fsub nsz 0, X ==> fsub nsz -0.0, X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) - return BinaryOperator::CreateFNegFMF(Op1, &I); + if (I.hasNoSignedZeros() && MC.try_match(Op0, m_PosZeroFP())) + return MCBuilder.CreateFNegFMF(Op1, &I); if (Instruction *X = foldFNegIntoConstant(I)) return X; @@ -2111,6 +2133,18 @@ Value *X, *Y; Constant *C; + // Fold negation into constant operand. This is limited with one-use because + // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. + // -(X * C) --> X * (-C) + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + // -(X / C) --> X / (-C) + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + // -(C / X) --> (-C) / X + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) + return MCBuilder.CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. // This can also help codegen because fadd is commutative. @@ -2118,36 +2152,37 @@ // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { - if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { - Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); - return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + if (MC.try_match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = MCBuilder.CreateFSubFMF(Builder, Y, X, &I); + return MCBuilder.CreateFAddFMF(Op0, NewSub, &I); } } - if (isa(Op0)) - if (SelectInst *SI = dyn_cast(Op1)) - if (Instruction *NV = FoldOpIntoSelect(I, SI)) - return NV; + if (auto * PlainBinOp = dyn_cast(&I)) + if (isa(Op0)) + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *NV = FoldOpIntoSelect(*PlainBinOp, SI)) + return NV; // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - if (match(Op1, m_Constant(C)) && !isa(Op1)) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(Op1, m_Constant(C)) && !isa(Op1)) + return MCBuilder.CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); // X - (-Y) --> X + Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + if (MC.try_match(Op1, m_FNeg(m_Value(Y)))) + return MCBuilder.CreateFAddFMF(Op0, Y, &I); // Similar to above, but look through a cast of the negated value: // X - (fptrunc(-Y)) --> X + fptrunc(Y) Type *Ty = I.getType(); - if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPTrunc(Builder, Y, Ty), &I); // X - (fpext(-Y)) --> X + fpext(Y) - if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPExt(Builder, Y, Ty), &I); // Similar to above, but look through fmul/fdiv of the negated value: // Op0 - (-X * Y) --> Op0 + (X * Y) @@ -2165,39 +2200,42 @@ } // Handle special cases for FSub with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + if (auto * PlainBinOp = dyn_cast(&I)) + if (Value *V = SimplifySelectsFeedingBinaryOp(*PlainBinOp, Op0, Op1)) + return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X - if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X - if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) - if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + if (MC.try_match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); - return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + return MCBuilder.CreateFMulFMF(Op1, CSubOne, &I); } // X - (X * C) --> X * (1.0 - C) - if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + if (MC.try_match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); - return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + return MCBuilder.CreateFMulFMF(Op0, OneSubC, &I); } - if (Instruction *F = factorizeFAddFSub(I, Builder)) - return F; + if (auto * PlainBinOp = dyn_cast(&I)) { + if (Instruction *F = factorizeFAddFSub(*PlainBinOp, Builder)) + return F; - // TODO: This performs reassociative folds for FP ops. Some fraction of the - // functionality has been subsumed by simple pattern matching here and in - // InstSimplify. We should let a dedicated reassociation pass handle more - // complex pattern matching and remove this from InstCombine. - if (Value *V = FAddCombine(Builder).simplify(&I)) - return replaceInstUsesWith(I, V); + // TODO: This performs reassociative folds for FP ops. Some fraction of the + // functionality has been subsumed by simple pattern matching here and in + // InstSimplify. We should let a dedicated reassociation pass handle more + // complex pattern matching and remove this from InstCombine. + if (Value *V = FAddCombine(Builder).simplify(PlainBinOp)) + return replaceInstUsesWith(*PlainBinOp, V); + } } return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -38,6 +38,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" @@ -1793,6 +1794,14 @@ return &CI; } + // Predicated instruction patterns + auto * VPInst = dyn_cast(&CI); + if (VPInst) { + auto * PredInst = cast(VPInst); + auto Result = visitPredicatedInstruction(PredInst); + if (Result) return Result; + } + IntrinsicInst *II = dyn_cast(&CI); if (!II) return visitCallBase(CI); @@ -1857,7 +1866,8 @@ if (Changed) return II; } - // For vector result intrinsics, use the generic demanded vector support. + // For vector result intrinsics, use the generic demanded vector support to + // simplify any operands before moving on to the per-intrinsic rules. if (II->getType()->isVectorTy()) { auto VWidth = II->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -30,6 +30,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" @@ -371,6 +372,8 @@ Instruction *visitFAdd(BinaryOperator &I); Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty); Instruction *visitSub(BinaryOperator &I); + template Instruction *visitFSubGeneric(BinaryOpTy &I); + Instruction *visitPredicatedFSub(PredicatedBinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); @@ -447,6 +450,16 @@ Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); + // Entry point to VPIntrinsic + Instruction *visitPredicatedInstruction(PredicatedInstruction * PI) { + switch (PI->getOpcode()) { + default: + return nullptr; + case Instruction::FSub: + return visitPredicatedFSub(cast(*PI)); + } + } + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp --- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp +++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp @@ -860,6 +860,7 @@ case Attribute::InaccessibleMemOnly: case Attribute::InaccessibleMemOrArgMemOnly: case Attribute::JumpTable: + case Attribute::Mask: case Attribute::Naked: case Attribute::Nest: case Attribute::NoAlias: @@ -869,6 +870,7 @@ case Attribute::NoSync: case Attribute::None: case Attribute::NonNull: + case Attribute::Passthru: case Attribute::ReadNone: case Attribute::ReadOnly: case Attribute::Returned: @@ -880,6 +882,7 @@ case Attribute::SwiftError: case Attribute::SwiftSelf: case Attribute::WillReturn: + case Attribute::VectorLength: case Attribute::WriteOnly: case Attribute::ZExt: case Attribute::ImmArg: diff --git a/llvm/test/Bitcode/attributes.ll b/llvm/test/Bitcode/attributes.ll --- a/llvm/test/Bitcode/attributes.ll +++ b/llvm/test/Bitcode/attributes.ll @@ -374,6 +374,11 @@ ret void; } +; CHECK: define <8 x double> @f64(<8 x double> passthru %0, <8 x i1> mask %1, i32 vlen %2) { +define <8 x double> @f64(<8 x double> passthru, <8 x i1> mask, i32 vlen) { + ret <8 x double> undef +} + ; CHECK: attributes #0 = { noreturn } ; CHECK: attributes #1 = { nounwind } ; CHECK: attributes #2 = { readnone } diff --git a/llvm/test/Transforms/InstCombine/vp-fsub.ll b/llvm/test/Transforms/InstCombine/vp-fsub.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/vp-fsub.ll @@ -0,0 +1,45 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; PR4374 + +define <4 x float> @test1_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @test1_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; Can't do anything with the test above because -0.0 - 0.0 = -0.0, but if we have nsz: +; -(X - Y) --> Y - X + +; TODO predicated FAdd folding +define <4 x float> @neg_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CH***-LABEL: @neg_sub_nsz_vp( +; + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> , <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + +; With nsz: Z - (X - Y) --> Z + (Y - X) + +define <4 x float> @sub_sub_nsz_vp(<4 x float> %x, <4 x float> %y, <4 x float> %z, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @sub_sub_nsz_vp( +; CHECK-NEXT: %1 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %y, <4 x float> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) # +; CHECK-NEXT: %t2 = call nsz <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %z, <4 x float> %1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) # +; CHECK-NEXT: ret <4 x float> %t2 + %t1 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %x, <4 x float> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + %t2 = call nsz <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %z, <4 x float> %t1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <4 x i1> %M, i32 %L) #0 + ret <4 x float> %t2 +} + + + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, metadata, metadata, <4 x i1> mask, i32 vlen) + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, metadata, metadata, <4 x i1> mask, i32 vlen) + +attributes #0 = { readnone } diff --git a/llvm/test/Transforms/InstSimplify/vp-fsub.ll b/llvm/test/Transforms/InstSimplify/vp-fsub.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/vp-fsub.ll @@ -0,0 +1,55 @@ +; RUN: opt < %s -instsimplify -S | FileCheck %s + +define <8 x double> @fsub_fadd_fold_vp_xy(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_xy +; CHECK: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %x, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res0 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res0 +} + +define <8 x double> @fsub_fadd_fold_vp_zw(<8 x double> %z, <8 x double> %w, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_vp_zw +; CHECK: ret <8 x double> %z + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %w, <8 x double> %z, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res1 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %w, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res1 +} + +define <8 x double> @fsub_fadd_fold_vp_yx_fpexcept(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) #0 { +; CHECK-LABEL: fsub_fadd_fold_vp_yx +; CHECK-NEXT: %tmp = +; CHECK-NEXT: %res2 = +; CHECK-NEXT: ret + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %len) + %res2 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.strict", <8 x i1> %m, i32 %len) + ret <8 x double> %res2 +} + +define <8 x double> @fsub_fadd_fold_vp_yx_olen(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, i32 %otherLen) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_olen +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %otherLen) +; CHECK-NEXT: %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) +; CHECK-NEXT: ret <8 x double> %res3 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %otherLen) + %res3 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + ret <8 x double> %res3 +} + +define <8 x double> @fsub_fadd_fold_vp_yx_omask(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, <8 x i1> %othermask) { +; CHECK-LABEL: fsub_fadd_fold_vp_yx_omask +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) +; CHECK-NEXT: %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %othermask, i32 %len) +; CHECK-NEXT: ret <8 x double> %res4 + %tmp = call reassoc nsz <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %y, <8 x double> %x, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %len) + %res4 = call reassoc nsz <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %othermask, i32 %len) + ret <8 x double> %res4 +} + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) + +attributes #0 = { strictfp } diff --git a/llvm/test/Verifier/evl_attribs.ll b/llvm/test/Verifier/evl_attribs.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/evl_attribs.ll @@ -0,0 +1,13 @@ +; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s + +declare void @a(<16 x i1> mask %a, <16 x i1> mask %b) +; CHECK: Cannot have multiple 'mask' parameters! + +declare void @b(<16 x i1> mask %a, i32 vlen %x, i32 vlen %y) +; CHECK: Cannot have multiple 'vlen' parameters! + +declare <16 x double> @c(<16 x double> passthru %a) +; CHECK: Cannot have 'passthru' parameter without 'mask' parameter! + +declare <16 x double> @d(<16 x double> passthru %a, <16 x i1> mask %M, <16 x double> passthru %b) +; CHECK: Cannot have multiple 'passthru' parameters! diff --git a/llvm/test/Verifier/vp-intrinsics.ll b/llvm/test/Verifier/vp-intrinsics.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Verifier/vp-intrinsics.ll @@ -0,0 +1,118 @@ +; RUN: opt --verify %s + +define void @test_vp_constrainedfp(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, <8 x double> %f3, <8 x i1> %m, i32 %n) #0 { + %r0 = call <8 x double> @llvm.vp.fadd.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tonearest", metadata !"fpexcept.ignore", <8 x i1> %m, i32 %n) + %r1 = call <8 x double> @llvm.vp.fsub.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + %r2 = call <8 x double> @llvm.vp.fmul.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + %r3 = call <8 x double> @llvm.vp.fdiv.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + %r4 = call <8 x double> @llvm.vp.frem.v8f64(<8 x double> %f0, <8 x double> %f1, metadata !"round.tozero", metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + %r5 = call <8 x double> @llvm.vp.fma.v8f64(<8 x double> %f0, <8 x double> %f1, <8 x double> %f2, metadata !"round.tozero", metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + %r6 = call <8 x double> @llvm.vp.fneg.v8f64(<8 x double> %f2, metadata !"fpexcept.strict", <8 x i1> %m, i32 %n) + ret void +} + +define void @test_vp_int(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { + %r0 = call <8 x i32> @llvm.vp.add.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r1 = call <8 x i32> @llvm.vp.sub.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r2 = call <8 x i32> @llvm.vp.mul.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r3 = call <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r4 = call <8 x i32> @llvm.vp.srem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r5 = call <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + %r6 = call <8 x i32> @llvm.vp.urem.v8i32(<8 x i32> %i0, <8 x i32> %i1, <8 x i1> %m, i32 %n) + ret void +} + +define void @test_mem(<16 x i32*> %p0, <16 x i32>* %p1, <16 x i32> %i0, <16 x i1> %m, i32 %n) { + call void @llvm.vp.store.v16i32p0v16i32(<16 x i32> %i0, <16 x i32>* %p1, <16 x i1> %m, i32 %n) + call void @llvm.vp.scatter.v16i32v16p0i32(<16 x i32> %i0 , <16 x i32*> %p0, <16 x i1> %m, i32 %n) + %l0 = call <16 x i32> @llvm.vp.load.v16i32p0v16i32(<16 x i32>* %p1, <16 x i1> %m, i32 %n) + %l1 = call <16 x i32> @llvm.vp.gather.v16i32v16p0i32(<16 x i32*> %p0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_fp(<16 x float> %v, <16 x i1> %m, i32 %n) { + %r0 = call float @llvm.vp.reduce.fadd.v16f32(float 0.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r1 = call float @llvm.vp.reduce.fmul.v16f32(float 42.0, <16 x float> %v, <16 x i1> %m, i32 %n) + %r2 = call float @llvm.vp.reduce.fmin.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + %r3 = call float @llvm.vp.reduce.fmax.v16f32(<16 x float> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_reduce_int(<16 x i32> %v, <16 x i1> %m, i32 %n) { + %r0 = call i32 @llvm.vp.reduce.add.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r1 = call i32 @llvm.vp.reduce.mul.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r2 = call i32 @llvm.vp.reduce.and.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r3 = call i32 @llvm.vp.reduce.xor.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + %r4 = call i32 @llvm.vp.reduce.or.v16i32(<16 x i32> %v, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_shuffle(<16 x float> %v0, <16 x float> %v1, <16 x i1> %m, i32 %k, i32 %n) { + %r0 = call <16 x float> @llvm.vp.select.v16f32(<16 x i1> %m, <16 x float> %v0, <16 x float> %v1, i32 %n) + %r1 = call <16 x float> @llvm.vp.compose.v16f32(<16 x float> %v0, <16 x float> %v1, i32 %k, i32 %n) + %r2 = call <16 x float> @llvm.vp.shift.v16f32(<16 x float> %v0, i32 %k, <16 x i1> %m, i32 %n) + %r3 = call <16 x float> @llvm.vp.compress.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + %r4 = call <16 x float> @llvm.vp.expand.v16f32(<16 x float> %v0, <16 x i1> %m, i32 %n) + ret void +} + +define void @test_xcmp(<16 x i32> %v0, <16 x i32> %v1, <16 x i1> %m, i32 %n) { + %r0 = call <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32> %v0, <16 x i32> %v1, i8 8, <16 x i1> %m, i32 %n) + %r1 = call <16 x i1> @llvm.vp.fcmp.v16i32(<16 x i32> %v0, <16 x i32> %v1, i8 12, <16 x i1> %m, i32 %n) + ret void +} + +; standard floating point arith +declare <8 x double> @llvm.vp.fadd.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fsub.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fmul.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fdiv.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.frem.v8f64(<8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fma.v8f64(<8 x double>, <8 x double>, <8 x double>, metadata, metadata, <8 x i1> mask, i32 vlen) +declare <8 x double> @llvm.vp.fneg.v8f64(<8 x double>, metadata, <8 x i1> mask, i32 vlen) + +; integer arith +declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +; bit arith +declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) +declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1> mask, i32 vlen) + +; memory +declare void @llvm.vp.store.v16i32p0v16i32(<16 x i32>, <16 x i32>*, <16 x i1> mask, i32 vlen) +declare void @llvm.vp.scatter.v16i32v16p0i32(<16 x i32>, <16 x i32*>, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.load.v16i32p0v16i32(<16 x i32>*, <16 x i1> mask, i32 vlen) +declare <16 x i32> @llvm.vp.gather.v16i32v16p0i32(<16 x i32*>, <16 x i1> mask, i32 vlen) + +; reductions +declare float @llvm.vp.reduce.fadd.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmul.v16f32(float, <16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmin.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare float @llvm.vp.reduce.fmax.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.add.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.mul.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.and.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.xor.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) +declare i32 @llvm.vp.reduce.or.v16i32(<16 x i32>, <16 x i1> mask, i32 vlen) + +; shuffles +declare <16 x float> @llvm.vp.select.v16f32(<16 x i1>, <16 x float>, <16 x float>, i32 vlen) +declare <16 x float> @llvm.vp.compose.v16f32(<16 x float>, <16 x float>, i32, i32 vlen) +declare <16 x float> @llvm.vp.shift.v16f32(<16 x float>, i32, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.compress.v16f32(<16 x float>, <16 x i1>, i32 vlen) +declare <16 x float> @llvm.vp.expand.v16f32(<16 x float>, <16 x i1> mask, i32 vlen) + +; icmp , fcmp +declare <16 x i1> @llvm.vp.icmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) +declare <16 x i1> @llvm.vp.fcmp.v16i32(<16 x i32>, <16 x i32>, i8, <16 x i1> mask, i32 vlen) + +attributes #0 = { strictfp } diff --git a/llvm/unittests/IR/IRBuilderTest.cpp b/llvm/unittests/IR/IRBuilderTest.cpp --- a/llvm/unittests/IR/IRBuilderTest.cpp +++ b/llvm/unittests/IR/IRBuilderTest.cpp @@ -242,52 +242,52 @@ V = Builder.CreateFAdd(V, V); ASSERT_TRUE(isa(V)); auto *CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDynamic); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebIgnore); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmUpward); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebIgnore); + Builder.setDefaultConstrainedRounding(RoundingMode::rmUpward); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmUpward); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmUpward); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebIgnore); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmToNearest); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebIgnore); + Builder.setDefaultConstrainedRounding(RoundingMode::rmToNearest); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmToNearest); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmToNearest); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebMayTrap); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmDownward); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebMayTrap); + Builder.setDefaultConstrainedRounding(RoundingMode::rmDownward); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebMayTrap); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDownward); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebMayTrap); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDownward); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebStrict); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmTowardZero); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebStrict); + Builder.setDefaultConstrainedRounding(RoundingMode::rmTowardZero); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebStrict); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmTowardZero); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebStrict); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmTowardZero); - Builder.setDefaultConstrainedExcept(ConstrainedFPIntrinsic::ebIgnore); - Builder.setDefaultConstrainedRounding(ConstrainedFPIntrinsic::rmDynamic); + Builder.setDefaultConstrainedExcept(ExceptionBehavior::ebIgnore); + Builder.setDefaultConstrainedRounding(RoundingMode::rmDynamic); V = Builder.CreateFAdd(V, V); CII = cast(V); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebIgnore); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDynamic); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebIgnore); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDynamic); // Now override the defaults. Call = Builder.CreateConstrainedFPBinOp( Intrinsic::experimental_constrained_fadd, V, V, nullptr, "", nullptr, - ConstrainedFPIntrinsic::rmDownward, ConstrainedFPIntrinsic::ebMayTrap); + RoundingMode::rmDownward, ExceptionBehavior::ebMayTrap); CII = cast(Call); EXPECT_EQ(CII->getIntrinsicID(), Intrinsic::experimental_constrained_fadd); - ASSERT_TRUE(CII->getExceptionBehavior() == ConstrainedFPIntrinsic::ebMayTrap); - ASSERT_TRUE(CII->getRoundingMode() == ConstrainedFPIntrinsic::rmDownward); + ASSERT_TRUE(CII->getExceptionBehavior() == ExceptionBehavior::ebMayTrap); + ASSERT_TRUE(CII->getRoundingMode() == RoundingMode::rmDownward); Builder.CreateRetVoid(); EXPECT_FALSE(verifyModule(*M)); diff --git a/llvm/utils/TableGen/CodeGenIntrinsics.h b/llvm/utils/TableGen/CodeGenIntrinsics.h --- a/llvm/utils/TableGen/CodeGenIntrinsics.h +++ b/llvm/utils/TableGen/CodeGenIntrinsics.h @@ -146,7 +146,10 @@ ReadOnly, WriteOnly, ReadNone, - ImmArg + ImmArg, + Mask, + VectorLength, + Passthru }; std::vector> ArgumentAttributes; diff --git a/llvm/utils/TableGen/CodeGenTarget.cpp b/llvm/utils/TableGen/CodeGenTarget.cpp --- a/llvm/utils/TableGen/CodeGenTarget.cpp +++ b/llvm/utils/TableGen/CodeGenTarget.cpp @@ -728,12 +728,12 @@ // variants with iAny types; otherwise, if the intrinsic is not // overloaded, all the types can be specified directly. assert(((!TyEl->isSubClassOf("LLVMExtendedType") && - !TyEl->isSubClassOf("LLVMTruncatedType") && - !TyEl->isSubClassOf("LLVMScalarOrSameVectorWidth")) || + !TyEl->isSubClassOf("LLVMTruncatedType")) || VT == MVT::iAny || VT == MVT::vAny) && "Expected iAny or vAny type"); - } else + } else { VT = getValueType(TyEl->getValueAsDef("VT")); + } // Reject invalid types. if (VT == MVT::isVoid && i != e-1 /*void at end means varargs*/) @@ -791,6 +791,15 @@ } else if (Property->isSubClassOf("Returned")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, Returned)); + } else if (Property->isSubClassOf("VectorLength")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, VectorLength)); + } else if (Property->isSubClassOf("Mask")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Mask)); + } else if (Property->isSubClassOf("Passthru")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Passthru)); } else if (Property->isSubClassOf("ReadOnly")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, ReadOnly)); diff --git a/llvm/utils/TableGen/IntrinsicEmitter.cpp b/llvm/utils/TableGen/IntrinsicEmitter.cpp --- a/llvm/utils/TableGen/IntrinsicEmitter.cpp +++ b/llvm/utils/TableGen/IntrinsicEmitter.cpp @@ -671,6 +671,24 @@ OS << "Attribute::Returned"; addComma = true; break; + case CodeGenIntrinsic::VectorLength: + if (addComma) + OS << ","; + OS << "Attribute::VectorLength"; + addComma = true; + break; + case CodeGenIntrinsic::Mask: + if (addComma) + OS << ","; + OS << "Attribute::Mask"; + addComma = true; + break; + case CodeGenIntrinsic::Passthru: + if (addComma) + OS << ","; + OS << "Attribute::Passthru"; + addComma = true; + break; case CodeGenIntrinsic::ReadOnly: if (addComma) OS << ",";