Index: docs/Proposals/VectorPredication.rst =================================================================== --- /dev/null +++ docs/Proposals/VectorPredication.rst @@ -0,0 +1,84 @@ +========================== +Vector Predication Roadmap +========================== + +.. contents:: Table of Contents + :depth: 3 + :local: + +Motivation +========== + +This proposal defines a roadmap towards native vector predication in LLVM, specifically for vector instructions with a mask and/or an explicit vector length. +LLVM currently has no target-independent means to model predicated vector instructions for modern SIMD ISAs such as AVX512, ARM SVE, the RISC-V V extension and NEC SX-Aurora. +Only some predicated vector operations, such as masked loads and stores are available through intrinsics [MaskedIR]_. + +The Explicit Vector Length extension +==================================== + +The Explicit Vector Length (EVL) extension [EvlRFC]_ can be a first step towards native vector predication. +The EVL prototype in this patch demonstrates the following concepts: + +- Predicated vector intrinsics with an explicit mask and vector length parameter on IR level. +- First-class predicated SDNodes on ISel level. Mask and vector length are value operands. +- An incremental strategy to generalize PatternMatch/InstCombine/InstSimplify and DAGCombiner to work on both regular instructions and EVL intrinsics. +- DAGCombiner example: FMA fusion. +- InstCombine/InstSimplify example: FSub pattern re-writes. +- Early experiments on the LNT test suite (Clang static release, O3 -ffast-math) indicate that compile time on non-EVL IR is not affected by the API abstractions in PatternMatch, etc. + +Roadmap +======= + +Drawing from the EVL prototype, we propose the following roadmap towards native vector predication in LLVM: + + +1. IR-level EVL intrinsics +----------------------------------------- + +- There is a consensus on the semantics/instruction set of EVL. +- EVL intrinsics and attributes are available on IR level. +- TTI has capability flags for EVL (``supportsEVL()``?, ``haveActiveVectorLength()``?). + +Result: EVL usable for IR-level vectorizers (LV, VPlan, RegionVectorizer), potential integration in Clang with builtins. + +2. CodeGen support +------------------ + +- EVL intrinsics translate to first-class SDNodes (``llvm.evl.fdiv.* -> evl_fdiv``). +- EVL legalization (legalize explicit vector length to mask (AVX512), legalize EVL SDNodes to pre-existing ones (SSE, NEON)). + +Result: Backend development based on EVL SDNodes. + +3. Lift InstSimplify/InstCombine/DAGCombiner to EVL +------------------------------------------------ + +- Introduce PredicatedInstruction, PredicatedBinaryOperator, .. helper classes that match standard vector IR and EVL intrinsics. +- Add a matcher context to PatternMatch and context-aware IR Builder APIs. +- Incrementally lift DAGCombiner to work on EVL SDNodes as well as on regular vector instructions. +- Incrementally lift InstCombine/InstSimplify to operate on EVL as well as regular IR instructions. + +Result: Optimization of EVL intrinsics on par with standard vector instructions. + +4. Deprecate llvm.masked.* / llvm.experimental.reduce.* +------------------------------------------------------- + +- Modernize llvm.masked.* / llvm.experimental.reduce* by translating to EVL. +- DCE transitional APIs. + +Result: EVL has superseded earlier vector intrinsics. + +5. Predicated IR Instructions +--------------------------------------- + +- Vector instructions have an optional mask and vector length parameter. These lower to EVL SDNodes (from Stage 2). +- Phase out EVL intrinsics, only keeping those that are not equivalent to vectorized scalar instructions (reduce, +- InstCombine/InstSimplify expect predication in regular Instructions (Stage (3) has laid the groundwork). +shuffles, ..). + +Result: Native vector predication in IR. + +References +========== + +.. [MaskedIR] `llvm.masked.*` intrinsics, https://llvm.org/docs/LangRef.html#masked-vector-load-and-store-intrinsics +.. [EvlRFC] Explicit Vector Length RFC, https://reviews.llvm.org/D53613 Index: include/llvm/Analysis/InstructionSimplify.h =================================================================== --- include/llvm/Analysis/InstructionSimplify.h +++ include/llvm/Analysis/InstructionSimplify.h @@ -52,6 +52,10 @@ class Value; class MDNode; class BinaryOperator; +class EVLIntrinsic; +namespace PatternMatch { + struct PredicatedContext; +} /// InstrInfoQuery provides an interface to query additional information for /// instructions like metadata or keywords like nsw, which provides conservative @@ -133,6 +137,13 @@ Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); +/// Given operands for an FSub, fold the result or return null. +Value *SimplifyFSubInst(Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); +Value *SimplifyPredicatedFSubInst(Value *LHS, Value *RHS, + FastMathFlags FMF, const SimplifyQuery &Q, + PatternMatch::PredicatedContext & PC); + /// Given operands for an FMul, fold the result or return null. Value *SimplifyFMulInst(Value *LHS, Value *RHS, FastMathFlags FMF, const SimplifyQuery &Q); @@ -245,6 +256,10 @@ Value *SimplifyCall(ImmutableCallSite CS, Value *V, User::op_iterator ArgBegin, User::op_iterator ArgEnd, const SimplifyQuery &Q); +/// Given a function and iterators over arguments, fold the result or return +/// null. +Value *SimplifyEVLIntrinsic(EVLIntrinsic & EVLInst, const SimplifyQuery &Q); + /// Given a function and set of arguments, fold the result or return null. Value *SimplifyCall(ImmutableCallSite CS, Value *V, ArrayRef Args, const SimplifyQuery &Q); Index: include/llvm/Bitcode/LLVMBitCodes.h =================================================================== --- include/llvm/Bitcode/LLVMBitCodes.h +++ include/llvm/Bitcode/LLVMBitCodes.h @@ -603,6 +603,9 @@ ATTR_KIND_OPT_FOR_FUZZING = 57, ATTR_KIND_SHADOWCALLSTACK = 58, ATTR_KIND_SPECULATIVE_LOAD_HARDENING = 59, + ATTR_KIND_MASK = 60, + ATTR_KIND_VECTORLENGTH = 61, + ATTR_KIND_PASSTHRU = 62, }; enum ComdatSelectionKindCodes { Index: include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- include/llvm/CodeGen/ISDOpcodes.h +++ include/llvm/CodeGen/ISDOpcodes.h @@ -198,6 +198,7 @@ /// Simple integer binary arithmetic operators. ADD, SUB, MUL, SDIV, UDIV, SREM, UREM, + EVL_ADD, EVL_SUB, EVL_MUL, EVL_SDIV, EVL_UDIV, EVL_SREM, EVL_UREM, /// SMUL_LOHI/UMUL_LOHI - Multiply two integers of type iN, producing /// a signed/unsigned value of type i[2*N], and return the full value as @@ -280,6 +281,7 @@ /// Simple binary floating point operators. FADD, FSUB, FMUL, FDIV, FREM, + EVL_FADD, EVL_FSUB, EVL_FMUL, EVL_FDIV, EVL_FREM, /// Constrained versions of the binary floating point operators. /// These will be lowered to the simple operators before final selection. @@ -299,6 +301,7 @@ /// FMA - Perform a * b + c with no intermediate rounding step. FMA, + EVL_FMA, /// FMAD - Perform a * b + c, while getting the same result as the /// separately rounded operations. @@ -365,6 +368,19 @@ /// in terms of the element size of VEC1/VEC2, not in terms of bytes. VECTOR_SHUFFLE, + /// EVL_VSHIFT(VEC1, AMOUNT, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. AMOUNT is an integer value. The returned vector is equivalent + /// to VEC1 shifted by AMOUNT (RETURNED_VEC[idx] = VEC1[idx + AMOUNT]). + EVL_VSHIFT, + + /// EVL_COMPRESS(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + EVL_COMPRESS, + + /// EVL_EXPAND(VEC1, MASK, VLEN) - Returns a vector, of the same type as + /// VEC1. + EVL_EXPAND, + /// SCALAR_TO_VECTOR(VAL) - This represents the operation of loading a /// scalar value into element 0 of the resultant vector type. The top /// elements 1 to N-1 of the N-element vector are undefined. The type @@ -384,6 +400,7 @@ /// Bitwise operators - logical and, logical or, logical xor. AND, OR, XOR, + EVL_AND, EVL_OR, EVL_XOR, /// ABS - Determine the unsigned absolute value of a signed integer value of /// the same bitwidth. @@ -407,6 +424,7 @@ /// fshl(X,Y,Z): (X << (Z % BW)) | (Y >> (BW - (Z % BW))) /// fshr(X,Y,Z): (X << (BW - (Z % BW))) | (Y >> (Z % BW)) SHL, SRA, SRL, ROTL, ROTR, FSHL, FSHR, + EVL_SHL, EVL_SRA, EVL_SRL, /// Byte Swap and Counting operators. BSWAP, CTTZ, CTLZ, CTPOP, BITREVERSE, @@ -426,6 +444,14 @@ /// change the condition type in order to match the VSELECT node using a /// pattern. The condition follows the BooleanContent format of the target. VSELECT, + EVL_SELECT, + + /// Select with an integer pivot (op #0) and two vector operands (ops #1 + /// and #2), returning a vector result. All vectors have the same length. + /// Similar to the vector select, a comparison of the results element index + /// with the integer pivot selects hether the corresponding result element + /// is taken from op #1 or op #2. + EVL_COMPOSE, /// Select with condition operator - This selects between a true value and /// a false value (ops #2 and #3) based on the boolean result of comparing @@ -440,6 +466,7 @@ /// them with (op #2) as a CondCodeSDNode. If the operands are vector types /// then the result type must also be a vector type. SETCC, + EVL_SETCC, /// Like SetCC, ops #0 and #1 are the LHS and RHS operands to compare, but /// op #2 is a boolean indicating if there is an incoming carry. This @@ -583,6 +610,7 @@ FNEG, FABS, FSQRT, FCBRT, FSIN, FCOS, FPOWI, FPOW, FLOG, FLOG2, FLOG10, FEXP, FEXP2, FCEIL, FTRUNC, FRINT, FNEARBYINT, FROUND, FFLOOR, + EVL_FNEG, /// FMINNUM/FMAXNUM - Perform floating-point minimum or maximum on two /// values. // @@ -828,6 +856,7 @@ // Val, OutChain = MLOAD(BasePtr, Mask, PassThru) // OutChain = MSTORE(Value, BasePtr, Mask) MLOAD, MSTORE, + EVL_LOAD, EVL_STORE, // Masked gather and scatter - load and store operations for a vector of // random addresses with additional mask operand that prevents memory @@ -839,6 +868,7 @@ // The Index operand can have more vector elements than the other operands // due to type legalization. The extra elements are ignored. MGATHER, MSCATTER, + EVL_GATHER, EVL_SCATTER, /// This corresponds to the llvm.lifetime.* intrinsics. The first operand /// is the chain and the second operand is the alloca pointer. @@ -870,8 +900,15 @@ VECREDUCE_ADD, VECREDUCE_MUL, VECREDUCE_AND, VECREDUCE_OR, VECREDUCE_XOR, VECREDUCE_SMAX, VECREDUCE_SMIN, VECREDUCE_UMAX, VECREDUCE_UMIN, + + EVL_REDUCE_FADD, EVL_REDUCE_FMUL, + EVL_REDUCE_ADD, EVL_REDUCE_MUL, + EVL_REDUCE_AND, EVL_REDUCE_OR, EVL_REDUCE_XOR, + EVL_REDUCE_SMAX, EVL_REDUCE_SMIN, EVL_REDUCE_UMAX, EVL_REDUCE_UMIN, + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. VECREDUCE_FMAX, VECREDUCE_FMIN, + EVL_REDUCE_FMAX, EVL_REDUCE_FMIN, /// BUILTIN_OP_END - This must be the last enum value in this list. /// The target-specific pre-isel opcode values start here. @@ -1032,6 +1069,19 @@ /// SETCC_INVALID if it is not possible to represent the resultant comparison. CondCode getSetCCAndOperation(CondCode Op1, CondCode Op2, bool isInteger); + /// Return the mask operand of this EVL SDNode. + /// Otw, return -1. + int GetMaskPosEVL(unsigned OpCode); + + /// Return the vector length operand of this EVL SDNode. + /// Otw, return -1. + int GetVectorLengthPosEVL(unsigned OpCode); + + /// Translate this EVL OpCode to a native instruction OpCode. + unsigned GetFunctionOpCodeForEVL(unsigned EVLOpCode); + + unsigned GetEVLForFunctionOpCode(unsigned OpCode); + } // end llvm::ISD namespace } // end llvm namespace Index: include/llvm/CodeGen/SelectionDAG.h =================================================================== --- include/llvm/CodeGen/SelectionDAG.h +++ include/llvm/CodeGen/SelectionDAG.h @@ -1083,6 +1083,20 @@ SDValue getIndexedStore(SDValue OrigStore, const SDLoc &dl, SDValue Base, SDValue Offset, ISD::MemIndexedMode AM); + /// Returns sum of the base pointer and offset. + SDValue getLoadEVL(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, + SDValue Mask, SDValue VLen, EVT MemVT, + MachineMemOperand *MMO, ISD::LoadExtType); + + SDValue getStoreEVL(SDValue Chain, const SDLoc &dl, SDValue Val, + SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + bool IsTruncating = false); + SDValue getGatherEVL(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO); + SDValue getScatterEVL(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, MachineMemOperand *MMO); + /// Returns sum of the base pointer and offset. SDValue getMemBasePlusOffset(SDValue Base, unsigned Offset, const SDLoc &DL); Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -533,6 +533,7 @@ class LoadSDNodeBitfields { friend class LoadSDNode; friend class MaskedLoadSDNode; + friend class EVLLoadSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -543,6 +544,7 @@ class StoreSDNodeBitfields { friend class StoreSDNode; friend class MaskedStoreSDNode; + friend class EVLStoreSDNode; uint16_t : NumLSBaseSDNodeBits; @@ -680,6 +682,66 @@ } } + /// Test whether this is an Explicit Vector Length node. + bool isEVL() const { + switch (NodeType) { + default: + return false; + case ISD::EVL_LOAD: + case ISD::EVL_STORE: + case ISD::EVL_GATHER: + case ISD::EVL_SCATTER: + + case ISD::EVL_FNEG: + + case ISD::EVL_FADD: + case ISD::EVL_FMUL: + case ISD::EVL_FSUB: + case ISD::EVL_FDIV: + case ISD::EVL_FREM: + + case ISD::EVL_FMA: + + case ISD::EVL_ADD: + case ISD::EVL_MUL: + case ISD::EVL_SUB: + case ISD::EVL_SRA: + case ISD::EVL_SRL: + case ISD::EVL_SHL: + case ISD::EVL_UDIV: + case ISD::EVL_SDIV: + case ISD::EVL_UREM: + case ISD::EVL_SREM: + + case ISD::EVL_EXPAND: + case ISD::EVL_COMPRESS: + case ISD::EVL_VSHIFT: + case ISD::EVL_SETCC: + case ISD::EVL_COMPOSE: + + case ISD::EVL_AND: + case ISD::EVL_XOR: + case ISD::EVL_OR: + + case ISD::EVL_REDUCE_ADD: + case ISD::EVL_REDUCE_SMIN: + case ISD::EVL_REDUCE_SMAX: + case ISD::EVL_REDUCE_UMIN: + case ISD::EVL_REDUCE_UMAX: + + case ISD::EVL_REDUCE_MUL: + case ISD::EVL_REDUCE_AND: + case ISD::EVL_REDUCE_OR: + case ISD::EVL_REDUCE_FADD: + case ISD::EVL_REDUCE_FMUL: + case ISD::EVL_REDUCE_FMIN: + case ISD::EVL_REDUCE_FMAX: + + return true; + } + } + + /// Test if this node has a post-isel opcode, directly /// corresponding to a MachineInstr opcode. bool isMachineOpcode() const { return NodeType < 0; } @@ -1367,6 +1429,10 @@ N->getOpcode() == ISD::MSTORE || N->getOpcode() == ISD::MGATHER || N->getOpcode() == ISD::MSCATTER || + N->getOpcode() == ISD::EVL_LOAD || + N->getOpcode() == ISD::EVL_STORE || + N->getOpcode() == ISD::EVL_GATHER || + N->getOpcode() == ISD::EVL_SCATTER || N->isMemIntrinsic() || N->isTargetMemoryOpcode(); } @@ -2139,6 +2205,96 @@ } }; +/// This base class is used to represent MLOAD and MSTORE nodes +class EVLLoadStoreSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + EVLLoadStoreSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + + // EVLLoadSDNode (Chain, ptr, mask, VLen) + // EVLStoreSDNode (Chain, data, ptr, mask, VLen) + // Mask is a vector of i1 elements, Vlen is i32 + const SDValue &getBasePtr() const { + return getOperand(getOpcode() == ISD::EVL_LOAD ? 1 : 2); + } + const SDValue &getMask() const { + return getOperand(getOpcode() == ISD::EVL_LOAD ? 2 : 3); + } + const SDValue &getVectorLength() const { + return getOperand(getOpcode() == ISD::EVL_LOAD ? 3 : 4); + } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EVL_LOAD || + N->getOpcode() == ISD::EVL_STORE; + } +}; + +/// This class is used to represent an MLOAD node +class EVLLoadSDNode : public EVLLoadStoreSDNode { +public: + friend class SelectionDAG; + + EVLLoadSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + ISD::LoadExtType ETy, EVT MemVT, + MachineMemOperand *MMO) + : EVLLoadStoreSDNode(ISD::EVL_LOAD, Order, dl, VTs, MemVT, MMO) { + LoadSDNodeBits.ExtTy = ETy; + LoadSDNodeBits.IsExpanding = false; + } + + ISD::LoadExtType getExtensionType() const { + return static_cast(LoadSDNodeBits.ExtTy); + } + + const SDValue &getBasePtr() const { return getOperand(1); } + const SDValue &getMask() const { return getOperand(2); } + const SDValue &getVectorLength() const { return getOperand(3); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EVL_LOAD; + } + bool isExpandingLoad() const { return LoadSDNodeBits.IsExpanding; } +}; + +/// This class is used to represent an MSTORE node +class EVLStoreSDNode : public EVLLoadStoreSDNode { +public: + friend class SelectionDAG; + + EVLStoreSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + bool isTrunc, EVT MemVT, + MachineMemOperand *MMO) + : EVLLoadStoreSDNode(ISD::EVL_STORE, Order, dl, VTs, MemVT, MMO) { + StoreSDNodeBits.IsTruncating = isTrunc; + StoreSDNodeBits.IsCompressing = false; + } + + /// Return true if the op does a truncation before store. + /// For integers this is the same as doing a TRUNCATE and storing the result. + /// For floats, it is the same as doing an FP_ROUND and storing the result. + bool isTruncatingStore() const { return StoreSDNodeBits.IsTruncating; } + + /// Returns true if the op does a compression to the vector before storing. + /// The node contiguously stores the active elements (integers or floats) + /// in src (those with their respective bit set in writemask k) to unaligned + /// memory at base_addr. + bool isCompressingStore() const { return StoreSDNodeBits.IsCompressing; } + + const SDValue &getValue() const { return getOperand(1); } + const SDValue &getBasePtr() const { return getOperand(2); } + const SDValue &getMask() const { return getOperand(3); } + const SDValue &getVectorLength() const { return getOperand(4); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EVL_STORE; + } +}; + /// This base class is used to represent MLOAD and MSTORE nodes class MaskedLoadStoreSDNode : public MemSDNode { public: @@ -2226,6 +2382,67 @@ } }; +/// This is a base class used to represent +/// EVL_GATHER and EVL_SCATTER nodes +/// +class EVLGatherScatterSDNode : public MemSDNode { +public: + friend class SelectionDAG; + + EVLGatherScatterSDNode(ISD::NodeType NodeTy, unsigned Order, + const DebugLoc &dl, SDVTList VTs, EVT MemVT, + MachineMemOperand *MMO) + : MemSDNode(NodeTy, Order, dl, VTs, MemVT, MMO) {} + + // In the both nodes address is Op1, mask is Op2: + // EVLGatherSDNode (Chain, base, index, scale, mask, vlen) + // EVLScatterSDNode (Chain, value, base, index, sckae, mask, vlen) + // Mask is a vector of i1 elements + const SDValue &getBasePtr() const { return getOperand((getOpcode() == ISD::EVL_GATHER) ? 1 : 2); } + const SDValue &getIndex() const { return getOperand((getOpcode() == ISD::EVL_GATHER) ? 2 : 3); } + const SDValue &getScale() const { return getOperand((getOpcode() == ISD::EVL_GATHER) ? 3 : 4); } + const SDValue &getMask() const { return getOperand((getOpcode() == ISD::EVL_GATHER) ? 4 : 5); } + const SDValue &getVectorLength() const { return getOperand((getOpcode() == ISD::EVL_GATHER) ? 5 : 6); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EVL_GATHER || + N->getOpcode() == ISD::EVL_SCATTER; + } +}; + +/// This class is used to represent an EVL_GATHER node +/// +class EVLGatherSDNode : public EVLGatherScatterSDNode { +public: + friend class SelectionDAG; + + EVLGatherSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO) + : EVLGatherScatterSDNode(ISD::EVL_GATHER, Order, dl, VTs, MemVT, MMO) {} + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EVL_GATHER; + } +}; + +/// This class is used to represent an EVL_SCATTER node +/// +class EVLScatterSDNode : public EVLGatherScatterSDNode { +public: + friend class SelectionDAG; + + EVLScatterSDNode(unsigned Order, const DebugLoc &dl, SDVTList VTs, + EVT MemVT, MachineMemOperand *MMO) + : EVLGatherScatterSDNode(ISD::EVL_SCATTER, Order, dl, VTs, MemVT, MMO) {} + + const SDValue &getValue() const { return getOperand(1); } + + static bool classof(const SDNode *N) { + return N->getOpcode() == ISD::EVL_SCATTER; + } +}; + + /// This is a base class used to represent /// MGATHER and MSCATTER nodes /// Index: include/llvm/IR/Attributes.td =================================================================== --- include/llvm/IR/Attributes.td +++ include/llvm/IR/Attributes.td @@ -130,6 +130,15 @@ /// Return value is always equal to this argument. def Returned : EnumAttr<"returned">; +/// Return value that is equal to this argument on enabled lanes (mask). +def Passthru : EnumAttr<"passthru">; + +/// Mask argument that applies to this function. +def Mask : EnumAttr<"mask">; + +/// Dynamic Vector Length argument of this function. +def VectorLength : EnumAttr<"vlen">; + /// Function can return twice. def ReturnsTwice : EnumAttr<"returns_twice">; Index: include/llvm/IR/EVLBuilder.h =================================================================== --- /dev/null +++ include/llvm/IR/EVLBuilder.h @@ -0,0 +1,232 @@ +#ifndef LLVM_IR_EVLBUILDER_H +#define LLVM_IR_EVLBUILDER_H + +#include +#include +#include +#include +#include +#include + +namespace llvm { + +using ValArray = ArrayRef; + +class EVLBuilder { + IRBuilder<> & Builder; + // Explicit mask parameter + Value * Mask; + // Explicit vector length parameter + Value * ExplicitVectorLength; + // Compile-time vector length + int StaticVectorLength; + + // get a vlaid mask/evl argument for the current predication contet + Value& GetMaskForType(VectorType & VecTy); + Value& GetEVLForType(VectorType & VecTy); + +public: + EVLBuilder(IRBuilder<> & _builder) + : Builder(_builder) + , Mask(nullptr) + , ExplicitVectorLength(nullptr) + , StaticVectorLength(-1) + {} + + Module & getModule() const; + LLVMContext & getContext() const { return Builder.getContext(); } + + // The cannonical vector type for this \p ElementTy + VectorType& getVectorType(Type &ElementTy); + + // Predication context tracker + EVLBuilder& setMask(Value * _Mask) { Mask = _Mask; return *this; } + EVLBuilder& setEVL(Value * _ExplicitVectorLength) { ExplicitVectorLength = _ExplicitVectorLength; return *this; } + EVLBuilder& setStaticVL(int VLen) { StaticVectorLength = VLen; return *this; } + + EVLIntrinsic::EVLIntrinsicDesc GetEVLIntrinsicDesc(unsigned OC); + + // Create a map-vectorized copy of the instruction \p Inst with the underlying IRBuilder instance. + // This operation may return nullptr if the instruction could not be vectorized. + Value* CreateVectorCopy(Instruction & Inst, ValArray VecOpArray); + + // Memory + Value& CreateContiguousStore(Value & Val, Value & Pointer, unsigned Alignment=0); + Value& CreateContiguousLoad(Value & Pointer, unsigned Alignment=0); + Value& CreateScatter(Value & Val, Value & PointerVec, unsigned Alignment=0); + Value& CreateGather(Value & PointerVec, unsigned Alignment=0); +}; + + + + + +namespace PatternMatch { + // Factory class to generate instructions in a context + template + class MatchContextBuilder { + public: + // MatchContextBuilder(MatcherContext MC); + }; + + +// Context-free instruction builder +template<> +class MatchContextBuilder { +public: + MatchContextBuilder(EmptyContext & EC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Value *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const { \ + auto * Inst = BinaryOperator::Create(Instruction::OPC, V1, V2, Name); \ + Builder.Insert(Inst); return Inst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Value *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return BinaryOperator::Create(Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + BinaryOperator *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FAdd, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, V1, V2, FMFSource, Name); + } + template + BinaryOperator *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + BinaryOperator *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FMul, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FDiv, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return BinaryOperator::CreateWithCopiedFlags(Instruction::FRem, V1, V2, FMFSource, Name); + } + BinaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return BinaryOperator::CreateWithCopiedFlags(Instruction::FSub, Zero, Op, FMFSource); + } + + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + + + +// Context-free instruction builder +template<> +class MatchContextBuilder { + PredicatedContext & PC; +public: + MatchContextBuilder(PredicatedContext & PC) : PC(PC) {} + + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name);\ + } \ + template \ + Instruction *Create##OPC(IRBuilderType & Builder, Value *V1, Value *V2, \ + const Twine &Name = "") const {\ + auto * PredInst = Create##OPC(V1, V2, Name); Builder.Insert(PredInst); return PredInst; \ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, BasicBlock *BB) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, BB);\ + } + #include "llvm/IR/Instruction.def" + #define HANDLE_BINARY_INST(N, OPC, CLASS) \ + Instruction *Create##OPC(Value *V1, Value *V2, \ + const Twine &Name, Instruction *I) const {\ + return PredicatedBinaryOperator::Create(PC.Mod, PC.Mask, PC.VectorLength, Instruction::OPC, V1, V2, Name, I);\ + } + #include "llvm/IR/Instruction.def" + #undef HANDLE_BINARY_INST + + Instruction *CreateFAddFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FAdd, V1, V2, FMFSource, Name); + } + Instruction *CreateFSubFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, V1, V2, FMFSource, Name); + } + template + Instruction *CreateFSubFMF(IRBuilderType & Builder, Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + auto * Inst = CreateFSubFMF(V1, V2, FMFSource, Name); + Builder.Insert(Inst); return Inst; + } + Instruction *CreateFMulFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FMul, V1, V2, FMFSource, Name); + } + Instruction *CreateFDivFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FDiv, V1, V2, FMFSource, Name); + } + Instruction *CreateFRemFMF(Value *V1, Value *V2, + Instruction *FMFSource, + const Twine &Name = "") { + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FRem, V1, V2, FMFSource, Name); + } + Instruction *CreateFNegFMF(Value *Op, Instruction *FMFSource, + const Twine &Name = "") { + Value *Zero = ConstantFP::getNegativeZero(Op->getType()); + return PredicatedBinaryOperator::CreateWithCopiedFlags(PC.Mod, PC.Mask, PC.VectorLength, Instruction::FSub, Zero, Op, FMFSource); + } + + // TODO predicated casts + template + Value *CreateFPTrunc(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPTrunc(V, DestTy, Name); } + template + Value *CreateFPExt(IRBuilderType & Builder, Value *V, Type *DestTy, const Twine & Name = Twine()) { return Builder.CreateFPExt(V, DestTy, Name); } +}; + +} + +} // namespace llvm + +#endif // LLVM_IR_EVLBUILDER_H Index: include/llvm/IR/InstrTypes.h =================================================================== --- include/llvm/IR/InstrTypes.h +++ include/llvm/IR/InstrTypes.h @@ -161,7 +161,7 @@ static BinaryOperator *CreateWithCopiedFlags(BinaryOps Opc, Value *V1, Value *V2, - BinaryOperator *CopyBO, + Instruction *CopyBO, const Twine &Name = "") { BinaryOperator *BO = Create(Opc, V1, V2, Name); BO->copyIRFlags(CopyBO); @@ -169,31 +169,31 @@ } static BinaryOperator *CreateFAddFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FAdd, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFSubFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FSub, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFMulFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FMul, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFDivFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FDiv, V1, V2, FMFSource, Name); } static BinaryOperator *CreateFRemFMF(Value *V1, Value *V2, - BinaryOperator *FMFSource, + Instruction *FMFSource, const Twine &Name = "") { return CreateWithCopiedFlags(Instruction::FRem, V1, V2, FMFSource, Name); } - static BinaryOperator *CreateFNegFMF(Value *Op, BinaryOperator *FMFSource, + static BinaryOperator *CreateFNegFMF(Value *Op, Instruction *FMFSource, const Twine &Name = "") { Value *Zero = ConstantFP::getNegativeZero(Op->getType()); return CreateWithCopiedFlags(Instruction::FSub, Zero, Op, FMFSource); Index: include/llvm/IR/IntrinsicInst.h =================================================================== --- include/llvm/IR/IntrinsicInst.h +++ include/llvm/IR/IntrinsicInst.h @@ -205,6 +205,150 @@ /// @} }; + class EVLIntrinsic : public IntrinsicInst { + public: + enum class EVLTypeToken : int8_t { + Scalar = 1, // scalar operand type + Vector = 2, // vectorized operand type + Mask = 3 // vector mask type + }; + + using TypeTokenVec = SmallVector; + using ShortTypeVec = SmallVector; + + struct + EVLIntrinsicDesc { + Intrinsic::ID ID; // LLVM Intrinsic ID. + TypeTokenVec typeTokens; // Type Parmeters for the LLVM Intrinsic. + int MaskPos; // Parameter index of the Mask parameter. + int EVLPos; // Parameter index of the EVL parameter. + }; + + // Translate this generic Opcode to an EVLIntrinsic + static EVLIntrinsicDesc GetEVLIntrinsicDesc(unsigned OC); + + // Generate the disambiguating type vec for this EVL Intrinsic + static EVLIntrinsic::ShortTypeVec + EncodeTypeTokens(EVLIntrinsic::TypeTokenVec TTVec, Type & VectorTy, Type & ScalarTy); + + bool isUnaryOp() const; + bool isBinaryOp() const; + bool isTernaryOp() const; + + CmpInst::Predicate getCmpPredicate() const; + + Value* getMask() const; + Value* getVectorLength() const; + + // Methods for support type inquiry through isa, cast, and dyn_cast: + static bool classof(const IntrinsicInst *I) { + switch (I->getIntrinsicID()) { + default: + return false; + + case Intrinsic::evl_cmp: + + case Intrinsic::evl_and: + case Intrinsic::evl_or: + case Intrinsic::evl_xor: + case Intrinsic::evl_ashr: + case Intrinsic::evl_lshr: + case Intrinsic::evl_shl: + + case Intrinsic::evl_select: + case Intrinsic::evl_compose: + case Intrinsic::evl_compress: + case Intrinsic::evl_expand: + case Intrinsic::evl_vshift: + + case Intrinsic::evl_load: + case Intrinsic::evl_store: + case Intrinsic::evl_gather: + case Intrinsic::evl_scatter: + + case Intrinsic::evl_fneg: + + case Intrinsic::evl_fadd: + case Intrinsic::evl_fsub: + case Intrinsic::evl_fmul: + case Intrinsic::evl_fdiv: + case Intrinsic::evl_frem: + + case Intrinsic::evl_fma: + + case Intrinsic::evl_add: + case Intrinsic::evl_sub: + case Intrinsic::evl_mul: + case Intrinsic::evl_udiv: + case Intrinsic::evl_sdiv: + case Intrinsic::evl_urem: + case Intrinsic::evl_srem: + + case Intrinsic::evl_reduce_add: + case Intrinsic::evl_reduce_mul: + case Intrinsic::evl_reduce_umin: + case Intrinsic::evl_reduce_umax: + case Intrinsic::evl_reduce_smin: + case Intrinsic::evl_reduce_smax: + + case Intrinsic::evl_reduce_and: + case Intrinsic::evl_reduce_or: + case Intrinsic::evl_reduce_xor: + + case Intrinsic::evl_reduce_fadd: + case Intrinsic::evl_reduce_fmul: + case Intrinsic::evl_reduce_fmin: + case Intrinsic::evl_reduce_fmax: + return true; + } + } + static bool classof(const Value *V) { + return isa(V) && classof(cast(V)); + } + + // Equivalent non-predicated opcode + unsigned getFunctionalOpcode() const { + switch (getIntrinsicID()) { + default: return Instruction::Call; + + case Intrinsic::evl_cmp: + if (getArgOperand(0)->getType()->isFloatingPointTy()) { + return Instruction::FCmp; + } else { + return Instruction::ICmp; + } + + case Intrinsic::evl_and: return Instruction::And; + case Intrinsic::evl_or: return Instruction::Or; + case Intrinsic::evl_xor: return Instruction::Xor; + case Intrinsic::evl_ashr: return Instruction::AShr; + case Intrinsic::evl_lshr: return Instruction::LShr; + case Intrinsic::evl_shl: return Instruction::Shl; + + case Intrinsic::evl_select: return Instruction::Select; + + case Intrinsic::evl_load: return Instruction::Load; + case Intrinsic::evl_store: return Instruction::Store; + + case Intrinsic::evl_fneg: return Instruction::FNeg; + + case Intrinsic::evl_fadd: return Instruction::FAdd; + case Intrinsic::evl_fsub: return Instruction::FSub; + case Intrinsic::evl_fmul: return Instruction::FMul; + case Intrinsic::evl_fdiv: return Instruction::FDiv; + case Intrinsic::evl_frem: return Instruction::FRem; + + case Intrinsic::evl_add: return Instruction::Add; + case Intrinsic::evl_sub: return Instruction::Sub; + case Intrinsic::evl_mul: return Instruction::Mul; + case Intrinsic::evl_udiv: return Instruction::UDiv; + case Intrinsic::evl_sdiv: return Instruction::SDiv; + case Intrinsic::evl_urem: return Instruction::URem; + case Intrinsic::evl_srem: return Instruction::SRem; + } + } + }; + /// This is the common base class for constrained floating point intrinsics. class ConstrainedFPIntrinsic : public IntrinsicInst { public: Index: include/llvm/IR/Intrinsics.td =================================================================== --- include/llvm/IR/Intrinsics.td +++ include/llvm/IR/Intrinsics.td @@ -87,6 +87,25 @@ int ArgNo = argNo; } +// VectorLength - The specified argument is the Dynamic Vector Length of the +// operation. +class VectorLength : IntrinsicProperty { + int ArgNo = argNo; +} + +// Mask - The specified argument contains the per-lane mask of this +// intrinsic. Inputs on masked-out lanes must not affect the result of this +// intrinsic (except for the Passthru argument). +class Mask : IntrinsicProperty { + int ArgNo = argNo; +} +// Passthru - The specified argument contains the per-lane return value +// for this vector intrinsic where the mask is false. +// (requires the Mask attribute in the same function) +class Passthru : IntrinsicProperty { + int ArgNo = argNo; +} + def IntrNoReturn : IntrinsicProperty; // IntrCold - Calls to this intrinsic are cold. @@ -1006,6 +1025,267 @@ // Intrinsic to detect whether its argument is a constant. def int_is_constant : Intrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem], "llvm.is.constant">; +//===---------------- Masked/Explicit Vector Length Intrinsics --------------===// + +// Memory Intrinsics +def int_evl_store : Intrinsic<[], + [ llvm_anyvector_ty, + LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrArgMemOnly, Mask<2>, VectorLength<3> ]>; + +def int_evl_load : Intrinsic<[ llvm_anyvector_ty], + [ LLVMAnyPointerType>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrReadMem, IntrArgMemOnly, Mask<1>, VectorLength<2> ]>; + +def int_evl_gather: Intrinsic<[ llvm_anyvector_ty], + [ LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrReadMem, IntrReadMem, Mask<1>, VectorLength<2> ]>; + +def int_evl_scatter: Intrinsic<[], + [ llvm_anyvector_ty, + LLVMVectorOfAnyPointersToElt<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrArgMemOnly, Mask<2>, VectorLength<3> ]>; + +// Reductions +let IntrProperties = [IntrNoMem, Mask<2>, VectorLength<3>] in { +def int_evl_reduce_add : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_mul : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_and : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_or : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_xor : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_smax : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_smin : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_umax : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_umin : Intrinsic<[llvm_anyint_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; + +def int_evl_reduce_fadd : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_fmul : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_fmax : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +def int_evl_reduce_fmin : Intrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, + llvm_anyvector_ty, + LLVMScalarOrSameVectorWidth<1, llvm_i1_ty>, + llvm_i32_ty]>; +} + +// Binary operators +let IntrProperties = [IntrNoMem, Mask<2>, VectorLength<3>] in { + def int_evl_add : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_sub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_mul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_sdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_udiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_srem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_urem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + + def int_evl_fadd : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_fsub : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_fmul : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_fdiv : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_frem : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Logical operators + def int_evl_ashr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_lshr : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_shl : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_or : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_and : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + def int_evl_xor : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty]>; + +// Comparison +// The last argument is the comparison predicate + def int_evl_cmp : Intrinsic<[ llvm_anyvector_ty ], + [ llvm_anyvector_ty, + LLVMMatchType<1>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty, + llvm_i8_ty]>; +} + + + +def int_evl_fneg : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<1>, VectorLength<2> ]>; + +def int_evl_fma : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<3>, VectorLength<4> ]>; + +// Shuffle +def int_evl_vshift: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + llvm_i32_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<2>, VectorLength<3> ]>; + +def int_evl_expand: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<2>, VectorLength<3> ]>; + +def int_evl_compress: Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + llvm_i32_ty], + [ IntrNoMem, Mask<2>, VectorLength<3> ]>; + +// Select +def int_evl_select : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, + LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty], + [ IntrNoMem, Passthru<2>, Mask<0>, VectorLength<3> ]>; + +// Compose +def int_evl_compose : Intrinsic<[ llvm_anyvector_ty ], + [ LLVMMatchType<0>, + LLVMMatchType<0>, + llvm_i32_ty, + llvm_i32_ty], + [ IntrNoMem, VectorLength<3> ]>; + + + //===-------------------------- Masked Intrinsics -------------------------===// // Index: include/llvm/IR/MatcherCast.h =================================================================== --- /dev/null +++ include/llvm/IR/MatcherCast.h @@ -0,0 +1,65 @@ +#ifndef LLVM_IR_MATCHERCAST_H +#define LLVM_IR_MATCHERCAST_H + +//===- MatcherCast.h - Match on the LLVM IR --------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Parameterized class hierachy for templatized pattern matching. +// +//===----------------------------------------------------------------------===// + + +namespace llvm { +namespace PatternMatch { + + +// type modification +template +struct MatcherCast { }; + +// whether the Value \p Obj behaves like a \p Class. +template +bool match_isa(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return isa(Obj); +} + +template +auto match_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(const Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + +template +auto match_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return cast(Obj); +} +template +auto match_dyn_cast(Value* Obj) { + using UnconstClass = typename std::remove_cv::type; + using DestClass = typename MatcherCast::ActualCastType; + return dyn_cast(Obj); +} + + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_MATCHERCAST_H + Index: include/llvm/IR/PatternMatch.h =================================================================== --- include/llvm/IR/PatternMatch.h +++ include/llvm/IR/PatternMatch.h @@ -39,22 +39,81 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" +#include "llvm/IR/MatcherCast.h" + #include + namespace llvm { namespace PatternMatch { +// Use verbatim types in default (empty) context. +struct EmptyContext { + EmptyContext() {} + + EmptyContext(const Value *) {} + + EmptyContext(const EmptyContext & E) {} + + // reset this match context to be rooted at \p V + void reset(Value * V) {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { return true; } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(EmptyContext E) const { return true; } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(EmptyContext E) { return true; } + + // reset this context to \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match in the current context + template bool try_match(Val *V, const Pattern &P) { + return const_cast(P).match_context(V, *this); + } +}; + +template +struct MatcherCast { using ActualCastType = DestClass; }; + + + + + + +// match without (== empty) context template bool match(Val *V, const Pattern &P) { - return const_cast(P).match(V); + EmptyContext ECtx; + return const_cast(P).match_context(V, ECtx); +} + +// match pattern in a given context +template bool match(Val *V, const Pattern &P, MatchContext & MContext) { + return const_cast(P).match_context(V, MContext); } + + template struct OneUse_match { SubPattern_t SubPattern; OneUse_match(const SubPattern_t &SP) : SubPattern(SP) {} template bool match(OpTy *V) { - return V->hasOneUse() && SubPattern.match(V); + EmptyContext EContext; return match_context(V, EContext); + } + + template bool match_context(OpTy *V, MatchContext & MContext) { + return V->hasOneUse() && SubPattern.match_context(V, MContext); } }; @@ -63,7 +122,11 @@ } template struct class_match { - template bool match(ITy *V) { return isa(V); } + template bool match(ITy *V) { + EmptyContext EContext; return match_context(V, EContext); + } + template + bool match_context(ITy *V, MatchContext & MContext) { return match_isa(V); } }; /// Match an arbitrary value and ignore it. @@ -95,11 +158,17 @@ match_combine_or(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + MatchContext SubContext; + + if (L.match_context(V, SubContext) && MContext.acceptContext(SubContext)) { + MContext.mergeContext(SubContext); return true; - if (R.match(V)) + } + if (R.match_context(V, MContext)) { return true; + } return false; } }; @@ -110,9 +179,10 @@ match_combine_and(const LTy &Left, const RTy &Right) : L(Left), R(Right) {} - template bool match(ITy *V) { - if (L.match(V)) - if (R.match(V)) + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { + if (L.match_context(V, MContext)) + if (R.match_context(V, MContext)) return true; return false; } @@ -135,7 +205,8 @@ apint_match(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValue(); return true; @@ -155,7 +226,8 @@ struct apfloat_match { const APFloat *&Res; apfloat_match(const APFloat *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); return true; @@ -179,7 +251,8 @@ inline apfloat_match m_APFloat(const APFloat *&Res) { return Res; } template struct constantint_match { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) { const APInt &CIV = CI->getValue(); if (Val >= 0) @@ -202,7 +275,8 @@ /// satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cst_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) return this->isValue(CI->getValue()); if (V->getType()->isVectorTy()) { @@ -239,7 +313,8 @@ api_pred_ty(const APInt *&R) : Res(R) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CI = dyn_cast(V)) if (this->isValue(CI->getValue())) { Res = &CI->getValue(); @@ -261,7 +336,8 @@ /// constants that satisfy a specified predicate. /// For vector constants, undefined elements are ignored. template struct cstfp_pred_ty : public Predicate { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CF = dyn_cast(V)) return this->isValue(CF->getValueAPF()); if (V->getType()->isVectorTy()) { @@ -365,7 +441,8 @@ } struct is_zero { - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { auto *C = dyn_cast(V); return C && (C->isNullValue() || cst_pred_ty().match(C)); } @@ -461,8 +538,11 @@ bind_ty(Class *&V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (auto *CV = dyn_cast(V)) { + if (!MContext.acceptBoundNode(V)) return false; + VR = CV; return true; } @@ -494,7 +574,8 @@ specificval_ty(const Value *V) : Val(V) {} - template bool match(ITy *V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { return V == Val; } }; /// Match if we have a specific specified value. @@ -507,7 +588,8 @@ deferredval_ty(Class *const &V) : Val(V) {} - template bool match(ITy *const V) { return V == Val; } + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *const V, MatchContext & MContext) { return V == Val; } }; /// A commutative-friendly version of m_Specific(). @@ -523,7 +605,8 @@ specific_fpval(double V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CFP = dyn_cast(V)) return CFP->isExactlyValue(Val); if (V->getType()->isVectorTy()) @@ -546,7 +629,8 @@ bind_const_intval_ty(uint64_t &V) : VR(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { if (const auto *CV = dyn_cast(V)) if (CV->getValue().ule(UINT64_MAX)) { VR = CV->getZExtValue(); @@ -563,7 +647,8 @@ specific_intval(uint64_t V) : Val(V) {} - template bool match(ITy *V) { + template bool match(ITy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(ITy *V, MatchContext & MContext) { const auto *CI = dyn_cast(V); if (!CI && V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) @@ -593,11 +678,16 @@ // The LHS is always matched first. AnyBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (!I) return false; + + if (!MContext.acceptInnerNode(I)) return false; + + MatchContext LRContext(MContext); + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; return false; } }; @@ -621,12 +711,15 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0))); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto * I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode) { + MatchContext LRContext(MContext); + if (!MContext.acceptInnerNode(I)) return false; + if (L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext))) return true; + return false; } if (auto *CE = dyn_cast(V)) return CE->getOpcode() == Opcode && @@ -665,20 +758,21 @@ Op_t X; FNeg_match(const Op_t &Op) : X(Op) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { auto *FPMO = dyn_cast(V); - if (!FPMO || FPMO->getOpcode() != Instruction::FSub) + if (!FPMO || match_cast(V)->getOpcode() != Instruction::FSub) return false; if (FPMO->hasNoSignedZeros()) { // With 'nsz', any zero goes. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } else { // Without 'nsz', we need fsub -0.0, X exactly. - if (!cstfp_pred_ty().match(FPMO->getOperand(0))) + if (!cstfp_pred_ty().match_context(FPMO->getOperand(0), MContext)) return false; } - return X.match(FPMO->getOperand(1)); + return X.match_context(FPMO->getOperand(1), MContext); } }; @@ -789,7 +883,8 @@ OverflowingBinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *Op = dyn_cast(V)) { if (Op->getOpcode() != Opcode) return false; @@ -799,7 +894,7 @@ if (WrapFlags & OverflowingBinaryOperator::NoSignedWrap && !Op->hasNoSignedWrap()) return false; - return L.match(Op->getOperand(0)) && R.match(Op->getOperand(1)); + return L.match_context(Op->getOperand(0), MContext) && R.match_context(Op->getOperand(1), MContext); } return false; } @@ -881,10 +976,11 @@ BinOpPred_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - return this->isOpType(I->getOpcode()) && L.match(I->getOperand(0)) && - R.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) + return this->isOpType(I->getOpcode()) && L.match_context(I->getOperand(0), MContext) && + R.match_context(I->getOperand(1), MContext); if (auto *CE = dyn_cast(V)) return this->isOpType(CE->getOpcode()) && L.match(CE->getOperand(0)) && R.match(CE->getOperand(1)); @@ -963,9 +1059,10 @@ Exact_match(const SubPattern_t &SP) : SubPattern(SP) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { if (auto *PEO = dyn_cast(V)) - return PEO->isExact() && SubPattern.match(V); + return PEO->isExact() && SubPattern.match_context(V, MContext); return false; } }; @@ -990,14 +1087,17 @@ CmpClass_match(PredicateTy &Pred, const LHS_t &LHS, const RHS_t &RHS) : Predicate(Pred), L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (auto *I = dyn_cast(V)) - if ((L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || - (Commutable && L.match(I->getOperand(1)) && - R.match(I->getOperand(0)))) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *I = match_dyn_cast(V)) { + if (!MContext.acceptInnerNode(I)) return false; + MatchContext LRContext(MContext); + if ((L.match_context(I->getOperand(0), LRContext) && R.match_context(I->getOperand(1), LRContext) && MContext.mergeContext(LRContext)) || + (Commutable && (L.match_context(I->getOperand(1), MContext) && R.match_context(I->getOperand(0), MContext)))) { Predicate = I->getPredicate(); return true; } + } return false; } }; @@ -1030,10 +1130,11 @@ OneOps_match(const T0 &Op1) : Op1(Op1) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext); } return false; } @@ -1046,10 +1147,12 @@ TwoOps_match(const T0 &Op1, const T1 &Op2) : Op1(Op1), Op2(Op2) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext); } return false; } @@ -1065,11 +1168,13 @@ ThreeOps_match(const T0 &Op1, const T1 &Op2, const T2 &Op3) : Op1(Op1), Op2(Op2), Op3(Op3) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { - auto *I = cast(V); - return Op1.match(I->getOperand(0)) && Op2.match(I->getOperand(1)) && - Op3.match(I->getOperand(2)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + auto *I = match_dyn_cast(V); + if (I && I->getOpcode() == Opcode && MContext.acceptInnerNode(I)) { + return Op1.match_context(I->getOperand(0), MContext) && + Op2.match_context(I->getOperand(1), MContext) && + Op3.match_context(I->getOperand(2), MContext); } return false; } @@ -1137,9 +1242,10 @@ CastClass_match(const Op_t &OpMatch) : Op(OpMatch) {} - template bool match(OpTy *V) { - if (auto *O = dyn_cast(V)) - return O->getOpcode() == Opcode && Op.match(O->getOperand(0)); + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto O = match_dyn_cast(V)) + return O->getOpcode() == Opcode && MContext.acceptInnerNode(O) && Op.match_context(O->getOperand(0), MContext); return false; } }; @@ -1214,8 +1320,9 @@ br_match(BasicBlock *&Succ) : Succ(Succ) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isUnconditional()) { Succ = BI->getSuccessor(0); return true; @@ -1233,8 +1340,9 @@ brc_match(const Cond_t &C, BasicBlock *&t, BasicBlock *&f) : Cond(C), T(t), F(f) {} - template bool match(OpTy *V) { - if (auto *BI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (auto *BI = match_dyn_cast(V)) if (BI->isConditional() && Cond.match(BI->getCondition())) { T = BI->getSuccessor(0); F = BI->getSuccessor(1); @@ -1263,13 +1371,14 @@ // The LHS is always matched first. MaxMin_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // Look for "(x pred y) ? x : y" or "(x pred y) ? y : x". - auto *SI = dyn_cast(V); - if (!SI) + auto *SI = match_dyn_cast(V); + if (!SI || !MContext.acceptInnerNode(SI)) return false; - auto *Cmp = dyn_cast(SI->getCondition()); - if (!Cmp) + auto *Cmp = match_dyn_cast(SI->getCondition()); + if (!Cmp || !MContext.acceptInnerNode(Cmp)) return false; // At this point we have a select conditioned on a comparison. Check that // it is the values returned by the select that are being compared. @@ -1285,9 +1394,12 @@ // Does "(x pred y) ? x : y" represent the desired max/min operation? if (!Pred_t::match(Pred)) return false; + // It does! Bind the operands. - return (L.match(LHS) && R.match(RHS)) || - (Commutable && L.match(RHS) && R.match(LHS)); + MatchContext LRContext(MContext); + if (L.match_context(LHS, LRContext) && R.match_context(RHS, LRContext) && MContext.mergeContext(LRContext)) return true; + if (Commutable && (L.match_context(RHS, MContext) && R.match_context(LHS, MContext))) return true; + return false; } }; @@ -1444,7 +1556,8 @@ UAddWithOverflow_match(const LHS_t &L, const RHS_t &R, const Sum_t &S) : L(L), R(R), S(S) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { Value *ICmpLHS, *ICmpRHS; ICmpInst::Predicate Pred; if (!m_ICmp(Pred, m_Value(ICmpLHS), m_Value(ICmpRHS)).match(V)) @@ -1483,9 +1596,10 @@ Argument_match(unsigned OpIdx, const Opnd_t &V) : OpI(OpIdx), Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { // FIXME: Should likely be switched to use `CallBase`. - if (const auto *CI = dyn_cast(V)) + if (const auto *CI = match_dyn_cast(V)) return Val.match(CI->getArgOperand(OpI)); return false; } @@ -1503,8 +1617,9 @@ IntrinsicID_match(Intrinsic::ID IntrID) : ID(IntrID) {} - template bool match(OpTy *V) { - if (const auto *CI = dyn_cast(V)) + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { + if (const auto *CI = match_dyn_cast(V)) if (const auto *F = CI->getCalledFunction()) return F->getIntrinsicID() == ID; return false; @@ -1714,7 +1829,8 @@ Opnd_t Val; Signum_match(const Opnd_t &V) : Val(V) {} - template bool match(OpTy *V) { + template bool match(OpTy *V) { EmptyContext EContext; return match_context(V, EContext); } + template bool match_context(OpTy *V, MatchContext & MContext) { unsigned TypeSize = V->getType()->getScalarSizeInBits(); if (TypeSize == 0) return false; Index: include/llvm/IR/PredicatedInst.h =================================================================== --- /dev/null +++ include/llvm/IR/PredicatedInst.h @@ -0,0 +1,369 @@ +//===-- llvm/PredicatedInst.h - Predication utility subclass --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines various classes for working with Predicated Instructions. +// Predicated instructions are either regular instructions or calls to +// Explicit Vector Length (EVL) intrinsics that have a mask and an explicit +// vector length argument. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_IR_PREDICATEDINST_H +#define LLVM_IR_PREDICATEDINST_H + +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/MatcherCast.h" + +#include + +namespace llvm { + +class BasicBlock; + +class PredicatedInstruction : public User { +public: + // The PredicatedInstruction class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedInstruction() = delete; + ~PredicatedInstruction() = delete; + + void copyIRFlags(const Value * V, bool IncludeWrapFlags) { + cast(this)->copyIRFlags(V, IncludeWrapFlags); + } + + BasicBlock* getParent() { return cast(this)->getParent(); } + const BasicBlock* getParent() const { return cast(this)->getParent(); } + + void *operator new(size_t s) = delete; + + Value* getMask() const { + auto thisEVL = dyn_cast(this); + if (!thisEVL) return nullptr; + return thisEVL->getMask(); + } + + Value* getVectorLength() const { + auto thisEVL = dyn_cast(this); + if (!thisEVL) return nullptr; + return thisEVL->getVectorLength(); + } + + unsigned getOpcode() const { + auto * EVLInst = dyn_cast(this); + if (EVLInst) + return EVLInst->getFunctionalOpcode(); + return cast(this)->getOpcode(); + } + +#if 0 + operator Instruction() { return cast(this); } + operator const Value() const { return cast(this); } +#endif + + static bool classof(const Instruction * I) { return isa(I); } + static bool classof(const ConstantExpr * CE) { return false; } + static bool classof(const Value *V) { return isa(V); } +}; + +class PredicatedOperator : public User { +public: + // The PredicatedOperator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedOperator() = delete; + ~PredicatedOperator() = delete; + + void *operator new(size_t s) = delete; + +#if 0 + operator Value*() { return cast(this); } + operator const Value*() const { return cast(this); } +#endif + + /// Return the opcode for this Instruction or ConstantExpr. + unsigned getOpcode() const { + auto * EVLInst = dyn_cast(this); + if (EVLInst) + return EVLInst->getFunctionalOpcode(); + if (const Instruction *I = dyn_cast(this)) + return I->getOpcode(); + return cast(this)->getOpcode(); + } + + Value* getMask() const { + auto thisEVL = dyn_cast(this); + if (!thisEVL) return nullptr; + return thisEVL->getMask(); + } + + Value* getVectorLength() const { + auto thisEVL = dyn_cast(this); + if (!thisEVL) return nullptr; + return thisEVL->getVectorLength(); + } + + void copyIRFlags(const Value * V, bool IncludeWrapFlags = true); + FastMathFlags getFastMathFlags() const { + auto * I = dyn_cast(this); + if (I) return I->getFastMathFlags(); + else return FastMathFlags(); + } + + static bool classof(const Instruction * I) { return isa(I) || isa(I); } + static bool classof(const ConstantExpr * CE) { return isa(CE); } + static bool classof(const Value *V) { return isa(V) || isa(V); } +}; + +class PredicatedBinaryOperator : public PredicatedOperator { +public: + // The PredicatedBinaryOperator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedBinaryOperator() = delete; + ~PredicatedBinaryOperator() = delete; + + using BinaryOps = Instruction::BinaryOps; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto EVLInst = dyn_cast(I); + return EVLInst && EVLInst->isBinaryOp(); + } + static bool classof(const ConstantExpr * CE) { return isa(CE); } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && classof(CE); + } + + /// Construct a predicated binary instruction, given the opcode and the two + /// operands. + static Instruction* Create(Module * Mod, + Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock * InsertAtEnd, + Instruction * InsertBefore); + + static Instruction* Create(Module *Mod, + Value *Mask, Value *VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name = Twine(), + Instruction *InsertBefore = nullptr) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, InsertBefore); + } + + static Instruction* Create(Module *Mod, + Value *Mask, Value *VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock *InsertAtEnd) { + return Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, InsertAtEnd, nullptr); + } + + static Instruction* CreateWithCopiedFlags(Module *Mod, + Value *Mask, Value* VectorLen, + BinaryOps Opc, + Value *V1, Value *V2, + Instruction *CopyBO, + const Twine &Name = "") { + Instruction *BO = Create(Mod, Mask, VectorLen, Opc, V1, V2, Name, nullptr, nullptr); + BO->copyIRFlags(CopyBO); + return BO; + } +}; + +class PredicatedICmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedICmpInst() = delete; + ~PredicatedICmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto EVLInst = dyn_cast(I); + return EVLInst && EVLInst->getFunctionalOpcode() == Instruction::ICmp; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::ICmp; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && classof(CE); + } + + ICmpInst::Predicate getPredicate() const { + auto * ICInst = dyn_cast(this); + if (ICInst) return ICInst->getPredicate(); + auto * CE = dyn_cast(this); + if (CE) return static_cast(CE->getPredicate()); + return static_cast(cast(this)->getCmpPredicate()); + } +}; + + +class PredicatedFCmpInst : public PredicatedBinaryOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedFCmpInst() = delete; + ~PredicatedFCmpInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto EVLInst = dyn_cast(I); + return EVLInst && EVLInst->getFunctionalOpcode() == Instruction::FCmp; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::FCmp; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + return isa(V); + } + + FCmpInst::Predicate getPredicate() const { + auto * FCInst = dyn_cast(this); + if (FCInst) return FCInst->getPredicate(); + auto * CE = dyn_cast(this); + if (CE) return static_cast(CE->getPredicate()); + return static_cast(cast(this)->getCmpPredicate()); + } +}; + + +class PredicatedSelectInst : public PredicatedOperator { +public: + // The Operator class is intended to be used as a utility, and is never itself + // instantiated. + PredicatedSelectInst() = delete; + ~PredicatedSelectInst() = delete; + + void *operator new(size_t s) = delete; + + static bool classof(const Instruction * I) { + if (isa(I)) return true; + auto EVLInst = dyn_cast(I); + return EVLInst && EVLInst->getFunctionalOpcode() == Instruction::Select; + } + static bool classof(const ConstantExpr * CE) { return CE->getOpcode() == Instruction::Select; } + static bool classof(const Value *V) { + auto * I = dyn_cast(V); + if (I && classof(I)) return true; + auto * CE = dyn_cast(V); + return CE && CE->getOpcode() == Instruction::Select; + } + + const Value *getCondition() const { return getOperand(0); } + const Value *getTrueValue() const { return getOperand(1); } + const Value *getFalseValue() const { return getOperand(2); } + Value *getCondition() { return getOperand(0); } + Value *getTrueValue() { return getOperand(1); } + Value *getFalseValue() { return getOperand(2); } + + void setCondition(Value *V) { setOperand(0, V); } + void setTrueValue(Value *V) { setOperand(1, V); } + void setFalseValue(Value *V) { setOperand(2, V); } +}; + + +namespace PatternMatch { + +// PredicatedMatchContext for pattern matching +struct PredicatedContext { + Value * Mask; + Value * VectorLength; + Module * Mod; + + void reset(Value * V) { + auto * PI = dyn_cast(V); + if (!PI) { + VectorLength = nullptr; + Mask = nullptr; + Mod = nullptr; + } else { + VectorLength = PI->getVectorLength(); + Mask = PI->getMask(); + Mod = PI->getParent()->getParent()->getParent(); + } + } + + PredicatedContext(Value * Val) + : Mask(nullptr) + , VectorLength(nullptr) + , Mod(nullptr) + { + reset(Val); + } + + PredicatedContext(const PredicatedContext & PC) + : Mask(PC.Mask) + , VectorLength(PC.VectorLength) + , Mod(PC.Mod) + {} + + // accept a match where \p Val is in a non-leaf position in a match pattern + bool acceptInnerNode(const Value * Val) const { + auto PredI = dyn_cast(Val); + if (!PredI) return VectorLength == nullptr && Mask == nullptr; + return VectorLength == PredI->getVectorLength() && Mask == PredI->getMask(); + } + + // accept a match where \p Val is bound to a free variable. + bool acceptBoundNode(const Value * Val) const { return true; } + + // whether this context is compatiable with \p E. + bool acceptContext(PredicatedContext PC) const { + return PC.Mask == Mask && PC.VectorLength == VectorLength; + } + + // merge the context \p E into this context and return whether the resulting context is valid. + bool mergeContext(PredicatedContext PC) const { return acceptContext(PC); } + + // match \p P in a new contest for \p Val. + template bool reset_match(Val *V, const Pattern &P) { + reset(V); + return const_cast(P).match_context(V, *this); + } + + // match \p P in the current context. + template bool try_match(Val *V, const Pattern &P) { + PredicatedContext SubContext(*this); + return const_cast(P).match_context(V, SubContext); + } +}; + +struct PredicatedContext; +template<> struct MatcherCast { using ActualCastType = PredicatedBinaryOperator; }; +template<> struct MatcherCast { using ActualCastType = PredicatedOperator; }; +template<> struct MatcherCast { using ActualCastType = PredicatedICmpInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedFCmpInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedSelectInst; }; +template<> struct MatcherCast { using ActualCastType = PredicatedInstruction; }; + +} // namespace PatternMatch + +} // namespace llvm + +#endif // LLVM_IR_PREDICATEDINST_H Index: include/llvm/Target/TargetSelectionDAG.td =================================================================== --- include/llvm/Target/TargetSelectionDAG.td +++ include/llvm/Target/TargetSelectionDAG.td @@ -128,6 +128,13 @@ SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<3> ]>; +def SDTIntBinOpEVL : SDTypeProfile<1, 4, [ // evl_add, evl_and, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisInt<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTIntShiftOpEVL : SDTypeProfile<1, 4, [ // shl, sra, srl + SDTCisSameAs<0, 1>, SDTCisInt<0>, SDTCisInt<2>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; + def SDTFPBinOp : SDTypeProfile<1, 2, [ // fadd, fmul, etc. SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0> ]>; @@ -170,6 +177,16 @@ SDTCisOpSmallerThanOp<1, 0> ]>; +def SDTFPUnOpEVL : SDTypeProfile<1, 3, [ // evl_fneg, etc. + SDTCisSameAs<0, 1>, SDTCisFP<0>, SDTCisInt<3>, SDTCisSameNumEltsAs<0, 2> +]>; +def SDTFPBinOpEVL : SDTypeProfile<1, 4, [ // evl_fadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisFP<0>, SDTCisInt<4>, SDTCisSameNumEltsAs<0, 3> +]>; +def SDTFPTernaryOpEVL : SDTypeProfile<1, 5, [ // evl_fmadd, etc. + SDTCisSameAs<0, 1>, SDTCisSameAs<0, 2>, SDTCisSameAs<0, 3>, SDTCisFP<0>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSetCC : SDTypeProfile<1, 3, [ // setcc SDTCisInt<0>, SDTCisSameAs<1, 2>, SDTCisVT<3, OtherVT> ]>; @@ -182,6 +199,10 @@ SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1> ]>; +def SDTVSelectEVL : SDTypeProfile<1, 5, [ // evl_vselect + SDTCisVec<0>, SDTCisInt<1>, SDTCisSameAs<0, 2>, SDTCisSameAs<2, 3>, SDTCisSameNumEltsAs<0, 1>, SDTCisInt<5>, SDTCisSameNumEltsAs<0, 4> +]>; + def SDTSelectCC : SDTypeProfile<1, 5, [ // select_cc SDTCisSameAs<1, 2>, SDTCisSameAs<3, 4>, SDTCisSameAs<0, 3>, SDTCisVT<5, OtherVT> @@ -225,11 +246,20 @@ SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTStoreEVL: SDTypeProfile<0, 4, [ // evl store + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3> +]>; + def SDTMaskedLoad: SDTypeProfile<1, 3, [ // masked load SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisVec<2>, SDTCisSameAs<0, 3>, SDTCisSameNumEltsAs<0, 2> ]>; +def SDTLoadEVL : SDTypeProfile<1, 3, [ // evl load + SDTCisVec<0>, SDTCisPtrTy<1>, SDTCisSameNumEltsAs<0, 2>, SDTCisInt<3>, + SDTCisSameNumEltsAs<0, 2> +]>; + def SDTVecShuffle : SDTypeProfile<1, 2, [ SDTCisSameAs<0, 1>, SDTCisSameAs<1, 2> ]>; @@ -385,6 +415,26 @@ def umax : SDNode<"ISD::UMAX" , SDTIntBinOp, [SDNPCommutative, SDNPAssociative]>; +def evl_and : SDNode<"ISD::EVL_AND" , SDTIntBinOpEVL , + [SDNPCommutative, SDNPAssociative]>; +def evl_or : SDNode<"ISD::EVL_OR" , SDTIntBinOpEVL , + [SDNPCommutative, SDNPAssociative]>; +def evl_xor : SDNode<"ISD::EVL_XOR" , SDTIntBinOpEVL , + [SDNPCommutative, SDNPAssociative]>; +def evl_srl : SDNode<"ISD::EVL_SRL" , SDTIntShiftOpEVL>; +def evl_sra : SDNode<"ISD::EVL_SRA" , SDTIntShiftOpEVL>; +def evl_shl : SDNode<"ISD::EVL_SHL" , SDTIntShiftOpEVL>; + +def evl_add : SDNode<"ISD::EVL_ADD" , SDTIntBinOpEVL , + [SDNPCommutative, SDNPAssociative]>; +def evl_sub : SDNode<"ISD::EVL_SUB" , SDTIntBinOpEVL>; +def evl_mul : SDNode<"ISD::EVL_MUL" , SDTIntBinOpEVL, + [SDNPCommutative, SDNPAssociative]>; +def evl_sdiv : SDNode<"ISD::EVL_SDIV" , SDTIntBinOpEVL>; +def evl_udiv : SDNode<"ISD::EVL_UDIV" , SDTIntBinOpEVL>; +def evl_srem : SDNode<"ISD::EVL_SREM" , SDTIntBinOpEVL>; +def evl_urem : SDNode<"ISD::EVL_UREM" , SDTIntBinOpEVL>; + def saddsat : SDNode<"ISD::SADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def uaddsat : SDNode<"ISD::UADDSAT" , SDTIntBinOp, [SDNPCommutative]>; def ssubsat : SDNode<"ISD::SSUBSAT" , SDTIntBinOp>; @@ -452,6 +502,14 @@ def fpextend : SDNode<"ISD::FP_EXTEND" , SDTFPExtendOp>; def fcopysign : SDNode<"ISD::FCOPYSIGN" , SDTFPSignOp>; +def evl_fneg : SDNode<"ISD::EVL_FNEG" , SDTFPUnOpEVL>; +def evl_fadd : SDNode<"ISD::EVL_FADD" , SDTFPBinOpEVL, [SDNPCommutative]>; +def evl_fsub : SDNode<"ISD::EVL_FSUB" , SDTFPBinOpEVL>; +def evl_fmul : SDNode<"ISD::EVL_FMUL" , SDTFPBinOpEVL, [SDNPCommutative]>; +def evl_fdiv : SDNode<"ISD::EVL_FDIV" , SDTFPBinOpEVL>; +def evl_frem : SDNode<"ISD::EVL_FREM" , SDTFPBinOpEVL>; +def evl_fma : SDNode<"ISD::EVL_FMA" , SDTFPTernaryOpEVL>; + def sint_to_fp : SDNode<"ISD::SINT_TO_FP" , SDTIntToFPOp>; def uint_to_fp : SDNode<"ISD::UINT_TO_FP" , SDTIntToFPOp>; def fp_to_sint : SDNode<"ISD::FP_TO_SINT" , SDTFPToIntOp>; @@ -459,10 +517,10 @@ def f16_to_fp : SDNode<"ISD::FP16_TO_FP" , SDTIntToFPOp>; def fp_to_f16 : SDNode<"ISD::FP_TO_FP16" , SDTFPToIntOp>; -def setcc : SDNode<"ISD::SETCC" , SDTSetCC>; -def select : SDNode<"ISD::SELECT" , SDTSelect>; -def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>; -def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>; +def setcc : SDNode<"ISD::SETCC" , SDTSetCC>; +def select : SDNode<"ISD::SELECT" , SDTSelect>; +def vselect : SDNode<"ISD::VSELECT" , SDTVSelect>; +def selectcc : SDNode<"ISD::SELECT_CC" , SDTSelectCC>; def brcc : SDNode<"ISD::BR_CC" , SDTBrCC, [SDNPHasChain]>; def brcond : SDNode<"ISD::BRCOND" , SDTBrcond, [SDNPHasChain]>; @@ -530,6 +588,11 @@ def masked_load : SDNode<"ISD::MLOAD", SDTMaskedLoad, [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; +def evl_store : SDNode<"ISD::EVL_STORE", SDTMaskedStore, + [SDNPHasChain, SDNPMayStore, SDNPMemOperand]>; +def evl_load : SDNode<"ISD::EVL_LOAD", SDTMaskedLoad, + [SDNPHasChain, SDNPMayLoad, SDNPMemOperand]>; + // Do not use ld, st directly. Use load, extload, sextload, zextload, store, // and truncst (see below). def ld : SDNode<"ISD::LOAD" , SDTLoad, Index: lib/Analysis/InstructionSimplify.cpp =================================================================== --- lib/Analysis/InstructionSimplify.cpp +++ lib/Analysis/InstructionSimplify.cpp @@ -35,6 +35,7 @@ #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Support/KnownBits.h" #include @@ -4559,8 +4560,10 @@ /// Given operands for an FSub, see if we can fold the result. If not, this /// returns null. -static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const SimplifyQuery &Q, unsigned MaxRecurse) { +template +static Value *SimplifyFSubInstGeneric(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse, MatchContext & MC) { + if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; @@ -4568,23 +4571,23 @@ return C; // fsub X, +0 ==> X - if (match(Op1, m_PosZeroFP())) + if (MC.try_match(Op1, m_PosZeroFP())) return Op0; // fsub X, -0 ==> X, when we know X is not -0 - if (match(Op1, m_NegZeroFP()) && + if (MC.try_match(Op1, m_NegZeroFP()) && (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fsub -0.0, (fsub -0.0, X) ==> X Value *X; - if (match(Op0, m_NegZeroFP()) && - match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) + if (MC.try_match(Op0, m_NegZeroFP()) && + MC.try_match(Op1, m_FSub(m_NegZeroFP(), m_Value(X)))) return X; // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && - match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) + MC.try_match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X)))) return X; // fsub nnan x, x ==> 0.0 @@ -4594,13 +4597,20 @@ // Y - (Y - X) --> X // (X + Y) - Y --> X if (FMF.noSignedZeros() && FMF.allowReassoc() && - (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || - match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + (MC.try_match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + MC.try_match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) return X; return nullptr; } +static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, unsigned MaxRecurse) { + EmptyContext EC; + return SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, MaxRecurse, EC); +} + + /// Given the operands for an FMul, see if we can fold the result static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -4641,6 +4651,11 @@ return ::SimplifyFSubInst(Op0, Op1, FMF, Q, RecursionLimit); } +Value *llvm::SimplifyPredicatedFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, + const SimplifyQuery &Q, PredicatedContext & PC) { + return ::SimplifyFSubInstGeneric(Op0, Op1, FMF, Q, RecursionLimit, PC); +} + Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q) { return ::SimplifyFMulInst(Op0, Op1, FMF, Q, RecursionLimit); @@ -5190,9 +5205,20 @@ Q, RecursionLimit); } +Value *llvm::SimplifyEVLIntrinsic(EVLIntrinsic & EVLInst, const SimplifyQuery &Q) { + PredicatedContext PC(&EVLInst); + + auto & PI = cast(EVLInst); + switch (PI.getOpcode()) { + default: + return nullptr; + + case Instruction::FSub: return SimplifyPredicatedFSubInst(EVLInst.getOperand(0), EVLInst.getOperand(1), EVLInst.getFastMathFlags(), Q, PC); + } +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. - Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); @@ -5326,6 +5352,12 @@ Result = SimplifyPHINode(cast(I), Q); break; case Instruction::Call: { + auto * EVL = dyn_cast(I); + if (EVL) { + Result = SimplifyEVLIntrinsic(*EVL, Q); + if (Result) break; + } + CallSite CS(cast(I)); Result = SimplifyCall(CS, Q); break; Index: lib/AsmParser/LLLexer.cpp =================================================================== --- lib/AsmParser/LLLexer.cpp +++ lib/AsmParser/LLLexer.cpp @@ -642,6 +642,7 @@ KEYWORD(inlinehint); KEYWORD(inreg); KEYWORD(jumptable); + KEYWORD(mask); KEYWORD(minsize); KEYWORD(naked); KEYWORD(nest); @@ -661,6 +662,7 @@ KEYWORD(optforfuzzing); KEYWORD(optnone); KEYWORD(optsize); + KEYWORD(passthru); KEYWORD(readnone); KEYWORD(readonly); KEYWORD(returned); @@ -682,6 +684,7 @@ KEYWORD(swifterror); KEYWORD(swiftself); KEYWORD(uwtable); + KEYWORD(vlen); KEYWORD(writeonly); KEYWORD(zeroext); Index: lib/AsmParser/LLParser.cpp =================================================================== --- lib/AsmParser/LLParser.cpp +++ lib/AsmParser/LLParser.cpp @@ -1294,14 +1294,17 @@ case lltok::kw_dereferenceable: case lltok::kw_dereferenceable_or_null: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_noalias: case lltok::kw_nocapture: case lltok::kw_nonnull: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute on a function"); @@ -1582,10 +1585,12 @@ } case lltok::kw_inalloca: B.addAttribute(Attribute::InAlloca); break; case lltok::kw_inreg: B.addAttribute(Attribute::InReg); break; + case lltok::kw_mask: B.addAttribute(Attribute::Mask); break; case lltok::kw_nest: B.addAttribute(Attribute::Nest); break; case lltok::kw_noalias: B.addAttribute(Attribute::NoAlias); break; case lltok::kw_nocapture: B.addAttribute(Attribute::NoCapture); break; case lltok::kw_nonnull: B.addAttribute(Attribute::NonNull); break; + case lltok::kw_passthru: B.addAttribute(Attribute::Passthru); break; case lltok::kw_readnone: B.addAttribute(Attribute::ReadNone); break; case lltok::kw_readonly: B.addAttribute(Attribute::ReadOnly); break; case lltok::kw_returned: B.addAttribute(Attribute::Returned); break; @@ -1593,6 +1598,7 @@ case lltok::kw_sret: B.addAttribute(Attribute::StructRet); break; case lltok::kw_swifterror: B.addAttribute(Attribute::SwiftError); break; case lltok::kw_swiftself: B.addAttribute(Attribute::SwiftSelf); break; + case lltok::kw_vlen: B.addAttribute(Attribute::VectorLength); break; case lltok::kw_writeonly: B.addAttribute(Attribute::WriteOnly); break; case lltok::kw_zeroext: B.addAttribute(Attribute::ZExt); break; @@ -1683,12 +1689,15 @@ // Error handling. case lltok::kw_byval: case lltok::kw_inalloca: + case lltok::kw_mask: case lltok::kw_nest: case lltok::kw_nocapture: + case lltok::kw_passthru: case lltok::kw_returned: case lltok::kw_sret: case lltok::kw_swifterror: case lltok::kw_swiftself: + case lltok::kw_vlen: HaveError |= Error(Lex.getLoc(), "invalid use of parameter-only attribute"); break; @@ -3294,7 +3303,7 @@ ID.Kind = ValID::t_Constant; return false; } - + // Unary Operators. case lltok::kw_fneg: { unsigned Opc = Lex.getUIntVal(); @@ -3304,7 +3313,7 @@ ParseGlobalTypeAndValue(Val) || ParseToken(lltok::rparen, "expected ')' in unary constantexpr")) return true; - + // Check that the type is valid for the operator. switch (Opc) { case Instruction::FNeg: @@ -6169,11 +6178,11 @@ Valid = LHS->getType()->isIntOrIntVectorTy() || LHS->getType()->isFPOrFPVectorTy(); break; - case 1: - Valid = LHS->getType()->isIntOrIntVectorTy(); + case 1: + Valid = LHS->getType()->isIntOrIntVectorTy(); break; - case 2: - Valid = LHS->getType()->isFPOrFPVectorTy(); + case 2: + Valid = LHS->getType()->isFPOrFPVectorTy(); break; } Index: lib/AsmParser/LLToken.h =================================================================== --- lib/AsmParser/LLToken.h +++ lib/AsmParser/LLToken.h @@ -186,6 +186,7 @@ kw_inlinehint, kw_inreg, kw_jumptable, + kw_mask, kw_minsize, kw_naked, kw_nest, @@ -205,6 +206,7 @@ kw_optforfuzzing, kw_optnone, kw_optsize, + kw_passthru, kw_readnone, kw_readonly, kw_returned, @@ -224,6 +226,7 @@ kw_swifterror, kw_swiftself, kw_uwtable, + kw_vlen, kw_writeonly, kw_zeroext, Index: lib/Bitcode/Reader/BitcodeReader.cpp =================================================================== --- lib/Bitcode/Reader/BitcodeReader.cpp +++ lib/Bitcode/Reader/BitcodeReader.cpp @@ -1332,6 +1332,8 @@ return Attribute::InReg; case bitc::ATTR_KIND_JUMP_TABLE: return Attribute::JumpTable; + case bitc::ATTR_KIND_MASK: + return Attribute::Mask; case bitc::ATTR_KIND_MIN_SIZE: return Attribute::MinSize; case bitc::ATTR_KIND_NAKED: @@ -1376,6 +1378,8 @@ return Attribute::OptimizeForSize; case bitc::ATTR_KIND_OPTIMIZE_NONE: return Attribute::OptimizeNone; + case bitc::ATTR_KIND_PASSTHRU: + return Attribute::Passthru; case bitc::ATTR_KIND_READ_NONE: return Attribute::ReadNone; case bitc::ATTR_KIND_READ_ONLY: @@ -1420,6 +1424,8 @@ return Attribute::SwiftSelf; case bitc::ATTR_KIND_UW_TABLE: return Attribute::UWTable; + case bitc::ATTR_KIND_VECTORLENGTH: + return Attribute::VectorLength; case bitc::ATTR_KIND_WRITEONLY: return Attribute::WriteOnly; case bitc::ATTR_KIND_Z_EXT: Index: lib/Bitcode/Writer/BitcodeWriter.cpp =================================================================== --- lib/Bitcode/Writer/BitcodeWriter.cpp +++ lib/Bitcode/Writer/BitcodeWriter.cpp @@ -670,6 +670,12 @@ return bitc::ATTR_KIND_READ_ONLY; case Attribute::Returned: return bitc::ATTR_KIND_RETURNED; + case Attribute::Mask: + return bitc::ATTR_KIND_MASK; + case Attribute::VectorLength: + return bitc::ATTR_KIND_VECTORLENGTH; + case Attribute::Passthru: + return bitc::ATTR_KIND_PASSTHRU; case Attribute::ReturnsTwice: return bitc::ATTR_KIND_RETURNS_TWICE; case Attribute::SExt: Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -353,6 +353,7 @@ SDValue visitBITCAST(SDNode *N); SDValue visitBUILD_PAIR(SDNode *N); SDValue visitFADD(SDNode *N); + SDValue visitFADD_EVL(SDNode *N); SDValue visitFSUB(SDNode *N); SDValue visitFMUL(SDNode *N); SDValue visitFMA(SDNode *N); @@ -400,6 +401,7 @@ SDValue visitFP_TO_FP16(SDNode *N); SDValue visitFP16_TO_FP(SDNode *N); + template SDValue visitFADDForFMACombine(SDNode *N); SDValue visitFSUBForFMACombine(SDNode *N); SDValue visitFMULForFMADistributiveCombine(SDNode *N); @@ -641,6 +643,138 @@ } }; +// TODO port this to EVL nodes +struct EmptyMatchContext { + SelectionDAG & DAG; + + EmptyMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + {} + + bool match(SDValue OpN, unsigned OpCode) const { return OpCode == OpN->getOpcode(); } + + unsigned getFunctionOpCode(SDValue N) const { + return N->getOpcode(); + } + + bool isCompatible(SDValue OpVal) const { return true; } + + // Specialize based on number of operands. + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, Operand, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4); + } + + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, SDValue N4, SDValue N5) { + return DAG.getNode(Opcode, DL, VT, N1, N2, N3, N4, N5); + } +}; + +struct +EVLMatchContext { + SelectionDAG & DAG; + SDNode * Root; + SDValue RootMaskOp; + SDValue RootVectorLenOp; + + EVLMatchContext(SelectionDAG & DAG, SDNode * Root) + : DAG(DAG) + , Root(Root) + , RootMaskOp() + , RootVectorLenOp() + { + if (Root->isEVL()) { + int RootMaskPos = ISD::GetMaskPosEVL(Root->getOpcode()); + if (RootMaskPos != -1) { + RootMaskOp = Root->getOperand(RootMaskPos); + } + + int RootVLenPos = ISD::GetVectorLengthPosEVL(Root->getOpcode()); + if (RootVLenPos != -1) { + RootVectorLenOp = Root->getOperand(RootVLenPos); + } + } + } + + unsigned getFunctionOpCode(SDValue N) const { + unsigned EVLOpCode = N->getOpcode(); + return ISD::GetFunctionOpCodeForEVL(EVLOpCode); + } + + bool isCompatible(SDValue OpVal) const { + if (!OpVal->isEVL()) { + return !Root->isEVL(); + + } else { + unsigned EVLOpCode = OpVal->getOpcode(); + int MaskPos = ISD::GetMaskPosEVL(EVLOpCode); + if (MaskPos != -1 && RootMaskOp != OpVal.getOperand(MaskPos)) { + return false; + } + + int VLenPos = ISD::GetVectorLengthPosEVL(EVLOpCode); + if (VLenPos != -1 && RootVectorLenOp != OpVal.getOperand(VLenPos)) { + return false; + } + + return true; + } + } + + /// whether \p OpN is a node that is functionally compatible with the NodeType \p OpNodeTy + bool match(SDValue OpVal, unsigned OpNT) const { + return isCompatible(OpVal) && getFunctionOpCode(OpVal) == OpNT; + } + + // Specialize based on number of operands. + // TODO emit EVL intrinsics where MaskOp/VectorLenOp != null + // SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT) { return DAG.getNode(Opcode, DL, VT); } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue Operand, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned EVLOpcode = ISD::GetEVLForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosEVL(EVLOpcode); + int VLenPos = ISD::GetVectorLengthPosEVL(EVLOpcode); + assert(MaskPos == 1 && VLenPos == 2); + + return DAG.getNode(EVLOpcode, DL, VT, {Operand, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, const SDNodeFlags Flags = SDNodeFlags()) { + unsigned EVLOpcode = ISD::GetEVLForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosEVL(EVLOpcode); + int VLenPos = ISD::GetVectorLengthPosEVL(EVLOpcode); + assert(MaskPos == 2 && VLenPos == 3); + + return DAG.getNode(EVLOpcode, DL, VT, {N1, N2, RootMaskOp, RootVectorLenOp}, Flags); + } + SDValue getNode(unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N1, + SDValue N2, SDValue N3, + const SDNodeFlags Flags = SDNodeFlags()) { + unsigned EVLOpcode = ISD::GetEVLForFunctionOpCode(Opcode); + int MaskPos = ISD::GetMaskPosEVL(EVLOpcode); + int VLenPos = ISD::GetVectorLengthPosEVL(EVLOpcode); + assert(MaskPos == 3 && VLenPos == 4); + + return DAG.getNode(EVLOpcode, DL, VT, {N1, N2, N3, RootMaskOp, RootVectorLenOp}, Flags); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -1549,6 +1683,7 @@ case ISD::BITCAST: return visitBITCAST(N); case ISD::BUILD_PAIR: return visitBUILD_PAIR(N); case ISD::FADD: return visitFADD(N); + case ISD::EVL_FADD: return visitFADD_EVL(N); case ISD::FSUB: return visitFSUB(N); case ISD::FMUL: return visitFMUL(N); case ISD::FMA: return visitFMA(N); @@ -10444,13 +10579,18 @@ return F.hasAllowContract() || F.hasAllowReassociation(); } + /// Try to perform FMA combining on a given FADD node. +template SDValue DAGCombiner::visitFADDForFMACombine(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); EVT VT = N->getValueType(0); SDLoc SL(N); + MatchContextClass matcher(DAG, N); + if (!matcher.isCompatible(N0) || !matcher.isCompatible(N1)) return SDValue(); + const TargetOptions &Options = DAG.getTarget().Options; // Floating-point multiply-add with intermediate rounding. @@ -10483,8 +10623,8 @@ // Is the node an FMUL and contractable either due to global flags or // SDNodeFlags. - auto isContractableFMUL = [AllowFusionGlobally](SDValue N) { - if (N.getOpcode() != ISD::FMUL) + auto isContractableFMUL = [AllowFusionGlobally, &matcher](SDValue N) { + if (!matcher.match(N, ISD::FMUL)) return false; return AllowFusionGlobally || isContractable(N.getNode()); }; @@ -10497,42 +10637,42 @@ // fold (fadd (fmul x, y), z) -> (fma x, y, z) if (isContractableFMUL(N0) && (Aggressive || N0->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), N1, Flags); } // fold (fadd x, (fmul y, z)) -> (fma y, z, x) // Note: Commutes FADD operands. if (isContractableFMUL(N1) && (Aggressive || N1->hasOneUse())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), N0, Flags); } // Look through FP_EXTEND nodes to do more combining. // fold (fadd (fpext (fmul x, y)), z) -> (fma (fpext x), (fpext y), z) - if (N0.getOpcode() == ISD::FP_EXTEND) { + if ((N0.getOpcode() == ISD::FP_EXTEND) && matcher.isCompatible(N0.getOperand(0))) { SDValue N00 = N0.getOperand(0); if (isContractableFMUL(N00) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N00.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N00.getOperand(1)), N1, Flags); } } // fold (fadd x, (fpext (fmul y, z))) -> (fma (fpext y), (fpext z), x) // Note: Commutes FADD operands. - if (N1.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N1, ISD::FP_EXTEND)) { SDValue N10 = N1.getOperand(0); if (isContractableFMUL(N10) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N10.getValueType())) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(0)), - DAG.getNode(ISD::FP_EXTEND, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, N10.getOperand(1)), N0, Flags); } } @@ -10541,12 +10681,12 @@ if (Aggressive) { // fold (fadd (fma x, y, (fmul u, v)), z) -> (fma x, y (fma u, v, z)) if (CanFuse && - N0.getOpcode() == PreferredFusedOpcode && - N0.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N0, PreferredFusedOpcode) && + matcher.match(N0.getOperand(2), ISD::FMUL) && N0->hasOneUse() && N0.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(0), N0.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N0.getOperand(2).getOperand(0), N0.getOperand(2).getOperand(1), N1, Flags), Flags); @@ -10554,12 +10694,12 @@ // fold (fadd x, (fma y, z, (fmul u, v)) -> (fma y, z (fma u, v, x)) if (CanFuse && - N1->getOpcode() == PreferredFusedOpcode && - N1.getOperand(2).getOpcode() == ISD::FMUL && + matcher.match(N1, PreferredFusedOpcode) && + matcher.match(N1.getOperand(2), ISD::FMUL) && N1->hasOneUse() && N1.getOperand(2)->hasOneUse()) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, + return matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(0), N1.getOperand(1), - DAG.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(PreferredFusedOpcode, SL, VT, N1.getOperand(2).getOperand(0), N1.getOperand(2).getOperand(1), N0, Flags), Flags); @@ -10571,15 +10711,15 @@ auto FoldFAddFMAFPExtFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, X, Y, - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, X, Y, + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; - if (N0.getOpcode() == PreferredFusedOpcode) { + if (matcher.match(N0, PreferredFusedOpcode)) { SDValue N02 = N0.getOperand(2); - if (N02.getOpcode() == ISD::FP_EXTEND) { + if (matcher.match(N02, ISD::FP_EXTEND)) { SDValue N020 = N02.getOperand(0); if (isContractableFMUL(N020) && TLI.isFPExtFoldable(PreferredFusedOpcode, VT, N020.getValueType())) { @@ -10598,12 +10738,12 @@ auto FoldFAddFPExtFMAFMul = [&] ( SDValue X, SDValue Y, SDValue U, SDValue V, SDValue Z, SDNodeFlags Flags) { - return DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, X), - DAG.getNode(ISD::FP_EXTEND, SL, VT, Y), - DAG.getNode(PreferredFusedOpcode, SL, VT, - DAG.getNode(ISD::FP_EXTEND, SL, VT, U), - DAG.getNode(ISD::FP_EXTEND, SL, VT, V), + return matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, X), + matcher.getNode(ISD::FP_EXTEND, SL, VT, Y), + matcher.getNode(PreferredFusedOpcode, SL, VT, + matcher.getNode(ISD::FP_EXTEND, SL, VT, U), + matcher.getNode(ISD::FP_EXTEND, SL, VT, V), Z, Flags), Flags); }; if (N0.getOpcode() == ISD::FP_EXTEND) { @@ -11040,6 +11180,15 @@ return SDValue(); } +SDValue DAGCombiner::visitFADD_EVL(SDNode *N) { + // FADD -> FMA combines: + if (SDValue Fused = visitFADDForFMACombine(N)) { + AddToWorklist(Fused.getNode()); + return Fused; + } + return SDValue(); +} + SDValue DAGCombiner::visitFADD(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -11210,7 +11359,7 @@ } // enable-unsafe-fp-math // FADD -> FMA combines: - if (SDValue Fused = visitFADDForFMACombine(N)) { + if (SDValue Fused = visitFADDForFMACombine(N)) { AddToWorklist(Fused.getNode()); return Fused; } @@ -17732,7 +17881,7 @@ NewMask.push_back(M < 0 ? -1 : Scale * M + s); return NewMask; }; - + SDValue BC0 = peekThroughOneUseBitcasts(N0); if (BC0.getOpcode() == ISD::VECTOR_SHUFFLE && BC0.hasOneUse()) { EVT SVT = VT.getScalarType(); Index: lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp +++ lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp @@ -885,7 +885,7 @@ } // Handle promotion for the ADDE/SUBE/ADDCARRY/SUBCARRY nodes. Notice that -// the third operand of ADDE/SUBE nodes is carry flag, which differs from +// the third operand of ADDE/SUBE nodes is carry flag, which differs from // the ADDCARRY/SUBCARRY nodes in that the third operand is carry Boolean. SDValue DAGTypeLegalizer::PromoteIntRes_ADDSUBCARRY(SDNode *N, unsigned ResNo) { if (ResNo == 1) @@ -1031,6 +1031,9 @@ return false; } + if (N->isEVL()) { + Res = PromoteIntOp_EVL(N, OpNo); + } else { switch (N->getOpcode()) { default: #ifndef NDEBUG @@ -1092,6 +1095,7 @@ case ISD::SMULFIX: Res = PromoteIntOp_SMULFIX(N); break; } + } // If the result is null, the sub-method took care of registering results etc. if (!Res.getNode()) return false; @@ -1365,6 +1369,25 @@ TruncateStore, N->isCompressingStore()); } +SDValue DAGTypeLegalizer::PromoteIntOp_EVL(SDNode *N, unsigned OpNo) { + EVT DataVT; + switch (N->getOpcode()) { + default: + DataVT = N->getValueType(0); + break; + + case ISD::EVL_STORE: + case ISD::EVL_SCATTER: + llvm_unreachable("TODO implement EVL memory nodes"); + } + + // TODO assert that \p OpNo is the mask + SDValue Mask = PromoteTargetBoolean(N->getOperand(OpNo), DataVT); + SmallVector NewOps(N->op_begin(), N->op_end()); + NewOps[OpNo] = Mask; + return SDValue(DAG.UpdateNodeOperands(N, NewOps), 0); +} + SDValue DAGTypeLegalizer::PromoteIntOp_MLOAD(MaskedLoadSDNode *N, unsigned OpNo) { assert(OpNo == 2 && "Only know how to promote the mask!"); Index: lib/CodeGen/SelectionDAG/LegalizeTypes.h =================================================================== --- lib/CodeGen/SelectionDAG/LegalizeTypes.h +++ lib/CodeGen/SelectionDAG/LegalizeTypes.h @@ -347,6 +347,7 @@ SDValue PromoteIntRes_SMULFIX(SDNode *N); SDValue PromoteIntRes_FLT_ROUNDS(SDNode *N); + // Integer Operand Promotion. bool PromoteIntegerOperand(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_ANY_EXTEND(SDNode *N); @@ -379,6 +380,7 @@ SDValue PromoteIntOp_FRAMERETURNADDR(SDNode *N); SDValue PromoteIntOp_PREFETCH(SDNode *N, unsigned OpNo); SDValue PromoteIntOp_SMULFIX(SDNode *N); + SDValue PromoteIntOp_EVL(SDNode *N, unsigned OpNo); void PromoteSetCCOperands(SDValue &LHS,SDValue &RHS, ISD::CondCode Code); Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -432,6 +432,183 @@ return Result; } +//===----------------------------------------------------------------------===// +// SDNode EVL Support +//===----------------------------------------------------------------------===// + +int +ISD::GetMaskPosEVL(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case ISD::EVL_FNEG: + return 1; + + case ISD::EVL_ADD: + case ISD::EVL_SUB: + case ISD::EVL_MUL: + case ISD::EVL_SDIV: + case ISD::EVL_SREM: + case ISD::EVL_UDIV: + case ISD::EVL_UREM: + + case ISD::EVL_AND: + case ISD::EVL_OR: + case ISD::EVL_XOR: + case ISD::EVL_SHL: + case ISD::EVL_SRA: + case ISD::EVL_SRL: + case ISD::EVL_FDIV: + case ISD::EVL_FREM: + + case ISD::EVL_FADD: + case ISD::EVL_FMUL: + return 2; + + case ISD::EVL_FMA: + case ISD::EVL_SELECT: + return 3; + + case EVL_REDUCE_FADD: + case EVL_REDUCE_FMUL: + case EVL_REDUCE_ADD: + case EVL_REDUCE_MUL: + case EVL_REDUCE_AND: + case EVL_REDUCE_OR: + case EVL_REDUCE_XOR: + case EVL_REDUCE_SMAX: + case EVL_REDUCE_SMIN: + case EVL_REDUCE_UMAX: + case EVL_REDUCE_UMIN: + case VECREDUCE_FMAX: + case VECREDUCE_FMIN: + case EVL_REDUCE_FMAX: + case EVL_REDUCE_FMIN: + return 1; + + /// FMIN/FMAX nodes can have flags, for NaN/NoNaN variants. + // (implicit) case ISD::EVL_COMPOSE: return -1 + } +} + +int +ISD::GetVectorLengthPosEVL(unsigned OpCode) { + switch (OpCode) { + default: return -1; + + case ISD::EVL_SELECT: + return 0; + + case ISD::EVL_FNEG: + return 2; + + case ISD::EVL_ADD: + case ISD::EVL_SUB: + case ISD::EVL_MUL: + case ISD::EVL_SDIV: + case ISD::EVL_SREM: + case ISD::EVL_UDIV: + case ISD::EVL_UREM: + + case ISD::EVL_AND: + case ISD::EVL_OR: + case ISD::EVL_XOR: + case ISD::EVL_SHL: + case ISD::EVL_SRA: + case ISD::EVL_SRL: + + case ISD::EVL_FADD: + case ISD::EVL_FMUL: + case ISD::EVL_FDIV: + case ISD::EVL_FREM: + return 3; + + case ISD::EVL_FMA: + return 4; + + case ISD::EVL_COMPOSE: + return 3; + + case EVL_REDUCE_FADD: + case EVL_REDUCE_FMUL: + case EVL_REDUCE_ADD: + case EVL_REDUCE_MUL: + case EVL_REDUCE_AND: + case EVL_REDUCE_OR: + case EVL_REDUCE_XOR: + case EVL_REDUCE_SMAX: + case EVL_REDUCE_SMIN: + case EVL_REDUCE_UMAX: + case EVL_REDUCE_UMIN: + case EVL_REDUCE_FMAX: + case EVL_REDUCE_FMIN: + return 2; + } +} + +unsigned +ISD::GetFunctionOpCodeForEVL(unsigned OpCode) { + switch (OpCode) { + default: return OpCode; + + case ISD::EVL_SELECT: return ISD::VSELECT; + case ISD::EVL_FNEG: return ISD::FNEG; + case ISD::EVL_ADD: return ISD::ADD; + case ISD::EVL_SUB: return ISD::SUB; + case ISD::EVL_MUL: return ISD::MUL; + case ISD::EVL_SDIV: return ISD::SDIV; + case ISD::EVL_SREM: return ISD::SREM; + case ISD::EVL_UDIV: return ISD::UDIV; + case ISD::EVL_UREM: return ISD::UREM; + + case ISD::EVL_AND: return ISD::AND; + case ISD::EVL_OR: return ISD::OR; + case ISD::EVL_XOR: return ISD::XOR; + case ISD::EVL_SHL: return ISD::SHL; + case ISD::EVL_SRA: return ISD::SRA; + case ISD::EVL_SRL: return ISD::SRL; + case ISD::EVL_FDIV: return ISD::FDIV; + case ISD::EVL_FREM: return ISD::FREM; + + case ISD::EVL_FADD: return ISD::FADD; + case ISD::EVL_FMUL: return ISD::FMUL; + + case ISD::EVL_FMA: return ISD::FMA; + } +} + +unsigned +ISD::GetEVLForFunctionOpCode(unsigned OpCode) { + switch (OpCode) { + default: llvm_unreachable("can not translate this Opcode to EVL"); + + case ISD::VSELECT:return ISD::EVL_SELECT; + case ISD::FNEG: return ISD::EVL_FNEG; + case ISD::ADD: return ISD::EVL_ADD; + case ISD::SUB: return ISD::EVL_SUB; + case ISD::MUL: return ISD::EVL_MUL; + case ISD::SDIV: return ISD::EVL_SDIV; + case ISD::SREM: return ISD::EVL_SREM; + case ISD::UDIV: return ISD::EVL_UDIV; + case ISD::UREM: return ISD::EVL_UREM; + + case ISD::AND: return ISD::EVL_AND; + case ISD::OR: return ISD::EVL_OR; + case ISD::XOR: return ISD::EVL_XOR; + case ISD::SHL: return ISD::EVL_SHL; + case ISD::SRA: return ISD::EVL_SRA; + case ISD::SRL: return ISD::EVL_SRL; + case ISD::FDIV: return ISD::EVL_FDIV; + case ISD::FREM: return ISD::EVL_FREM; + + case ISD::FADD: return ISD::EVL_FADD; + case ISD::FMUL: return ISD::EVL_FMUL; + + case ISD::FMA: return ISD::EVL_FMA; + } +} + + //===----------------------------------------------------------------------===// // SDNode Profile Support //===----------------------------------------------------------------------===// @@ -555,6 +732,34 @@ ID.AddInteger(ST->getPointerInfo().getAddrSpace()); break; } + case ISD::EVL_LOAD: { + const EVLLoadSDNode *ELD = cast(N); + ID.AddInteger(ELD->getMemoryVT().getRawBits()); + ID.AddInteger(ELD->getRawSubclassData()); + ID.AddInteger(ELD->getPointerInfo().getAddrSpace()); + break; + } + case ISD::EVL_STORE: { + const EVLStoreSDNode *EST = cast(N); + ID.AddInteger(EST->getMemoryVT().getRawBits()); + ID.AddInteger(EST->getRawSubclassData()); + ID.AddInteger(EST->getPointerInfo().getAddrSpace()); + break; + } + case ISD::EVL_GATHER: { + const EVLGatherSDNode *EG = cast(N); + ID.AddInteger(EG->getMemoryVT().getRawBits()); + ID.AddInteger(EG->getRawSubclassData()); + ID.AddInteger(EG->getPointerInfo().getAddrSpace()); + break; + } + case ISD::EVL_SCATTER: { + const EVLScatterSDNode *ES = cast(N); + ID.AddInteger(ES->getMemoryVT().getRawBits()); + ID.AddInteger(ES->getRawSubclassData()); + ID.AddInteger(ES->getPointerInfo().getAddrSpace()); + break; + } case ISD::MLOAD: { const MaskedLoadSDNode *MLD = cast(N); ID.AddInteger(MLD->getMemoryVT().getRawBits()); @@ -6868,6 +7073,34 @@ return V; } +SDValue SelectionDAG::getLoadEVL(EVT VT, const SDLoc &dl, SDValue Chain, + SDValue Ptr, SDValue Mask, SDValue VLen, + EVT MemVT, MachineMemOperand *MMO, + ISD::LoadExtType ExtTy) { + SDVTList VTs = getVTList(VT, MVT::Other); + SDValue Ops[] = { Chain, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EVL_LOAD, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, ExtTy, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + ExtTy, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + SDValue SelectionDAG::getMaskedLoad(EVT VT, const SDLoc &dl, SDValue Chain, SDValue Ptr, SDValue Mask, SDValue PassThru, EVT MemVT, MachineMemOperand *MMO, @@ -6896,6 +7129,111 @@ return V; } +SDValue SelectionDAG::getStoreEVL(SDValue Chain, const SDLoc &dl, + SDValue Val, SDValue Ptr, SDValue Mask, + SDValue VLen, EVT MemVT, MachineMemOperand *MMO, + bool IsTruncating) { + assert(Chain.getValueType() == MVT::Other && + "Invalid chain type"); + EVT VT = Val.getValueType(); + SDVTList VTs = getVTList(MVT::Other); + SDValue Ops[] = { Chain, Val, Ptr, Mask, VLen }; + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EVL_STORE, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, IsTruncating, MemVT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), VTs, + IsTruncating, MemVT, MMO); + createOperands(N, Ops); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getGatherEVL(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + assert(Ops.size() == 6 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EVL_GATHER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValueType(0).getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} + +SDValue SelectionDAG::getScatterEVL(SDVTList VTs, EVT VT, const SDLoc &dl, + ArrayRef Ops, + MachineMemOperand *MMO) { + assert(Ops.size() == 7 && "Incompatible number of operands"); + + FoldingSetNodeID ID; + AddNodeIDNode(ID, ISD::EVL_SCATTER, VTs, Ops); + ID.AddInteger(VT.getRawBits()); + ID.AddInteger(getSyntheticNodeSubclassData( + dl.getIROrder(), VTs, VT, MMO)); + ID.AddInteger(MMO->getPointerInfo().getAddrSpace()); + void *IP = nullptr; + if (SDNode *E = FindNodeOrInsertPos(ID, dl, IP)) { + cast(E)->refineAlignment(MMO); + return SDValue(E, 0); + } + auto *N = newSDNode(dl.getIROrder(), dl.getDebugLoc(), + VTs, VT, MMO); + createOperands(N, Ops); + + assert(N->getMask().getValueType().getVectorNumElements() == + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between mask and data"); + assert(N->getIndex().getValueType().getVectorNumElements() >= + N->getValue().getValueType().getVectorNumElements() && + "Vector width mismatch between index and data"); + assert(isa(N->getScale()) && + cast(N->getScale())->getAPIntValue().isPowerOf2() && + "Scale should be a constant power of 2"); + + CSEMap.InsertNode(N, IP); + InsertNode(N); + SDValue V(N, 0); + NewSDValueDbgMsg(V, "Creating new node: ", this); + return V; +} SDValue SelectionDAG::getMaskedStore(SDValue Chain, const SDLoc &dl, SDValue Val, SDValue Ptr, SDValue Mask, EVT MemVT, MachineMemOperand *MMO, Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h @@ -934,6 +934,12 @@ const char *visitIntrinsicCall(const CallInst &I, unsigned Intrinsic); void visitTargetIntrinsic(const CallInst &I, unsigned Intrinsic); void visitConstrainedFPIntrinsic(const ConstrainedFPIntrinsic &FPI); + void visitExplicitVectorLengthIntrinsic(const EVLIntrinsic &EVLI); + void visitCmpEVL(const EVLIntrinsic &I); + void visitLoadEVL(const CallInst &I); + void visitStoreEVL(const CallInst &I); + void visitGatherEVL(const CallInst &I); + void visitScatterEVL(const CallInst &I); void visitVAStart(const CallInst &I); void visitVAArg(const VAArgInst &I); Index: lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -3982,6 +3982,46 @@ setValue(&I, StoreNode); } +void SelectionDAGBuilder::visitStoreEVL(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getEVLStoreOps = [&](Value* &Ptr, Value* &Mask, Value* &Src0, + Value * &VLen, unsigned & Alignment) { + // llvm.masked.store.*(Src0, Ptr, Mask, VLen) + Src0 = I.getArgOperand(0); + Ptr = I.getArgOperand(1); + Alignment = I.getParamAlignment(1); + Mask = I.getArgOperand(2); + VLen = I.getArgOperand(3); + }; + + Value *PtrOperand, *MaskOperand, *Src0Operand, *VLenOperand; + unsigned Alignment = 0; + getEVLStoreOps(PtrOperand, MaskOperand, Src0Operand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue Src0 = getValue(Src0Operand); + SDValue Mask = getValue(MaskOperand); + SDValue VLen = getValue(VLenOperand); + + EVT VT = Src0.getValueType(); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + SDValue StoreNode = DAG.getStoreEVL(getRoot(), sdl, Src0, Ptr, Mask, VLen, VT, + MMO, false /* Truncating */); + DAG.setRoot(StoreNode); + setValue(&I, StoreNode); +} + // Get a uniform base for the Gather/Scatter intrinsic. // The first argument of the Gather/Scatter intrinsic is a vector of pointers. // We try to represent it as a base pointer + vector of indices. @@ -4200,6 +4240,158 @@ setValue(&I, Gather); } +void SelectionDAGBuilder::visitGatherEVL(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // @llvm.evl.gather.*(Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(0); + SDValue Mask = getValue(I.getArgOperand(1)); + SDValue VLen = getValue(I.getArgOperand(2)); + + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); + unsigned Alignment = I.getParamAlignment(0); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + SDValue Root = DAG.getRoot(); + SDValue Base; + SDValue Index; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); + bool ConstantMemory = false; + if (UniformBase && AA && + AA->pointsToConstantMemory( + MemoryLocation(BasePtr, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo))) { + // Do not serialize (non-volatile) loads of constant memory with anything. + Root = DAG.getEntryNode(); + ConstantMemory = true; + } + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(UniformBase ? BasePtr : nullptr), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { Root, Base, Index, Scale, Mask, VLen }; + SDValue Gather = DAG.getGatherEVL(DAG.getVTList(VT, MVT::Other), VT, sdl, Ops, MMO); + + SDValue OutChain = Gather.getValue(1); + if (!ConstantMemory) + PendingLoads.push_back(OutChain); + setValue(&I, Gather); +} + +void SelectionDAGBuilder::visitScatterEVL(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + // llvm.evl.scatter.*(Src0, Ptrs, Mask, VLen) + const Value *Ptr = I.getArgOperand(1); + SDValue Src0 = getValue(I.getArgOperand(0)); + SDValue Mask = getValue(I.getArgOperand(2)); + SDValue VLen = getValue(I.getArgOperand(3)); + EVT VT = Src0.getValueType(); + unsigned Alignment = I.getParamAlignment(1); + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + + SDValue Base; + SDValue Index; + SDValue Scale; + const Value *BasePtr = Ptr; + bool UniformBase = getUniformBase(BasePtr, Base, Index, Scale, this); + + const Value *MemOpBasePtr = UniformBase ? BasePtr : nullptr; + MachineMemOperand *MMO = DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(MemOpBasePtr), + MachineMemOperand::MOStore, VT.getStoreSize(), + Alignment, AAInfo); + if (!UniformBase) { + Base = DAG.getConstant(0, sdl, TLI.getPointerTy(DAG.getDataLayout())); + Index = getValue(Ptr); + Scale = DAG.getTargetConstant(1, sdl, TLI.getPointerTy(DAG.getDataLayout())); + } + SDValue Ops[] = { getRoot(), Src0, Base, Index, Scale, Mask, VLen }; + SDValue Scatter = DAG.getScatterEVL(DAG.getVTList(MVT::Other), VT, sdl, + Ops, MMO); + DAG.setRoot(Scatter); + setValue(&I, Scatter); +} + +void SelectionDAGBuilder::visitLoadEVL(const CallInst &I) { + SDLoc sdl = getCurSDLoc(); + + auto getMaskedLoadOps = [&](Value* &Ptr, Value* &Mask, Value* &VLen, + unsigned& Alignment) { + // @llvm.evl.load.*(Ptr, Mask, Vlen) + Ptr = I.getArgOperand(0); + Alignment = I.getParamAlignment(0); + Mask = I.getArgOperand(1); + VLen = I.getArgOperand(2); + }; + + Value *PtrOperand, *MaskOperand, *VLenOperand; + unsigned Alignment; + getMaskedLoadOps(PtrOperand, MaskOperand, VLenOperand, Alignment); + + SDValue Ptr = getValue(PtrOperand); + SDValue VLen = getValue(VLenOperand); + SDValue Mask = getValue(MaskOperand); + + // infer the return type + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + SmallVector ValValueVTs; + ComputeValueVTs(TLI, DAG.getDataLayout(), I.getType(), ValValueVTs); + EVT VT = ValValueVTs[0]; + assert((ValValueVTs.size() == 1) && "splitting not implemented"); + + if (!Alignment) + Alignment = DAG.getEVTAlignment(VT); + + AAMDNodes AAInfo; + I.getAAMetadata(AAInfo); + const MDNode *Ranges = I.getMetadata(LLVMContext::MD_range); + + // Do not serialize masked loads of constant memory with anything. + bool AddToChain = + !AA || !AA->pointsToConstantMemory(MemoryLocation( + PtrOperand, + LocationSize::precise( + DAG.getDataLayout().getTypeStoreSize(I.getType())), + AAInfo)); + SDValue InChain = AddToChain ? DAG.getRoot() : DAG.getEntryNode(); + + MachineMemOperand *MMO = + DAG.getMachineFunction(). + getMachineMemOperand(MachinePointerInfo(PtrOperand), + MachineMemOperand::MOLoad, VT.getStoreSize(), + Alignment, AAInfo, Ranges); + + SDValue Load = DAG.getLoadEVL(VT, sdl, InChain, Ptr, Mask, VLen, VT, MMO, + ISD::NON_EXTLOAD); + if (AddToChain) + PendingLoads.push_back(Load.getValue(1)); + setValue(&I, Load); +} + void SelectionDAGBuilder::visitAtomicCmpXchg(const AtomicCmpXchgInst &I) { SDLoc dl = getCurSDLoc(); AtomicOrdering SuccessOrder = I.getSuccessOrdering(); @@ -5711,6 +5903,63 @@ case Intrinsic::experimental_constrained_trunc: visitConstrainedFPIntrinsic(cast(I)); return nullptr; + + case Intrinsic::evl_and: + case Intrinsic::evl_or: + case Intrinsic::evl_xor: + case Intrinsic::evl_ashr: + case Intrinsic::evl_lshr: + case Intrinsic::evl_shl: + + case Intrinsic::evl_select: + case Intrinsic::evl_compose: + case Intrinsic::evl_compress: + case Intrinsic::evl_expand: + case Intrinsic::evl_vshift: + + case Intrinsic::evl_load: + case Intrinsic::evl_store: + case Intrinsic::evl_gather: + case Intrinsic::evl_scatter: + + case Intrinsic::evl_fneg: + + case Intrinsic::evl_fadd: + case Intrinsic::evl_fsub: + case Intrinsic::evl_fmul: + case Intrinsic::evl_fdiv: + case Intrinsic::evl_frem: + + case Intrinsic::evl_fma: + + case Intrinsic::evl_add: + case Intrinsic::evl_sub: + case Intrinsic::evl_mul: + case Intrinsic::evl_udiv: + case Intrinsic::evl_sdiv: + case Intrinsic::evl_urem: + case Intrinsic::evl_srem: + + case Intrinsic::evl_cmp: + + case Intrinsic::evl_reduce_and: + case Intrinsic::evl_reduce_or: + case Intrinsic::evl_reduce_xor: + + case Intrinsic::evl_reduce_fadd: + case Intrinsic::evl_reduce_fmax: + case Intrinsic::evl_reduce_fmin: + case Intrinsic::evl_reduce_fmul: + + case Intrinsic::evl_reduce_add: + case Intrinsic::evl_reduce_mul: + case Intrinsic::evl_reduce_umax: + case Intrinsic::evl_reduce_umin: + case Intrinsic::evl_reduce_smax: + case Intrinsic::evl_reduce_smin: + visitExplicitVectorLengthIntrinsic(cast(I)); + return nullptr; + case Intrinsic::fmuladd: { EVT VT = TLI.getValueType(DAG.getDataLayout(), I.getType()); if (TM.Options.AllowFPOpFusion != FPOpFusion::Strict && @@ -6524,6 +6773,138 @@ setValue(&FPI, FPResult); } +void SelectionDAGBuilder::visitCmpEVL(const EVLIntrinsic &I) { + ISD::CondCode Condition; + CmpInst::Predicate predicate = I.getCmpPredicate(); + bool IsFP = I.getOperand(0)->getType()->isFPOrFPVectorTy(); + if (IsFP) { + Condition = getFCmpCondCode(predicate); + auto *FPMO = dyn_cast(&I); + if ((FPMO && FPMO->hasNoNaNs()) || TM.Options.NoNaNsFPMath) + Condition = getFCmpCodeWithoutNaN(Condition); + + } else { + Condition = getICmpCondCode(predicate); + } + + SDValue Op1 = getValue(I.getOperand(0)); + SDValue Op2 = getValue(I.getOperand(1)); + + EVT DestVT = DAG.getTargetLoweringInfo().getValueType(DAG.getDataLayout(), + I.getType()); + setValue(&I, DAG.getSetCC(getCurSDLoc(), DestVT, Op1, Op2, Condition)); +} + +void SelectionDAGBuilder::visitExplicitVectorLengthIntrinsic( + const EVLIntrinsic & EVLInst) { + SDLoc sdl = getCurSDLoc(); + unsigned Opcode; + switch (EVLInst.getIntrinsicID()) { + default: + llvm_unreachable("Unforeseen intrinsic"); // Can't reach here. + + case Intrinsic::evl_load: visitLoadEVL(EVLInst); return; + case Intrinsic::evl_store: visitStoreEVL(EVLInst); return; + case Intrinsic::evl_gather: visitGatherEVL(EVLInst); return; + case Intrinsic::evl_scatter: visitScatterEVL(EVLInst); return; + + case Intrinsic::evl_cmp: visitCmpEVL(EVLInst); return; + + case Intrinsic::evl_add: Opcode = ISD::EVL_ADD; break; + case Intrinsic::evl_sub: Opcode = ISD::EVL_SUB; break; + case Intrinsic::evl_mul: Opcode = ISD::EVL_MUL; break; + case Intrinsic::evl_udiv: Opcode = ISD::EVL_UDIV; break; + case Intrinsic::evl_sdiv: Opcode = ISD::EVL_SDIV; break; + case Intrinsic::evl_urem: Opcode = ISD::EVL_UREM; break; + case Intrinsic::evl_srem: Opcode = ISD::EVL_SREM; break; + + case Intrinsic::evl_and: Opcode = ISD::EVL_AND; break; + case Intrinsic::evl_or: Opcode = ISD::EVL_OR; break; + case Intrinsic::evl_xor: Opcode = ISD::EVL_XOR; break; + case Intrinsic::evl_ashr: Opcode = ISD::EVL_SRA; break; + case Intrinsic::evl_lshr: Opcode = ISD::EVL_SRL; break; + case Intrinsic::evl_shl: Opcode = ISD::EVL_SHL; break; + + case Intrinsic::evl_fneg: Opcode = ISD::EVL_FNEG; break; + case Intrinsic::evl_fadd: Opcode = ISD::EVL_FADD; break; + case Intrinsic::evl_fsub: Opcode = ISD::EVL_FSUB; break; + case Intrinsic::evl_fmul: Opcode = ISD::EVL_FMUL; break; + case Intrinsic::evl_fdiv: Opcode = ISD::EVL_FDIV; break; + case Intrinsic::evl_frem: Opcode = ISD::EVL_FREM; break; + + case Intrinsic::evl_fma: Opcode = ISD::EVL_FMA; break; + + case Intrinsic::evl_select: Opcode = ISD::EVL_SELECT; break; + case Intrinsic::evl_compose: Opcode = ISD::EVL_COMPOSE; break; + case Intrinsic::evl_compress: Opcode = ISD::EVL_COMPRESS; break; + case Intrinsic::evl_expand: Opcode = ISD::EVL_EXPAND; break; + case Intrinsic::evl_vshift: Opcode = ISD::EVL_VSHIFT; break; + + case Intrinsic::evl_reduce_and: Opcode = ISD::EVL_REDUCE_AND; break; + case Intrinsic::evl_reduce_or: Opcode = ISD::EVL_REDUCE_OR; break; + case Intrinsic::evl_reduce_xor: Opcode = ISD::EVL_REDUCE_XOR; break; + case Intrinsic::evl_reduce_add: Opcode = ISD::EVL_REDUCE_ADD; break; + case Intrinsic::evl_reduce_mul: Opcode = ISD::EVL_REDUCE_MUL; break; + case Intrinsic::evl_reduce_fadd: Opcode = ISD::EVL_REDUCE_FADD; break; + case Intrinsic::evl_reduce_fmul: Opcode = ISD::EVL_REDUCE_FMUL; break; + case Intrinsic::evl_reduce_smax: Opcode = ISD::EVL_REDUCE_SMAX; break; + case Intrinsic::evl_reduce_smin: Opcode = ISD::EVL_REDUCE_SMIN; break; + case Intrinsic::evl_reduce_umax: Opcode = ISD::EVL_REDUCE_UMAX; break; + case Intrinsic::evl_reduce_umin: Opcode = ISD::EVL_REDUCE_UMIN; break; + } + + // TODO memory evl: SDValue Chain = getRoot(); + + SmallVector ValueVTs; + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + ComputeValueVTs(TLI, DAG.getDataLayout(), EVLInst.getType(), ValueVTs); + SDVTList VTs = DAG.getVTList(ValueVTs); + + // ValueVTs.push_back(MVT::Other); // Out chain + + + SDValue Result; + + switch (EVLInst.getNumArgOperands()) { + default: + llvm_unreachable("unexpected number of arguments to evl intrinsic"); + case 3: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(EVLInst.getArgOperand(0)), + getValue(EVLInst.getArgOperand(1)), + getValue(EVLInst.getArgOperand(2)) }); + break; + + case 4: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(EVLInst.getArgOperand(0)), + getValue(EVLInst.getArgOperand(1)), + getValue(EVLInst.getArgOperand(2)), + getValue(EVLInst.getArgOperand(3)) }); + break; + + case 5: + Result = DAG.getNode(Opcode, sdl, VTs, + { getValue(EVLInst.getArgOperand(0)), + getValue(EVLInst.getArgOperand(1)), + getValue(EVLInst.getArgOperand(2)), + getValue(EVLInst.getArgOperand(3)), + getValue(EVLInst.getArgOperand(4)) }); + break; + } + + if (Result.getNode()->getNumValues() == 2) { + // this evl node has a chain + SDValue OutChain = Result.getValue(1); + DAG.setRoot(OutChain); + SDValue EVLResult = Result.getValue(0); + setValue(&EVLInst, EVLResult); + } else { + // this is a pure node + setValue(&EVLInst, Result); + } +} + std::pair SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI, const BasicBlock *EHPadBB) { Index: lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -421,6 +421,65 @@ case ISD::VECREDUCE_UMIN: return "vecreduce_umin"; case ISD::VECREDUCE_FMAX: return "vecreduce_fmax"; case ISD::VECREDUCE_FMIN: return "vecreduce_fmin"; + + // Explicit Vector Length erxtension + // EVL Memory + case ISD::EVL_LOAD: return "evl_load"; + case ISD::EVL_STORE: return "evl_store"; + case ISD::EVL_GATHER: return "evl_gather"; + case ISD::EVL_SCATTER: return "evl_scatter"; + + // EVL Unary operators + case ISD::EVL_FNEG: return "evl_fneg"; + + // EVL Binary operators + case ISD::EVL_ADD: return "evl_add"; + case ISD::EVL_SUB: return "evl_sub"; + case ISD::EVL_MUL: return "evl_mul"; + case ISD::EVL_SDIV: return "evl_sdiv"; + case ISD::EVL_UDIV: return "evl_udiv"; + case ISD::EVL_SREM: return "evl_srem"; + case ISD::EVL_UREM: return "evl_urem"; + case ISD::EVL_AND: return "evl_and"; + case ISD::EVL_OR: return "evl_or"; + case ISD::EVL_XOR: return "evl_xor"; + case ISD::EVL_SHL: return "evl_shl"; + case ISD::EVL_SRA: return "evl_sra"; + case ISD::EVL_SRL: return "evl_srl"; + case ISD::EVL_FADD: return "evl_fadd"; + case ISD::EVL_FSUB: return "evl_fsub"; + case ISD::EVL_FMUL: return "evl_fmul"; + case ISD::EVL_FDIV: return "evl_fdiv"; + case ISD::EVL_FREM: return "evl_frem"; + + // EVL comparison + case ISD::EVL_SETCC: return "evl_setcc"; + + // EVL ternary operators + case ISD::EVL_FMA: return "evl_fma"; + + // EVL shuffle + case ISD::EVL_VSHIFT: return "evl_vshift"; + case ISD::EVL_COMPRESS: return "evl_compress"; + case ISD::EVL_EXPAND: return "evl_expand"; + + case ISD::EVL_COMPOSE: return "evl_compose"; + case ISD::EVL_SELECT: return "evl_select"; + + // EVL reduction operators + case ISD::EVL_REDUCE_FADD: return "evl_reduce_fadd"; + case ISD::EVL_REDUCE_FMUL: return "evl_reduce_fmul"; + case ISD::EVL_REDUCE_ADD: return "evl_reduce_add"; + case ISD::EVL_REDUCE_MUL: return "evl_reduce_mul"; + case ISD::EVL_REDUCE_AND: return "evl_reduce_and"; + case ISD::EVL_REDUCE_OR: return "evl_reduce_or"; + case ISD::EVL_REDUCE_XOR: return "evl_reduce_xor"; + case ISD::EVL_REDUCE_SMAX: return "evl_reduce_smax"; + case ISD::EVL_REDUCE_SMIN: return "evl_reduce_smin"; + case ISD::EVL_REDUCE_UMAX: return "evl_reduce_umax"; + case ISD::EVL_REDUCE_UMIN: return "evl_reduce_umin"; + case ISD::EVL_REDUCE_FMAX: return "evl_reduce_fmax"; + case ISD::EVL_REDUCE_FMIN: return "evl_reduce_fmin"; } } Index: lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp @@ -782,6 +782,10 @@ CurDAG->Combine(BeforeLegalizeTypes, AA, OptLevel); } + if (getenv("SDEBUG")) { + CurDAG->dump(); + } + #ifndef NDEBUG if (TTI.hasBranchDivergence()) CurDAG->VerifyDAGDiverence(); Index: lib/IR/Attributes.cpp =================================================================== --- lib/IR/Attributes.cpp +++ lib/IR/Attributes.cpp @@ -256,6 +256,8 @@ return "byval"; if (hasAttribute(Attribute::Convergent)) return "convergent"; + if (hasAttribute(Attribute::VectorLength)) + return "vlen"; if (hasAttribute(Attribute::SwiftError)) return "swifterror"; if (hasAttribute(Attribute::SwiftSelf)) @@ -272,6 +274,10 @@ return "inreg"; if (hasAttribute(Attribute::JumpTable)) return "jumptable"; + if (hasAttribute(Attribute::Mask)) + return "mask"; + if (hasAttribute(Attribute::Passthru)) + return "passthru"; if (hasAttribute(Attribute::MinSize)) return "minsize"; if (hasAttribute(Attribute::Naked)) Index: lib/IR/CMakeLists.txt =================================================================== --- lib/IR/CMakeLists.txt +++ lib/IR/CMakeLists.txt @@ -23,6 +23,7 @@ DiagnosticPrinter.cpp Dominators.cpp DomTreeUpdater.cpp + EVLBuilder.cpp Function.cpp GVMaterializer.cpp Globals.cpp @@ -47,6 +48,7 @@ PassManager.cpp PassRegistry.cpp PassTimingInfo.cpp + PredicatedInst.cpp SafepointIRVerifier.cpp ProfileSummary.cpp Statepoint.cpp Index: lib/IR/EVLBuilder.cpp =================================================================== --- /dev/null +++ lib/IR/EVLBuilder.cpp @@ -0,0 +1,195 @@ +#include +#include +#include +#include +#include + +namespace { + using namespace llvm; + using ShortTypeVec = EVLIntrinsic::ShortTypeVec; + using ShortValueVec = SmallVector; +} + +namespace llvm { + +Module & +EVLBuilder::getModule() const { + return *Builder.GetInsertBlock()->getParent()->getParent(); +} + +Value& +EVLBuilder::GetMaskForType(VectorType & VecTy) { + if (Mask) return *Mask; + + auto * boolTy = Builder.getInt1Ty(); + auto * maskTy = VectorType::get(boolTy, StaticVectorLength); + return *ConstantInt::getAllOnesValue(maskTy); +} + +Value& +EVLBuilder::GetEVLForType(VectorType & VecTy) { + if (ExplicitVectorLength) return *ExplicitVectorLength; + + auto * intTy = Builder.getInt32Ty(); + return *ConstantInt::get(intTy, StaticVectorLength); +} + +Value* +EVLBuilder::CreateVectorCopy(Instruction & Inst, ValArray VecOpArray) { + + auto oc = Inst.getOpcode(); + + auto evlDesc = EVLIntrinsic::GetEVLIntrinsicDesc(oc); + if (evlDesc.ID == Intrinsic::not_intrinsic) { + return nullptr; + } + + if ((oc <= Instruction::BinaryOpsEnd) && + (oc >= Instruction::BinaryOpsBegin)) { + + assert(VecOpArray.size() == 2); + Value & FirstOp = *VecOpArray[0]; + Value & SndOp = *VecOpArray[1]; + + // Fetch the EVL intrinsic + auto & VecTy = cast(*FirstOp.getType()); + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + auto & EVLCall = + cast(*PredicatedBinaryOperator::Create(&getModule(), &GetMaskForType(VecTy), &GetEVLForType(VecTy), static_cast(oc), &FirstOp, &SndOp)); + Builder.Insert(&EVLCall); + + // transfer fast math flags + if (isa(Inst)) { + EVLCall.copyFastMathFlags(Inst.getFastMathFlags()); + } + + return &EVLCall; + } + + if ((oc <= Instruction::UnaryOpsBegin) && + (oc >= Instruction::UnaryOpsEnd)) { + assert(VecOpArray.size() == 1); + Value & FirstOp = *VecOpArray[0]; + + // Fetch the EVL intrinsic + auto & VecTy = cast(*FirstOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto * Func = Intrinsic::getDeclaration(&getModule(), evlDesc.ID, EVLIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 1) && (evlDesc.EVLPos == 2)); + + // Materialize the Call + ShortValueVec Args{&FirstOp, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + + auto & EVLCall = *Builder.CreateCall(Func, Args); + + // transfer fast math flags + if (isa(Inst)) { + cast(EVLCall).copyFastMathFlags(Inst.getFastMathFlags()); + } + + return &EVLCall; + } + + switch (oc) { + default: + return nullptr; + + case Instruction::FCmp: + case Instruction::ICmp: { + assert(VecOpArray.size() == 2); + Value & FirstOp = *VecOpArray[0]; + Value & SndOp = *VecOpArray[1]; + + // Fetch the EVL intrinsic + auto & VecTy = cast(*FirstOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto * Func = Intrinsic::getDeclaration(&getModule(), evlDesc.ID, EVLIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + // encode comparison predicate as MD + uint8_t RawPred = cast(Inst).getPredicate(); + auto Int8Ty = Builder.getInt8Ty(); + auto PredArg = ConstantInt::get(Int8Ty, RawPred, false); + + // Materialize the Call + ShortValueVec Args{&FirstOp, &SndOp, &GetMaskForType(VecTy), &GetEVLForType(VecTy), PredArg}; + + return Builder.CreateCall(Func, Args); + } + + case Instruction::Select: { + assert(VecOpArray.size() == 2); + Value & MaskOp = *VecOpArray[0]; + Value & OnTrueOp = *VecOpArray[1]; + Value & OnFalseOp = *VecOpArray[2]; + + // Fetch the EVL intrinsic + auto & VecTy = cast(*OnTrueOp.getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + + auto * Func = Intrinsic::getDeclaration(&getModule(), evlDesc.ID, EVLIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + // Materialize the Call + ShortValueVec Args{&OnTrueOp, &OnFalseOp, &MaskOp, &GetEVLForType(VecTy)}; + + return Builder.CreateCall(Func, Args); + } + } +} + +VectorType& +EVLBuilder::getVectorType(Type &ElementTy) { + return *VectorType::get(&ElementTy, StaticVectorLength); +} + +Value& +EVLBuilder::CreateContiguousStore(Value & Val, Value & Pointer, unsigned Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * StoreFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::evl_store, {Val.getType(), Pointer.getType()}); + ShortValueVec Args{&Val, &Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &StoreCall = *Builder.CreateCall(StoreFunc, Args); + if (Alignment) StoreCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return StoreCall; +} + +Value& +EVLBuilder::CreateContiguousLoad(Value & Pointer, unsigned Alignment) { + auto & PointerTy = cast(*Pointer.getType()); + auto & VecTy = getVectorType(*PointerTy.getPointerElementType()); + + auto * LoadFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::evl_load, {&VecTy, &PointerTy}); + ShortValueVec Args{&Pointer, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &LoadCall= *Builder.CreateCall(LoadFunc, Args); + if (Alignment) LoadCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return LoadCall; +} + +Value& +EVLBuilder::CreateScatter(Value & Val, Value & PointerVec, unsigned Alignment) { + auto & VecTy = cast(*Val.getType()); + auto * ScatterFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::evl_scatter, {Val.getType(), PointerVec.getType()}); + ShortValueVec Args{&Val, &PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &ScatterCall = *Builder.CreateCall(ScatterFunc, Args); + if (Alignment) ScatterCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return ScatterCall; +} + +Value& +EVLBuilder::CreateGather(Value & PointerVec, unsigned Alignment) { + auto & PointerVecTy = cast(*PointerVec.getType()); + auto & ElemTy = *cast(*PointerVecTy.getVectorElementType()).getPointerElementType(); + auto & VecTy = *VectorType::get(&ElemTy, PointerVecTy.getNumElements()); + auto * GatherFunc = Intrinsic::getDeclaration(&getModule(), Intrinsic::evl_gather, {&VecTy, &PointerVecTy}); + + ShortValueVec Args{&PointerVec, &GetMaskForType(VecTy), &GetEVLForType(VecTy)}; + CallInst &GatherCall = *Builder.CreateCall(GatherFunc, Args); + if (Alignment) GatherCall.addParamAttr(1, Attribute::getWithAlignment(getContext(), Alignment)); + return GatherCall; +} + +} // namespace llvm Index: lib/IR/IntrinsicInst.cpp =================================================================== --- lib/IR/IntrinsicInst.cpp +++ lib/IR/IntrinsicInst.cpp @@ -137,6 +137,154 @@ .Default(ebInvalid); } +CmpInst::Predicate +EVLIntrinsic::getCmpPredicate() const { + return static_cast(cast(getArgOperand(4))->getZExtValue()); +} + +bool EVLIntrinsic::isUnaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + case Intrinsic::evl_fneg: + return true; + } +} + +Value* +EVLIntrinsic::getMask() const { + if (isBinaryOp()) { return getArgOperand(2); } + else if (isTernaryOp()) { return getArgOperand(3); } + else if (isUnaryOp()) { return getArgOperand(1); } + else return nullptr; +} + +Value* +EVLIntrinsic::getVectorLength() const { + if (isBinaryOp()) { return getArgOperand(3); } + else if (isTernaryOp()) { return getArgOperand(4); } + else if (isUnaryOp()) { return getArgOperand(2); } + else return nullptr; +} + +bool EVLIntrinsic::isBinaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + + case Intrinsic::evl_and: + case Intrinsic::evl_or: + case Intrinsic::evl_xor: + case Intrinsic::evl_ashr: + case Intrinsic::evl_lshr: + case Intrinsic::evl_shl: + + case Intrinsic::evl_fadd: + case Intrinsic::evl_fsub: + case Intrinsic::evl_fmul: + case Intrinsic::evl_fdiv: + case Intrinsic::evl_frem: + + case Intrinsic::evl_reduce_or: + case Intrinsic::evl_reduce_xor: + case Intrinsic::evl_reduce_add: + case Intrinsic::evl_reduce_mul: + case Intrinsic::evl_reduce_smax: + case Intrinsic::evl_reduce_smin: + case Intrinsic::evl_reduce_umax: + case Intrinsic::evl_reduce_umin: + + case Intrinsic::evl_reduce_fadd: + case Intrinsic::evl_reduce_fmul: + case Intrinsic::evl_reduce_fmax: + case Intrinsic::evl_reduce_fmin: + + case Intrinsic::evl_add: + case Intrinsic::evl_sub: + case Intrinsic::evl_mul: + case Intrinsic::evl_udiv: + case Intrinsic::evl_sdiv: + case Intrinsic::evl_urem: + case Intrinsic::evl_srem: + return true; + } +} + +bool EVLIntrinsic::isTernaryOp() const { + switch (getIntrinsicID()) { + default: + return false; + case Intrinsic::evl_fma: + case Intrinsic::evl_select: + return true; + } +} + +EVLIntrinsic::EVLIntrinsicDesc +EVLIntrinsic::GetEVLIntrinsicDesc(unsigned OC) { + switch (OC) { + // fp unary + case Instruction::FNeg: return EVLIntrinsicDesc{ Intrinsic::evl_fneg, TypeTokenVec{EVLTypeToken::Vector}, 1, 2}; break; + + // fp binary + case Instruction::FAdd: return EVLIntrinsicDesc{ Intrinsic::evl_fadd, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::FSub: return EVLIntrinsicDesc{ Intrinsic::evl_fsub, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::FMul: return EVLIntrinsicDesc{ Intrinsic::evl_fmul, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::FDiv: return EVLIntrinsicDesc{ Intrinsic::evl_fdiv, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::FRem: return EVLIntrinsicDesc{ Intrinsic::evl_frem, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + + // sign-oblivious int + case Instruction::Add: return EVLIntrinsicDesc{ Intrinsic::evl_add, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::Sub: return EVLIntrinsicDesc{ Intrinsic::evl_sub, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::Mul: return EVLIntrinsicDesc{ Intrinsic::evl_mul, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + + // signed/unsigned int + case Instruction::SDiv: return EVLIntrinsicDesc{ Intrinsic::evl_sdiv, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::UDiv: return EVLIntrinsicDesc{ Intrinsic::evl_udiv, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::SRem: return EVLIntrinsicDesc{ Intrinsic::evl_srem, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::URem: return EVLIntrinsicDesc{ Intrinsic::evl_urem, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + + // logical + case Instruction::Or: return EVLIntrinsicDesc{ Intrinsic::evl_or, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::And: return EVLIntrinsicDesc{ Intrinsic::evl_and, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::Xor: return EVLIntrinsicDesc{ Intrinsic::evl_xor, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + + case Instruction::LShr: return EVLIntrinsicDesc{ Intrinsic::evl_lshr, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::AShr: return EVLIntrinsicDesc{ Intrinsic::evl_ashr, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + case Instruction::Shl: return EVLIntrinsicDesc{ Intrinsic::evl_shl, TypeTokenVec{EVLTypeToken::Vector}, 2, 3}; break; + + // comparison + case Instruction::ICmp: + case Instruction::FCmp: + return EVLIntrinsicDesc{ Intrinsic::evl_cmp, TypeTokenVec{EVLTypeToken::Mask, EVLTypeToken::Vector}, 2, 3}; break; + + default: + return EVLIntrinsicDesc{Intrinsic::not_intrinsic, TypeTokenVec(), -1, -1}; + } +} + +EVLIntrinsic::ShortTypeVec +EVLIntrinsic::EncodeTypeTokens(EVLIntrinsic::TypeTokenVec TTVec, Type & VectorTy, Type & ScalarTy) { + ShortTypeVec STV; + + for (auto Token : TTVec) { + switch (Token) { + default: + llvm_unreachable("unsupported token"); // unsupported EVLTypeToken + + case EVLIntrinsic::EVLTypeToken::Scalar: STV.push_back(&ScalarTy); break; + case EVLIntrinsic::EVLTypeToken::Vector: STV.push_back(&VectorTy); break; + case EVLIntrinsic::EVLTypeToken::Mask: + auto NumElems = VectorTy.getVectorNumElements(); + auto MaskTy = VectorType::get(Type::getInt1Ty(VectorTy.getContext()), NumElems); + STV.push_back(MaskTy); break; + } + } + + return STV; +} + + bool ConstrainedFPIntrinsic::isUnaryOp() const { switch (getIntrinsicID()) { default: Index: lib/IR/PredicatedInst.cpp =================================================================== --- /dev/null +++ lib/IR/PredicatedInst.cpp @@ -0,0 +1,61 @@ +#include +#include +#include +#include +#include + +namespace { + using namespace llvm; + using ShortValueVec = SmallVector; +} + +namespace llvm { + +void +PredicatedOperator::copyIRFlags(const Value * V, bool IncludeWrapFlags) { + auto * I = dyn_cast(this); + if (I) I->copyIRFlags(V, IncludeWrapFlags); +} + +Instruction* +PredicatedBinaryOperator::Create(Module * Mod, + Value *Mask, Value *VectorLen, + Instruction::BinaryOps Opc, + Value *V1, Value *V2, + const Twine &Name, + BasicBlock * InsertAtEnd, + Instruction * InsertBefore) { + assert(!(InsertAtEnd && InsertBefore)); + + auto evlDesc = EVLIntrinsic::GetEVLIntrinsicDesc(Opc); + + if ((!Mod || + (!Mask && !VectorLen)) || + evlDesc.ID == Intrinsic::not_intrinsic) { + if (InsertAtEnd) { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertAtEnd); + } else { + return BinaryOperator::Create(Opc, V1, V2, Name, InsertBefore); + } + } + + assert(Mod && "Need a module to emit EVL Intrinsics"); + + // Fetch the EVL intrinsic + auto & VecTy = cast(*V1->getType()); + auto & ScalarTy = *VecTy.getVectorElementType(); + auto * Func = Intrinsic::getDeclaration(Mod, evlDesc.ID, EVLIntrinsic::EncodeTypeTokens(evlDesc.typeTokens, VecTy, ScalarTy)); + + assert((evlDesc.MaskPos == 2) && (evlDesc.EVLPos == 3)); + + // Materialize the Call + ShortValueVec Args{V1, V2, Mask, VectorLen}; + + if (InsertAtEnd) { + return CallInst::Create(Func, {V1, V2, Mask, VectorLen}, Name, InsertAtEnd); + } else { + return CallInst::Create(Func, {V1, V2, Mask, VectorLen}, Name, InsertBefore); + } +} + +} Index: lib/IR/Verifier.cpp =================================================================== --- lib/IR/Verifier.cpp +++ lib/IR/Verifier.cpp @@ -1652,11 +1652,14 @@ if (Attrs.isEmpty()) return; + bool SawMask = false; bool SawNest = false; + bool SawPassthru = false; bool SawReturned = false; bool SawSRet = false; bool SawSwiftSelf = false; bool SawSwiftError = false; + bool SawVectorLength = false; // Verify return value attributes. AttributeSet RetAttrs = Attrs.getRetAttributes(); @@ -1719,12 +1722,33 @@ SawSwiftError = true; } + if (ArgAttrs.hasAttribute(Attribute::VectorLength)) { + Assert(!SawVectorLength, "Cannot have multiple 'vlen' parameters!", + V); + SawVectorLength = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Passthru)) { + Assert(!SawPassthru, "Cannot have multiple 'passthru' parameters!", + V); + SawPassthru = true; + } + + if (ArgAttrs.hasAttribute(Attribute::Mask)) { + Assert(!SawMask, "Cannot have multiple 'mask' parameters!", + V); + SawMask = true; + } + if (ArgAttrs.hasAttribute(Attribute::InAlloca)) { Assert(i == FT->getNumParams() - 1, "inalloca isn't on the last parameter!", V); } } + Assert(!SawPassthru || SawMask, + "Cannot have 'passthru' parameter without 'mask' parameter!", V); + if (!Attrs.hasAttributes(AttributeList::FunctionIndex)) return; @@ -3041,7 +3065,7 @@ /// visitUnaryOperator - Check the argument to the unary operator. /// void Verifier::visitUnaryOperator(UnaryOperator &U) { - Assert(U.getType() == U.getOperand(0)->getType(), + Assert(U.getType() == U.getOperand(0)->getType(), "Unary operators must have same type for" "operands and result!", &U); @@ -4870,7 +4894,7 @@ bool runOnFunction(Function &F) override { if (!V->verify(F) && FatalErrors) { - errs() << "in function " << F.getName() << '\n'; + errs() << "in function " << F.getName() << '\n'; report_fatal_error("Broken function found, compilation aborted!"); } return false; Index: lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -24,6 +24,9 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/PredicatedInst.h" +#include "llvm/IR/EVLBuilder.h" +#include "llvm/IR/MatcherCast.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Support/AlignOf.h" @@ -1744,6 +1747,17 @@ return Changed ? &I : nullptr; } +Instruction *InstCombiner::visitPredicatedFSub(PredicatedBinaryOperator& I) { + auto * Inst = cast(&I); + PredicatedContext PC(&I); + if (Value *V = SimplifyPredicatedFSubInst(I.getOperand(0), I.getOperand(1), + I.getFastMathFlags(), + SQ.getWithInstruction(Inst), PC)) + return replaceInstUsesWith(*Inst, V); + + return visitFSubGeneric(*Inst); +} + Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Value *V = SimplifyFSubInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -1753,11 +1767,19 @@ if (Instruction *X = foldVectorBinop(I)) return X; + return visitFSubGeneric(I); +} + +template +Instruction *InstCombiner::visitFSubGeneric(BinaryOpTy &I) { + MatchContextType MC(cast(&I)); + MatchContextBuilder MCBuilder(MC); + // Subtraction from -0.0 is the canonical form of fneg. // fsub nsz 0, X ==> fsub nsz -0.0, X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) - return BinaryOperator::CreateFNegFMF(Op1, &I); + if (I.hasNoSignedZeros() && MC.try_match(Op0, m_PosZeroFP())) + return MCBuilder.CreateFNegFMF(Op1, &I); Value *X, *Y; Constant *C; @@ -1765,14 +1787,14 @@ // Fold negation into constant operand. This is limited with one-use because // fneg is assumed better for analysis and cheaper in codegen than fmul/fdiv. // -(X * C) --> X * (-C) - if (match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FMul(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); // -(X / C) --> X / (-C) - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Value(X), m_Constant(C)))))) + return MCBuilder.CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); // -(C / X) --> (-C) / X - if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) - return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + if (MC.try_match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) + return MCBuilder.CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. @@ -1781,71 +1803,75 @@ // killed later. We still limit that particular transform with 'hasOneUse' // because an fneg is assumed better/cheaper than a generic fsub. if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) { - if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { - Value *NewSub = Builder.CreateFSubFMF(Y, X, &I); - return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I); + if (MC.try_match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *NewSub = MCBuilder.CreateFSubFMF(Builder, Y, X, &I); + return MCBuilder.CreateFAddFMF(Op0, NewSub, &I); } } - if (isa(Op0)) - if (SelectInst *SI = dyn_cast(Op1)) - if (Instruction *NV = FoldOpIntoSelect(I, SI)) - return NV; + if (auto * PlainBinOp = dyn_cast(&I)) + if (isa(Op0)) + if (SelectInst *SI = dyn_cast(Op1)) + if (Instruction *NV = FoldOpIntoSelect(*PlainBinOp, SI)) + return NV; // X - C --> X + (-C) // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. - if (match(Op1, m_Constant(C)) && !isa(Op1)) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (MC.try_match(Op1, m_Constant(C)) && !isa(Op1)) + return MCBuilder.CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); // X - (-Y) --> X + Y - if (match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFAddFMF(Op0, Y, &I); + if (MC.try_match(Op1, m_FNeg(m_Value(Y)))) + return MCBuilder.CreateFAddFMF(Op0, Y, &I); // Similar to above, but look through a cast of the negated value: // X - (fptrunc(-Y)) --> X + fptrunc(Y) Type *Ty = I.getType(); - if (match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPTrunc(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPTrunc(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPTrunc(Builder, Y, Ty), &I); // X - (fpext(-Y)) --> X + fpext(Y) - if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) - return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + if (MC.try_match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) + return MCBuilder.CreateFAddFMF(Op0, MCBuilder.CreateFPExt(Builder, Y, Ty), &I); // Handle special cases for FSub with selects feeding the operation - if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) - return replaceInstUsesWith(I, V); + if (auto * PlainBinOp = dyn_cast(&I)) + if (Value *V = SimplifySelectsFeedingBinaryOp(*PlainBinOp, Op0, Op1)) + return replaceInstUsesWith(I, V); if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X - if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X - if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + if (MC.try_match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) + return MCBuilder.CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) - if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { + if (MC.try_match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { Constant *CSubOne = ConstantExpr::getFSub(C, ConstantFP::get(Ty, 1.0)); - return BinaryOperator::CreateFMulFMF(Op1, CSubOne, &I); + return MCBuilder.CreateFMulFMF(Op1, CSubOne, &I); } // X - (X * C) --> X * (1.0 - C) - if (match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { + if (MC.try_match(Op1, m_FMul(m_Specific(Op0), m_Constant(C)))) { Constant *OneSubC = ConstantExpr::getFSub(ConstantFP::get(Ty, 1.0), C); - return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); + return MCBuilder.CreateFMulFMF(Op0, OneSubC, &I); } - if (Instruction *F = factorizeFAddFSub(I, Builder)) - return F; + if (auto * PlainBinOp = dyn_cast(&I)) { + if (Instruction *F = factorizeFAddFSub(*PlainBinOp, Builder)) + return F; - // TODO: This performs reassociative folds for FP ops. Some fraction of the - // functionality has been subsumed by simple pattern matching here and in - // InstSimplify. We should let a dedicated reassociation pass handle more - // complex pattern matching and remove this from InstCombine. - if (Value *V = FAddCombine(Builder).simplify(&I)) - return replaceInstUsesWith(I, V); + // TODO: This performs reassociative folds for FP ops. Some fraction of the + // functionality has been subsumed by simple pattern matching here and in + // InstSimplify. We should let a dedicated reassociation pass handle more + // complex pattern matching and remove this from InstCombine. + if (Value *V = FAddCombine(Builder).simplify(PlainBinOp)) + return replaceInstUsesWith(*PlainBinOp, V); + } } return nullptr; Index: lib/Transforms/InstCombine/InstCombineCalls.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCalls.cpp +++ lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -37,6 +37,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" @@ -1804,6 +1805,14 @@ return &CI; } + // Predicated instruction patterns + auto * EVLInst = dyn_cast(&CI); + if (EVLInst) { + auto * PredInst = cast(EVLInst); + auto Result = visitPredicatedInstruction(PredInst); + if (Result) return Result; + } + IntrinsicInst *II = dyn_cast(&CI); if (!II) return visitCallSite(&CI); Index: lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- lib/Transforms/InstCombine/InstCombineInternal.h +++ lib/Transforms/InstCombine/InstCombineInternal.h @@ -30,6 +30,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PredicatedInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" @@ -348,6 +349,8 @@ Instruction *visitFAdd(BinaryOperator &I); Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty); Instruction *visitSub(BinaryOperator &I); + template Instruction *visitFSubGeneric(BinaryOpTy &I); + Instruction *visitPredicatedFSub(PredicatedBinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); @@ -415,6 +418,16 @@ Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); + // Entry point to EVLIntrinsic + Instruction *visitPredicatedInstruction(PredicatedInstruction * PI) { + switch (PI->getOpcode()) { + default: + return nullptr; + case Instruction::FSub: + return visitPredicatedFSub(cast(*PI)); + } + } + /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } Index: lib/Transforms/Utils/CodeExtractor.cpp =================================================================== --- lib/Transforms/Utils/CodeExtractor.cpp +++ lib/Transforms/Utils/CodeExtractor.cpp @@ -773,6 +773,7 @@ case Attribute::InaccessibleMemOnly: case Attribute::InaccessibleMemOrArgMemOnly: case Attribute::JumpTable: + case Attribute::Mask: case Attribute::Naked: case Attribute::Nest: case Attribute::NoAlias: @@ -781,6 +782,7 @@ case Attribute::NoReturn: case Attribute::None: case Attribute::NonNull: + case Attribute::Passthru: case Attribute::ReadNone: case Attribute::ReadOnly: case Attribute::Returned: @@ -791,6 +793,7 @@ case Attribute::StructRet: case Attribute::SwiftError: case Attribute::SwiftSelf: + case Attribute::VectorLength: case Attribute::WriteOnly: case Attribute::ZExt: case Attribute::EndAttrKinds: Index: test/Bitcode/attributes.ll =================================================================== --- test/Bitcode/attributes.ll +++ test/Bitcode/attributes.ll @@ -351,6 +351,11 @@ ret void } +; CHECK: define <8 x double> @f60(<8 x double> passthru, <8 x i1> mask, i32 vlen) { +define <8 x double> @f60(<8 x double> passthru, <8 x i1> mask, i32 vlen) { + ret <8 x double> undef +} + ; CHECK: attributes #0 = { noreturn } ; CHECK: attributes #1 = { nounwind } ; CHECK: attributes #2 = { readnone } Index: test/Transforms/InstCombine/evl-fsub.ll =================================================================== --- /dev/null +++ test/Transforms/InstCombine/evl-fsub.ll @@ -0,0 +1,43 @@ +; RUN: opt < %s -instcombine -S | FileCheck %s + +; PR4374 + +define <4 x float> @test1_evl(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @test1_evl( +; + %t1 = call <4 x float> @llvm.evl.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) + %t2 = call <4 x float> @llvm.evl.fsub.v4f32(<4 x float> , <4 x float> %t1, <4 x i1> %M, i32 %L) + ret <4 x float> %t2 +} + +; Can't do anything with the test above because -0.0 - 0.0 = -0.0, but if we have nsz: +; -(X - Y) --> Y - X + +; TODO predicated FAdd folding +define <4 x float> @neg_sub_nsz_evl(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) { +; CH***-LABEL: @neg_sub_nsz_evl( +; + %t1 = call <4 x float> @llvm.evl.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) + %t2 = call nsz <4 x float> @llvm.evl.fsub.v4f32(<4 x float> , <4 x float> %t1, <4 x i1> %M, i32 %L) + ret <4 x float> %t2 +} + +; With nsz: Z - (X - Y) --> Z + (Y - X) + +define <4 x float> @sub_sub_nsz_evl(<4 x float> %x, <4 x float> %y, <4 x float> %z, <4 x i1> %M, i32 %L) { +; CHECK-LABEL: @sub_sub_nsz_evl( +; CHECK-NEXT: %1 = call nsz <4 x float> @llvm.evl.fsub.v4f32(<4 x float> %y, <4 x float> %x, <4 x i1> %M, i32 %L) +; CHECK-NEXT: %t2 = call nsz <4 x float> @llvm.evl.fadd.v4f32(<4 x float> %z, <4 x float> %1, <4 x i1> %M, i32 %L) +; CHECK-NEXT: ret <4 x float> %t2 + %t1 = call <4 x float> @llvm.evl.fsub.v4f32(<4 x float> %x, <4 x float> %y, <4 x i1> %M, i32 %L) + %t2 = call nsz <4 x float> @llvm.evl.fsub.v4f32(<4 x float> %z, <4 x float> %t1, <4 x i1> %M, i32 %L) + ret <4 x float> %t2 +} + + + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.evl.fadd.v4f32(<4 x float>, <4 x float>, <4 x i1> mask, i32 vlen) #0 + +; Function Attrs: nounwind readnone +declare <4 x float> @llvm.evl.fsub.v4f32(<4 x float>, <4 x float>, <4 x i1> mask, i32 vlen) #0 Index: test/Transforms/InstSimplify/evl-fsub.ll =================================================================== --- /dev/null +++ test/Transforms/InstSimplify/evl-fsub.ll @@ -0,0 +1,43 @@ +; RUN: opt < %s -instsimplify -S | FileCheck %s + +define <8 x double> @fsub_fadd_fold_evl_xy(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_evl_xy +; CHECK-NEXT: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.evl.fadd.v8f64(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) + %res = call reassoc nsz <8 x double> @llvm.evl.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %x +} + +define <8 x double> @fsub_fadd_fold_evl_yx(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len) { +; CHECK-LABEL: fsub_fadd_fold_evl_yx +; CHECK-NEXT: ret <8 x double> %x + %tmp = call reassoc nsz <8 x double> @llvm.evl.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) + %res = call reassoc nsz <8 x double> @llvm.evl.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %x +} + +define <8 x double> @fsub_fadd_fold_evl_yx_olen(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, i32 %otherLen) { +; CHECK-LABEL: fsub_fadd_fold_evl_yx_olen +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.evl.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %otherLen) +; CHECK-NEXT: %res = call reassoc nsz <8 x double> @llvm.evl.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) +; CHECK-NEXT: ret <8 x double> %res + %tmp = call reassoc nsz <8 x double> @llvm.evl.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %otherLen) + %res = call reassoc nsz <8 x double> @llvm.evl.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %m, i32 %len) + ret <8 x double> %res +} + +define <8 x double> @fsub_fadd_fold_evl_yx_omask(<8 x double> %x, <8 x double> %y, <8 x i1> %m, i32 %len, <8 x i1> %othermask) { +; CHECK-LABEL: fsub_fadd_fold_evl_yx_omask +; CHECK-NEXT: %tmp = call reassoc nsz <8 x double> @llvm.evl.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) +; CHECK-NEXT: %res = call reassoc nsz <8 x double> @llvm.evl.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %othermask, i32 %len) +; CHECK-NEXT: ret <8 x double> %res + %tmp = call reassoc nsz <8 x double> @llvm.evl.fadd.v8f64(<8 x double> %y, <8 x double> %x, <8 x i1> %m, i32 %len) + %res = call reassoc nsz <8 x double> @llvm.evl.fsub.v8f64(<8 x double> %tmp, <8 x double> %y, <8 x i1> %othermask, i32 %len) + ret <8 x double> %res +} + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.evl.fadd.v8f64(<8 x double>, <8 x double>, <8 x i1> mask, i32 vlen) #0 + +; Function Attrs: nounwind readnone +declare <8 x double> @llvm.evl.fsub.v8f64(<8 x double>, <8 x double>, <8 x i1> mask, i32 vlen) #0 Index: test/Verifier/evl_attribs.ll =================================================================== --- /dev/null +++ test/Verifier/evl_attribs.ll @@ -0,0 +1,13 @@ +; RUN: not llvm-as %s -o /dev/null 2>&1 | FileCheck %s + +declare void @a(<16 x i1> mask %a, <16 x i1> mask %b) +; CHECK: Cannot have multiple 'mask' parameters! + +declare void @b(<16 x i1> mask %a, i32 vlen %x, i32 vlen %y) +; CHECK: Cannot have multiple 'vlen' parameters! + +declare <16 x double> @c(<16 x double> passthru %a) +; CHECK: Cannot have 'passthru' parameter without 'mask' parameter! + +declare <16 x double> @d(<16 x double> passthru %a, <16 x i1> mask %M, <16 x double> passthru %b) +; CHECK: Cannot have multiple 'passthru' parameters! Index: utils/TableGen/CodeGenIntrinsics.h =================================================================== --- utils/TableGen/CodeGenIntrinsics.h +++ utils/TableGen/CodeGenIntrinsics.h @@ -136,7 +136,7 @@ // True if the intrinsic is marked as speculatable. bool isSpeculatable; - enum ArgAttribute { NoCapture, Returned, ReadOnly, WriteOnly, ReadNone }; + enum ArgAttribute { Mask, NoCapture, Passthru, Returned, ReadOnly, WriteOnly, ReadNone, VectorLength }; std::vector> ArgumentAttributes; bool hasProperty(enum SDNP Prop) const { Index: utils/TableGen/CodeGenTarget.cpp =================================================================== --- utils/TableGen/CodeGenTarget.cpp +++ utils/TableGen/CodeGenTarget.cpp @@ -599,10 +599,10 @@ "Expected iAny or vAny type"); } else { VT = getValueType(TyEl->getValueAsDef("VT")); - } - if (MVT(VT).isOverloaded()) { - OverloadedVTs.push_back(VT); - isOverloaded = true; + if (MVT(VT).isOverloaded()) { + OverloadedVTs.push_back(VT); + isOverloaded = true; + } } // Reject invalid types. @@ -636,14 +636,15 @@ !TyEl->isSubClassOf("LLVMScalarOrSameVectorWidth")) || VT == MVT::iAny || VT == MVT::vAny) && "Expected iAny or vAny type"); - } else + } else { VT = getValueType(TyEl->getValueAsDef("VT")); - - if (MVT(VT).isOverloaded()) { - OverloadedVTs.push_back(VT); - isOverloaded = true; + if (MVT(VT).isOverloaded()) { + OverloadedVTs.push_back(VT); + isOverloaded = true; + } } + // Reject invalid types. if (VT == MVT::isVoid && i != e-1 /*void at end means varargs*/) PrintFatalError("Intrinsic '" + DefName + " has void in result type list!"); @@ -694,6 +695,15 @@ } else if (Property->isSubClassOf("Returned")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, Returned)); + } else if (Property->isSubClassOf("VectorLength")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, VectorLength)); + } else if (Property->isSubClassOf("Mask")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Mask)); + } else if (Property->isSubClassOf("Passthru")) { + unsigned ArgNo = Property->getValueAsInt("ArgNo"); + ArgumentAttributes.push_back(std::make_pair(ArgNo, Passthru)); } else if (Property->isSubClassOf("ReadOnly")) { unsigned ArgNo = Property->getValueAsInt("ArgNo"); ArgumentAttributes.push_back(std::make_pair(ArgNo, ReadOnly)); Index: utils/TableGen/IntrinsicEmitter.cpp =================================================================== --- utils/TableGen/IntrinsicEmitter.cpp +++ utils/TableGen/IntrinsicEmitter.cpp @@ -594,6 +594,24 @@ OS << "Attribute::Returned"; addComma = true; break; + case CodeGenIntrinsic::VectorLength: + if (addComma) + OS << ","; + OS << "Attribute::VectorLength"; + addComma = true; + break; + case CodeGenIntrinsic::Mask: + if (addComma) + OS << ","; + OS << "Attribute::Mask"; + addComma = true; + break; + case CodeGenIntrinsic::Passthru: + if (addComma) + OS << ","; + OS << "Attribute::Passthru"; + addComma = true; + break; case CodeGenIntrinsic::ReadOnly: if (addComma) OS << ",";