diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -1071,6 +1071,11 @@ return false; } + /// How to legalize this custom operation? + virtual LegalizeAction getCustomOperationAction(SDNode &Op) const { + return Legal; + } + /// Return how this operation should be treated: either it is legal, needs to /// be promoted to a larger size, needs to be expanded to some other code /// sequence, or the target has a custom expander for it. diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -1212,7 +1212,7 @@ break; default: if (Node->getOpcode() >= ISD::BUILTIN_OP_END) { - Action = TargetLowering::Legal; + Action = TLI.getCustomOperationAction(*Node); } else { Action = TLI.getOperationAction(Node->getOpcode(), Node->getValueType(0)); } diff --git a/llvm/lib/Target/VE/VECustomDAG.h b/llvm/lib/Target/VE/VECustomDAG.h --- a/llvm/lib/Target/VE/VECustomDAG.h +++ b/llvm/lib/Target/VE/VECustomDAG.h @@ -27,6 +27,20 @@ bool isPackedVectorType(EVT SomeVT); +Optional getAVLPos(unsigned); + +bool isLegalAVL(SDValue AVL); + +SDValue getNodeAVL(SDValue); + +bool isVVPOrVEC(unsigned); + +bool maySafelyIgnoreMask(SDValue Op); + +// Return the unwrapped AVL of this operation and whether the AVL was unwrapped +// from a LEGALAVL. +std::pair getAnnotatedNodeAVL(SDValue); + class VECustomDAG { SelectionDAG &DAG; SDLoc DL; @@ -72,6 +86,9 @@ bool IsOpaque = false) const; SDValue getBroadcast(EVT ResultVT, SDValue Scalar, SDValue AVL) const; + + // Wrap AVL in a LEGALAVL node (unless it is one already). + SDValue annotateLegalAVL(SDValue AVL) const; }; } // namespace llvm diff --git a/llvm/lib/Target/VE/VECustomDAG.cpp b/llvm/lib/Target/VE/VECustomDAG.cpp --- a/llvm/lib/Target/VE/VECustomDAG.cpp +++ b/llvm/lib/Target/VE/VECustomDAG.cpp @@ -42,6 +42,32 @@ return None; } +bool maySafelyIgnoreMask(SDValue Op) { + auto VVPOpc = getVVPOpcode(Op->getOpcode()); + auto Opc = VVPOpc ? *VVPOpc : Op->getOpcode(); + + switch (Opc) { + case VEISD::VVP_SDIV: + case VEISD::VVP_UDIV: + case VEISD::VVP_FDIV: + case VEISD::VVP_SELECT: + return false; + + default: + return true; + } +} + +bool isVVPOrVEC(unsigned Opcode) { + switch (Opcode) { + case VEISD::VEC_BROADCAST: +#define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME: +#include "VVPNodes.def" + return true; + } + return false; +} + bool isVVPBinaryOp(unsigned VVPOpcode) { switch (VVPOpcode) { #define ADD_BINARY_VVP_OP(VVPNAME, ...) \ @@ -52,6 +78,44 @@ return false; } +// Return the AVL operand position for this VVP or VEC Op. +Optional getAVLPos(unsigned Opc) { + // This is only available for VP SDNodes + auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opc); + if (PosOpt) + return *PosOpt; + + // VVP Opcodes. + if (isVVPBinaryOp(Opc)) + return 3; + + // VM Opcodes. + switch (Opc) { + case VEISD::VEC_BROADCAST: + return 1; + case VEISD::VVP_SELECT: + return 3; + } + + return None; +} + +bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; } + +SDValue getNodeAVL(SDValue Op) { + auto PosOpt = getAVLPos(Op->getOpcode()); + return PosOpt ? Op->getOperand(*PosOpt) : SDValue(); +} + +std::pair getAnnotatedNodeAVL(SDValue Op) { + SDValue AVL = getNodeAVL(Op); + if (!AVL) + return {SDValue(), true}; + if (isLegalAVL(AVL)) + return {AVL->getOperand(0), true}; + return {AVL, false}; +} + SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget, bool IsOpaque) const { return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque); @@ -78,4 +142,10 @@ return getNode(VEISD::VEC_BROADCAST, ResultVT, {Scalar, AVL}); } +SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const { + if (isLegalAVL(AVL)) + return AVL; + return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL); +} + } // namespace llvm diff --git a/llvm/lib/Target/VE/VEISelDAGToDAG.cpp b/llvm/lib/Target/VE/VEISelDAGToDAG.cpp --- a/llvm/lib/Target/VE/VEISelDAGToDAG.cpp +++ b/llvm/lib/Target/VE/VEISelDAGToDAG.cpp @@ -335,6 +335,12 @@ } switch (N->getOpcode()) { + + // Late eliminate the LEGALAVL wrapper + case VEISD::LEGALAVL: + ReplaceNode(N, N->getOperand(0).getNode()); + return; + case VEISD::GLOBAL_BASE_REG: ReplaceNode(N, getGlobalBaseReg()); return; diff --git a/llvm/lib/Target/VE/VEISelLowering.h b/llvm/lib/Target/VE/VEISelLowering.h --- a/llvm/lib/Target/VE/VEISelLowering.h +++ b/llvm/lib/Target/VE/VEISelLowering.h @@ -43,12 +43,19 @@ REPL_I32, REPL_F32, // Replicate subregister to other half. + // Annotation as a wrapper. LEGALAVL(VL) means that VL refers to 64bit of + // data, whereas the raw EVL coming in from VP nodes always refers to number + // of elements, regardless of their size. + LEGALAVL, + // VVP_* nodes. #define ADD_VVP_OP(VVP_NAME, ...) VVP_NAME, #include "VVPNodes.def" }; } +class VECustomDAG; + class VETargetLowering : public TargetLowering { const VESubtarget *Subtarget; @@ -105,6 +112,9 @@ } /// Custom Lower { + TargetLoweringBase::LegalizeAction + getCustomOperationAction(SDNode &) const override; + SDValue LowerOperation(SDValue Op, SelectionDAG &DAG) const override; unsigned getJumpTableEncoding() const override; const MCExpr *LowerCustomJumpTableEntry(const MachineJumpTableInfo *MJTI, @@ -170,6 +180,8 @@ /// VVP Lowering { SDValue lowerToVVP(SDValue Op, SelectionDAG &DAG) const; + SDValue legalizeInternalVectorOp(SDValue Op, SelectionDAG &DAG) const; + SDValue legalizePackedAVL(SDValue Op, VECustomDAG &CDAG) const; /// } VVPLowering /// Custom DAGCombine { diff --git a/llvm/lib/Target/VE/VEISelLowering.cpp b/llvm/lib/Target/VE/VEISelLowering.cpp --- a/llvm/lib/Target/VE/VEISelLowering.cpp +++ b/llvm/lib/Target/VE/VEISelLowering.cpp @@ -902,6 +902,8 @@ TARGET_NODE_CASE(REPL_I32) TARGET_NODE_CASE(REPL_F32) + TARGET_NODE_CASE(LEGALAVL) + // Register the VVP_* SDNodes. #define ADD_VVP_OP(VVP_NAME, ...) TARGET_NODE_CASE(VVP_NAME) #include "VVPNodes.def" @@ -1658,10 +1660,7 @@ // Else emit a broadcast. if (SDValue ScalarV = getSplatValue(Op.getNode())) { unsigned NumEls = ResultVT.getVectorNumElements(); - // TODO: Legalize packed-mode AVL. - // For now, cap the AVL at 256. - auto CappedLength = std::min(256, NumEls); - auto AVL = CDAG.getConstant(CappedLength, MVT::i32); + auto AVL = CDAG.getConstant(NumEls, MVT::i32); return CDAG.getBroadcast(ResultVT, Op.getOperand(0), AVL); } @@ -1669,7 +1668,71 @@ return SDValue(); } +SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op, + SelectionDAG &DAG) const { + VECustomDAG CDAG(DAG, Op); + // TODO: Implement odd/even splitting. + return legalizePackedAVL(Op, CDAG); +} + +SDValue VETargetLowering::legalizePackedAVL(SDValue Op, + VECustomDAG &CDAG) const { + LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";); + // Only required for VEC and VVP ops. + if (!isVVPOrVEC(Op->getOpcode())) + return Op; + + // Operation already has a legal AVL. + auto AVL = getNodeAVL(Op); + if (isLegalAVL(AVL)) + return Op; + + // Half and round up EVL for 32bit element types. + SDValue LegalAVL = AVL; + if (isPackedVectorType(Op.getValueType())) { + assert(maySafelyIgnoreMask(Op) && + "TODO Shift predication from EVL into Mask"); + + if (auto ConstAVL = dyn_cast(AVL)) { + LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32); + } else { + auto ConstOne = CDAG.getConstant(1, MVT::i32); + auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne}); + LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne}); + } + } + + SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL); + + // Copy the operand list. + int NumOp = Op->getNumOperands(); + auto AVLPos = getAVLPos(Op->getOpcode()); + std::vector FixedOperands; + for (int i = 0; i < NumOp; ++i) { + if (AVLPos && (i == *AVLPos)) { + FixedOperands.push_back(AnnotatedLegalAVL); + continue; + } + FixedOperands.push_back(Op->getOperand(i)); + } + + // Clone the operation with fixed operands. + auto Flags = Op->getFlags(); + SDValue NewN = + CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags); + return NewN; +} + +TargetLowering::LegalizeAction +VETargetLowering::getCustomOperationAction(SDNode &Op) const { + // Custom lower to legalize AVL for packed mode. + if (isVVPOrVEC(Op.getOpcode())) + return Custom; + return Legal; +} + SDValue VETargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const { + LLVM_DEBUG(dbgs() << "::LowerOperation"; Op->print(dbgs());); unsigned Opcode = Op.getOpcode(); if (ISD::isVPOpcode(Opcode)) return lowerToVVP(Op, DAG); @@ -1721,6 +1784,16 @@ case ISD::EXTRACT_VECTOR_ELT: return lowerEXTRACT_VECTOR_ELT(Op, DAG); + // Legalize the AVL of this internal node. + case VEISD::VEC_BROADCAST: +#define ADD_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME: +#include "VVPNodes.def" + // AVL already legalized. + if (getAnnotatedNodeAVL(Op).second) + return Op; + return legalizeInternalVectorOp(Op, DAG); + + // Translate into a VEC_*/VVP_* layer operation. #define ADD_VVP_OP(VVP_NAME, ISD_NAME) case ISD::ISD_NAME: #include "VVPNodes.def" return lowerToVVP(Op, DAG);