diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -13,6 +13,7 @@ #include "RISCVISelDAGToDAG.h" #include "MCTargetDesc/RISCVMCTargetDesc.h" #include "MCTargetDesc/RISCVMatInt.h" +#include "RISCVISelLowering.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/Support/Alignment.h" @@ -62,64 +63,6 @@ return Result; } -static RISCVVLMUL getLMUL(MVT VT) { - switch (VT.getSizeInBits().getKnownMinValue() / 8) { - default: - llvm_unreachable("Invalid LMUL."); - case 1: - return RISCVVLMUL::LMUL_F8; - case 2: - return RISCVVLMUL::LMUL_F4; - case 4: - return RISCVVLMUL::LMUL_F2; - case 8: - return RISCVVLMUL::LMUL_1; - case 16: - return RISCVVLMUL::LMUL_2; - case 32: - return RISCVVLMUL::LMUL_4; - case 64: - return RISCVVLMUL::LMUL_8; - } -} - -static unsigned getRegClassIDForLMUL(RISCVVLMUL LMul) { - switch (LMul) { - default: - llvm_unreachable("Invalid LMUL."); - case RISCVVLMUL::LMUL_F8: - case RISCVVLMUL::LMUL_F4: - case RISCVVLMUL::LMUL_F2: - case RISCVVLMUL::LMUL_1: - return RISCV::VRRegClassID; - case RISCVVLMUL::LMUL_2: - return RISCV::VRM2RegClassID; - case RISCVVLMUL::LMUL_4: - return RISCV::VRM4RegClassID; - case RISCVVLMUL::LMUL_8: - return RISCV::VRM8RegClassID; - } -} - -static unsigned getSubregIndexByMVT(MVT VT, unsigned Index) { - RISCVVLMUL LMUL = getLMUL(VT); - if (LMUL == RISCVVLMUL::LMUL_F8 || LMUL == RISCVVLMUL::LMUL_F4 || - LMUL == RISCVVLMUL::LMUL_F2 || LMUL == RISCVVLMUL::LMUL_1) { - static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7, - "Unexpected subreg numbering"); - return RISCV::sub_vrm1_0 + Index; - } else if (LMUL == RISCVVLMUL::LMUL_2) { - static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3, - "Unexpected subreg numbering"); - return RISCV::sub_vrm2_0 + Index; - } else if (LMUL == RISCVVLMUL::LMUL_4) { - static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1, - "Unexpected subreg numbering"); - return RISCV::sub_vrm4_0 + Index; - } - llvm_unreachable("Invalid vector type."); -} - static SDValue createTupleImpl(SelectionDAG &CurDAG, ArrayRef Regs, unsigned RegClassID, unsigned SubReg0) { assert(Regs.size() >= 2 && Regs.size() <= 8); @@ -187,7 +130,7 @@ MVT VT = Node->getSimpleValueType(0); unsigned ScalarSize = VT.getScalarSizeInBits(); MVT XLenVT = Subtarget->getXLenVT(); - RISCVVLMUL LMUL = getLMUL(VT); + RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT); SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT); unsigned CurOp = 2; SmallVector Operands; @@ -218,10 +161,11 @@ CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()}); SDValue SuperReg = SDValue(Load, 0); - for (unsigned I = 0; I < NF; ++I) + for (unsigned I = 0; I < NF; ++I) { + unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I); ReplaceUses(SDValue(Node, I), - CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL, - VT, SuperReg)); + CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg)); + } ReplaceUses(SDValue(Node, NF), SDValue(Load, 1)); CurDAG->RemoveDeadNode(Node); @@ -233,7 +177,7 @@ MVT VT = Node->getSimpleValueType(0); MVT XLenVT = Subtarget->getXLenVT(); unsigned ScalarSize = VT.getScalarSizeInBits(); - RISCVVLMUL LMUL = getLMUL(VT); + RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT); SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT); unsigned CurOp = 2; @@ -265,10 +209,11 @@ CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()}); SDValue SuperReg = SDValue(Load, 0); - for (unsigned I = 0; I < NF; ++I) + for (unsigned I = 0; I < NF; ++I) { + unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I); ReplaceUses(SDValue(Node, I), - CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL, - VT, SuperReg)); + CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg)); + } ReplaceUses(SDValue(Node, NF), SDValue(ReadVL, 0)); // VL ReplaceUses(SDValue(Node, NF + 1), SDValue(Load, 1)); // Chain @@ -282,7 +227,7 @@ MVT VT = Node->getSimpleValueType(0); unsigned ScalarSize = VT.getScalarSizeInBits(); MVT XLenVT = Subtarget->getXLenVT(); - RISCVVLMUL LMUL = getLMUL(VT); + RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT); SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT); unsigned CurOp = 2; SmallVector Operands; @@ -307,7 +252,7 @@ assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() && "Element count mismatch"); - RISCVVLMUL IndexLMUL = getLMUL(IndexVT); + RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT); unsigned IndexScalarSize = IndexVT.getScalarSizeInBits(); const RISCV::VLXSEGPseudo *P = RISCV::getVLXSEGPseudo( NF, IsMasked, IsOrdered, IndexScalarSize, static_cast(LMUL), @@ -319,10 +264,11 @@ CurDAG->setNodeMemRefs(Load, {MemOp->getMemOperand()}); SDValue SuperReg = SDValue(Load, 0); - for (unsigned I = 0; I < NF; ++I) + for (unsigned I = 0; I < NF; ++I) { + unsigned SubRegIdx = RISCVTargetLowering::getSubregIndexByMVT(VT, I); ReplaceUses(SDValue(Node, I), - CurDAG->getTargetExtractSubreg(getSubregIndexByMVT(VT, I), DL, - VT, SuperReg)); + CurDAG->getTargetExtractSubreg(SubRegIdx, DL, VT, SuperReg)); + } ReplaceUses(SDValue(Node, NF), SDValue(Load, 1)); CurDAG->RemoveDeadNode(Node); @@ -339,7 +285,7 @@ MVT VT = Node->getOperand(2)->getSimpleValueType(0); unsigned ScalarSize = VT.getScalarSizeInBits(); MVT XLenVT = Subtarget->getXLenVT(); - RISCVVLMUL LMUL = getLMUL(VT); + RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT); SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT); SmallVector Regs(Node->op_begin() + 2, Node->op_begin() + 2 + NF); SDValue StoreVal = createTuple(*CurDAG, Regs, NF, LMUL); @@ -376,7 +322,7 @@ MVT VT = Node->getOperand(2)->getSimpleValueType(0); unsigned ScalarSize = VT.getScalarSizeInBits(); MVT XLenVT = Subtarget->getXLenVT(); - RISCVVLMUL LMUL = getLMUL(VT); + RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT); SDValue SEW = CurDAG->getTargetConstant(ScalarSize, DL, XLenVT); SmallVector Operands; SmallVector Regs(Node->op_begin() + 2, Node->op_begin() + 2 + NF); @@ -397,7 +343,7 @@ assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() && "Element count mismatch"); - RISCVVLMUL IndexLMUL = getLMUL(IndexVT); + RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT); unsigned IndexScalarSize = IndexVT.getScalarSizeInBits(); const RISCV::VSXSEGPseudo *P = RISCV::getVSXSEGPseudo( NF, IsMasked, IsOrdered, IndexScalarSize, static_cast(LMUL), @@ -411,47 +357,6 @@ ReplaceNode(Node, Store); } -static unsigned getRegClassIDForVecVT(MVT VT) { - if (VT.getVectorElementType() == MVT::i1) - return RISCV::VRRegClassID; - return getRegClassIDForLMUL(getLMUL(VT)); -} - -// Attempt to decompose a subvector insert/extract between VecVT and -// SubVecVT via subregister indices. Returns the subregister index that -// can perform the subvector insert/extract with the given element index, as -// well as the index corresponding to any leftover subvectors that must be -// further inserted/extracted within the register class for SubVecVT. -static std::pair -decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT, - unsigned InsertExtractIdx, - const RISCVRegisterInfo *TRI) { - static_assert((RISCV::VRM8RegClassID > RISCV::VRM4RegClassID && - RISCV::VRM4RegClassID > RISCV::VRM2RegClassID && - RISCV::VRM2RegClassID > RISCV::VRRegClassID), - "Register classes not ordered"); - unsigned VecRegClassID = getRegClassIDForVecVT(VecVT); - unsigned SubRegClassID = getRegClassIDForVecVT(SubVecVT); - // Try to compose a subregister index that takes us from the incoming - // LMUL>1 register class down to the outgoing one. At each step we half - // the LMUL: - // nxv16i32@12 -> nxv2i32: sub_vrm4_1_then_sub_vrm2_1_then_sub_vrm1_0 - // Note that this is not guaranteed to find a subregister index, such as - // when we are extracting from one VR type to another. - unsigned SubRegIdx = RISCV::NoSubRegister; - for (const unsigned RCID : - {RISCV::VRM4RegClassID, RISCV::VRM2RegClassID, RISCV::VRRegClassID}) - if (VecRegClassID > RCID && SubRegClassID <= RCID) { - VecVT = VecVT.getHalfNumVectorElementsVT(); - bool IsHi = - InsertExtractIdx >= VecVT.getVectorElementCount().getKnownMinValue(); - SubRegIdx = TRI->composeSubRegIndices(SubRegIdx, - getSubregIndexByMVT(VecVT, IsHi)); - if (IsHi) - InsertExtractIdx -= VecVT.getVectorElementCount().getKnownMinValue(); - } - return {SubRegIdx, InsertExtractIdx}; -} void RISCVDAGToDAGISel::Select(SDNode *Node) { // If we have a custom node, we have already selected. @@ -726,8 +631,8 @@ assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() && "Element count mismatch"); - RISCVVLMUL LMUL = getLMUL(VT); - RISCVVLMUL IndexLMUL = getLMUL(IndexVT); + RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT); + RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT); unsigned IndexScalarSize = IndexVT.getScalarSizeInBits(); const RISCV::VLX_VSXPseudo *P = RISCV::getVLXPseudo( IsMasked, IsOrdered, IndexScalarSize, static_cast(LMUL), @@ -855,8 +760,8 @@ assert(VT.getVectorElementCount() == IndexVT.getVectorElementCount() && "Element count mismatch"); - RISCVVLMUL LMUL = getLMUL(VT); - RISCVVLMUL IndexLMUL = getLMUL(IndexVT); + RISCVVLMUL LMUL = RISCVTargetLowering::getLMUL(VT); + RISCVVLMUL IndexLMUL = RISCVTargetLowering::getLMUL(IndexVT); unsigned IndexScalarSize = IndexVT.getScalarSizeInBits(); const RISCV::VLX_VSXPseudo *P = RISCV::getVSXPseudo( IsMasked, IsOrdered, IndexScalarSize, static_cast(LMUL), @@ -895,7 +800,7 @@ // For now, keep the two paths separate. if (VT.isScalableVector() && SubVecVT.isScalableVector()) { bool IsFullVecReg = false; - switch (getLMUL(SubVecVT)) { + switch (RISCVTargetLowering::getLMUL(SubVecVT)) { default: break; case RISCVVLMUL::LMUL_1: @@ -915,10 +820,11 @@ const auto *TRI = Subtarget->getRegisterInfo(); unsigned SubRegIdx; std::tie(SubRegIdx, Idx) = - decomposeSubvectorInsertExtractToSubRegs(VT, SubVecVT, Idx, TRI); + RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs( + VT, SubVecVT, Idx, TRI); // If the Idx hasn't been completely eliminated then this is a subvector - // extract which doesn't naturally align to a vector register. These must + // insert which doesn't naturally align to a vector register. These must // be handled using instructions to manipulate the vector registers. if (Idx != 0) break; @@ -936,7 +842,7 @@ if (!Node->getOperand(0).isUndef()) break; - unsigned RegClassID = getRegClassIDForVecVT(VT); + unsigned RegClassID = RISCVTargetLowering::getRegClassIDForVecVT(VT); SDValue RC = CurDAG->getTargetConstant(RegClassID, DL, Subtarget->getXLenVT()); @@ -961,7 +867,8 @@ const auto *TRI = Subtarget->getRegisterInfo(); unsigned SubRegIdx; std::tie(SubRegIdx, Idx) = - decomposeSubvectorInsertExtractToSubRegs(InVT, VT, Idx, TRI); + RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs( + InVT, VT, Idx, TRI); // If the Idx hasn't been completely eliminated then this is a subvector // extract which doesn't naturally align to a vector register. These must @@ -972,8 +879,10 @@ // If we haven't set a SubRegIdx, then we must be going between LMUL<=1 // types (VR -> VR). This can be done as a copy. if (SubRegIdx == RISCV::NoSubRegister) { - unsigned InRegClassID = getRegClassIDForVecVT(InVT); - assert(getRegClassIDForVecVT(VT) == RISCV::VRRegClassID && + unsigned InRegClassID = + RISCVTargetLowering::getRegClassIDForVecVT(InVT); + assert(RISCVTargetLowering::getRegClassIDForVecVT(VT) == + RISCV::VRRegClassID && InRegClassID == RISCV::VRRegClassID && "Unexpected subvector extraction"); SDValue RC = @@ -993,7 +902,7 @@ if (Idx != 0) break; - unsigned InRegClassID = getRegClassIDForVecVT(InVT); + unsigned InRegClassID = RISCVTargetLowering::getRegClassIDForVecVT(InVT); SDValue RC = CurDAG->getTargetConstant(InRegClassID, DL, Subtarget->getXLenVT()); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -20,6 +20,7 @@ namespace llvm { class RISCVSubtarget; +struct RISCVRegisterInfo; namespace RISCVISD { enum NodeType : unsigned { FIRST_NUMBER = ISD::BUILTIN_OP_END, @@ -374,6 +375,15 @@ MachineMemOperand::Flags Flags = MachineMemOperand::MONone, bool *Fast = nullptr) const override; + static RISCVVLMUL getLMUL(MVT VT); + static unsigned getRegClassIDForLMUL(RISCVVLMUL LMul); + static unsigned getSubregIndexByMVT(MVT VT, unsigned Index); + static unsigned getRegClassIDForVecVT(MVT VT); + static std::pair + decomposeSubvectorInsertExtractToSubRegs(MVT VecVT, MVT SubVecVT, + unsigned InsertExtractIdx, + const RISCVRegisterInfo *TRI); + private: void analyzeInputArgs(MachineFunction &MF, CCState &CCInfo, const SmallVectorImpl &Ins, @@ -410,6 +420,7 @@ SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const; SDValue lowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFPVECREDUCE(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerEXTRACT_SUBVECTOR(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorStoreToRVV(SDValue Op, SelectionDAG &DAG) const; SDValue lowerFixedLengthVectorSetccToRVV(SDValue Op, SelectionDAG &DAG) const; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -464,6 +464,8 @@ setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom); setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); + + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); } // Expand various CCs to best match the RVV ISA, which natively supports UNE @@ -498,6 +500,8 @@ setOperationAction(ISD::VECREDUCE_FADD, VT, Custom); setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); setOperationAction(ISD::FCOPYSIGN, VT, Legal); + + setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom); }; if (Subtarget.hasStdExtZfh()) @@ -800,6 +804,108 @@ } } +RISCVVLMUL RISCVTargetLowering::getLMUL(MVT VT) { + switch (VT.getSizeInBits().getKnownMinValue() / 8) { + default: + llvm_unreachable("Invalid LMUL."); + case 1: + return RISCVVLMUL::LMUL_F8; + case 2: + return RISCVVLMUL::LMUL_F4; + case 4: + return RISCVVLMUL::LMUL_F2; + case 8: + return RISCVVLMUL::LMUL_1; + case 16: + return RISCVVLMUL::LMUL_2; + case 32: + return RISCVVLMUL::LMUL_4; + case 64: + return RISCVVLMUL::LMUL_8; + } +} + +unsigned RISCVTargetLowering::getRegClassIDForLMUL(RISCVVLMUL LMul) { + switch (LMul) { + default: + llvm_unreachable("Invalid LMUL."); + case RISCVVLMUL::LMUL_F8: + case RISCVVLMUL::LMUL_F4: + case RISCVVLMUL::LMUL_F2: + case RISCVVLMUL::LMUL_1: + return RISCV::VRRegClassID; + case RISCVVLMUL::LMUL_2: + return RISCV::VRM2RegClassID; + case RISCVVLMUL::LMUL_4: + return RISCV::VRM4RegClassID; + case RISCVVLMUL::LMUL_8: + return RISCV::VRM8RegClassID; + } +} + +unsigned RISCVTargetLowering::getSubregIndexByMVT(MVT VT, unsigned Index) { + RISCVVLMUL LMUL = getLMUL(VT); + if (LMUL == RISCVVLMUL::LMUL_F8 || LMUL == RISCVVLMUL::LMUL_F4 || + LMUL == RISCVVLMUL::LMUL_F2 || LMUL == RISCVVLMUL::LMUL_1) { + static_assert(RISCV::sub_vrm1_7 == RISCV::sub_vrm1_0 + 7, + "Unexpected subreg numbering"); + return RISCV::sub_vrm1_0 + Index; + } + if (LMUL == RISCVVLMUL::LMUL_2) { + static_assert(RISCV::sub_vrm2_3 == RISCV::sub_vrm2_0 + 3, + "Unexpected subreg numbering"); + return RISCV::sub_vrm2_0 + Index; + } + if (LMUL == RISCVVLMUL::LMUL_4) { + static_assert(RISCV::sub_vrm4_1 == RISCV::sub_vrm4_0 + 1, + "Unexpected subreg numbering"); + return RISCV::sub_vrm4_0 + Index; + } + llvm_unreachable("Invalid vector type."); +} + +unsigned RISCVTargetLowering::getRegClassIDForVecVT(MVT VT) { + if (VT.getVectorElementType() == MVT::i1) + return RISCV::VRRegClassID; + return getRegClassIDForLMUL(getLMUL(VT)); +} + +// Attempt to decompose a subvector insert/extract between VecVT and +// SubVecVT via subregister indices. Returns the subregister index that +// can perform the subvector insert/extract with the given element index, as +// well as the index corresponding to any leftover subvectors that must be +// further inserted/extracted within the register class for SubVecVT. +std::pair +RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs( + MVT VecVT, MVT SubVecVT, unsigned InsertExtractIdx, + const RISCVRegisterInfo *TRI) { + static_assert((RISCV::VRM8RegClassID > RISCV::VRM4RegClassID && + RISCV::VRM4RegClassID > RISCV::VRM2RegClassID && + RISCV::VRM2RegClassID > RISCV::VRRegClassID), + "Register classes not ordered"); + unsigned VecRegClassID = getRegClassIDForVecVT(VecVT); + unsigned SubRegClassID = getRegClassIDForVecVT(SubVecVT); + // Try to compose a subregister index that takes us from the incoming + // LMUL>1 register class down to the outgoing one. At each step we half + // the LMUL: + // nxv16i32@12 -> nxv2i32: sub_vrm4_1_then_sub_vrm2_1_then_sub_vrm1_0 + // Note that this is not guaranteed to find a subregister index, such as + // when we are extracting from one VR type to another. + unsigned SubRegIdx = RISCV::NoSubRegister; + for (const unsigned RCID : + {RISCV::VRM4RegClassID, RISCV::VRM2RegClassID, RISCV::VRRegClassID}) + if (VecRegClassID > RCID && SubRegClassID <= RCID) { + VecVT = VecVT.getHalfNumVectorElementsVT(); + bool IsHi = + InsertExtractIdx >= VecVT.getVectorElementCount().getKnownMinValue(); + SubRegIdx = TRI->composeSubRegIndices(SubRegIdx, + getSubregIndexByMVT(VecVT, IsHi)); + if (IsHi) + InsertExtractIdx -= VecVT.getVectorElementCount().getKnownMinValue(); + } + return {SubRegIdx, InsertExtractIdx}; +} + // Return the largest legal scalable vector type that matches VT's element type. static MVT getContainerForFixedLengthVector(SelectionDAG &DAG, MVT VT, const RISCVSubtarget &Subtarget) { @@ -1206,6 +1312,8 @@ case ISD::VECREDUCE_FADD: case ISD::VECREDUCE_SEQ_FADD: return lowerFPVECREDUCE(Op, DAG); + case ISD::EXTRACT_SUBVECTOR: + return lowerEXTRACT_SUBVECTOR(Op, DAG); case ISD::BUILD_VECTOR: return lowerBUILD_VECTOR(Op, DAG, Subtarget); case ISD::VECTOR_SHUFFLE: @@ -2129,6 +2237,70 @@ DAG.getConstant(0, DL, Subtarget.getXLenVT())); } +static MVT getLMUL1VT(MVT VT) { + assert(VT.getVectorElementType().getSizeInBits() <= 64 && + "Unexpected vector MVT"); + return MVT::getScalableVectorVT( + VT.getVectorElementType(), + RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits()); +} + +SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op, + SelectionDAG &DAG) const { + SDValue Vec = Op.getOperand(0); + MVT SubVecVT = Op.getSimpleValueType(); + MVT VecVT = Vec.getSimpleValueType(); + + // TODO: Only handle scalable->scalable extracts for now, and revisit this + // for fixed-length vectors later. + if (!SubVecVT.isScalableVector() || !VecVT.isScalableVector()) + return Op; + + SDLoc DL(Op); + unsigned OrigIdx = Op.getConstantOperandVal(1); + const RISCVRegisterInfo *TRI = Subtarget.getRegisterInfo(); + + unsigned SubRegIdx, RemIdx; + std::tie(SubRegIdx, RemIdx) = + RISCVTargetLowering::decomposeSubvectorInsertExtractToSubRegs( + VecVT, SubVecVT, OrigIdx, TRI); + + // If the Idx has been completely eliminated then this is a subvector extract + // which naturally aligns to a vector register. These can easily be handled + // using subregister manipulation. + if (RemIdx == 0) + return Op; + + // Else we must shift our vector register directly to extract the subvector. + // Do this using VSLIDEDOWN. + MVT XLenVT = Subtarget.getXLenVT(); + + // Extract a subvector equal to the nearest full vector register type. This + // should resolve to a EXTRACT_SUBREG instruction. + unsigned AlignedIdx = OrigIdx - RemIdx; + MVT InterSubVT = getLMUL1VT(VecVT); + SDValue AlignedExtract = + DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, InterSubVT, Vec, + DAG.getConstant(AlignedIdx, DL, XLenVT)); + + // Slide this vector register down by the desired number of elements in order + // to place the desired subvector starting at element 0. + SDValue SlidedownAmt = DAG.getConstant(RemIdx, DL, XLenVT); + // For scalable vectors this must be further multiplied by vscale. + SlidedownAmt = DAG.getNode(ISD::VSCALE, DL, XLenVT, SlidedownAmt); + + SDValue Mask, VL; + std::tie(Mask, VL) = getDefaultScalableVLOps(InterSubVT, DL, DAG, Subtarget); + SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN_VL, DL, InterSubVT, + DAG.getUNDEF(InterSubVT), AlignedExtract, + SlidedownAmt, Mask, VL); + + // Now the vector is in the right position, extract our final subvector. This + // should resolve to a COPY. + return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SubVecVT, Slidedown, + DAG.getConstant(0, DL, XLenVT)); +} + SDValue RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op, SelectionDAG &DAG) const { diff --git a/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll b/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll --- a/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll +++ b/llvm/test/CodeGen/RISCV/rvv/extract-subvector.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py -; RUN: llc -mtriple riscv64 -mattr=+experimental-v -verify-machineinstrs < %s | FileCheck %s +; RUN: llc -mtriple riscv64 -mattr=+m,+d,+experimental-zfh,+experimental-v -verify-machineinstrs < %s | FileCheck %s define @extract_nxv8i32_nxv4i32_0( %vec) { ; CHECK-LABEL: extract_nxv8i32_nxv4i32_0: @@ -190,13 +190,41 @@ ret %c } -; TODO: Extracts that don't align to a vector register are not yet supported. -; In this case we want to extract the upper half of the lowest VR subregister -; in the LMUL group. -; define @extract_nxv16i32_nxv1i32_1( %vec) { -; %c = call @llvm.experimental.vector.extract.nxv1i32.nxv16i32( %vec, i64 1) -; ret %c -; } +define @extract_nxv16i32_nxv1i32_1( %vec) { +; CHECK-LABEL: extract_nxv16i32_nxv1i32_1: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 3 +; CHECK-NEXT: vsetvli a1, zero, e32,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v8, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv1i32.nxv16i32( %vec, i64 1) + ret %c +} + +define @extract_nxv16i32_nxv1i32_3( %vec) { +; CHECK-LABEL: extract_nxv16i32_nxv1i32_3: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 3 +; CHECK-NEXT: vsetvli a1, zero, e32,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v9, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv1i32.nxv16i32( %vec, i64 3) + ret %c +} + +define @extract_nxv16i32_nxv1i32_15( %vec) { +; CHECK-LABEL: extract_nxv16i32_nxv1i32_15: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 3 +; CHECK-NEXT: vsetvli a1, zero, e32,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v15, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv1i32.nxv16i32( %vec, i64 15) + ret %c +} define @extract_nxv16i32_nxv1i32_2( %vec) { ; CHECK-LABEL: extract_nxv16i32_nxv1i32_2: @@ -215,6 +243,124 @@ ret %c } +define @extract_nxv32i8_nxv2i8_0( %vec) { +; CHECK-LABEL: extract_nxv32i8_nxv2i8_0: +; CHECK: # %bb.0: +; CHECK-NEXT: # kill: def $v8 killed $v8 killed $v8m4 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2i8.nxv32i8( %vec, i64 0) + ret %c +} + +define @extract_nxv32i8_nxv2i8_2( %vec) { +; CHECK-LABEL: extract_nxv32i8_nxv2i8_2: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 2 +; CHECK-NEXT: vsetvli a1, zero, e8,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v8, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2i8.nxv32i8( %vec, i64 2) + ret %c +} + +define @extract_nxv32i8_nxv2i8_4( %vec) { +; CHECK-LABEL: extract_nxv32i8_nxv2i8_4: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 1 +; CHECK-NEXT: vsetvli a1, zero, e8,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v8, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2i8.nxv32i8( %vec, i64 4) + ret %c +} + +define @extract_nxv32i8_nxv2i8_6( %vec) { +; CHECK-LABEL: extract_nxv32i8_nxv2i8_6: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 3 +; CHECK-NEXT: addi a1, zero, 6 +; CHECK-NEXT: mul a0, a0, a1 +; CHECK-NEXT: vsetvli a1, zero, e8,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v8, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2i8.nxv32i8( %vec, i64 6) + ret %c +} + +define @extract_nxv32i8_nxv2i8_8( %vec) { +; CHECK-LABEL: extract_nxv32i8_nxv2i8_8: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2i8.nxv32i8( %vec, i64 8) + ret %c +} + +define @extract_nxv32i8_nxv2i8_22( %vec) { +; CHECK-LABEL: extract_nxv32i8_nxv2i8_22: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 3 +; CHECK-NEXT: addi a1, zero, 6 +; CHECK-NEXT: mul a0, a0, a1 +; CHECK-NEXT: vsetvli a1, zero, e8,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v10, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2i8.nxv32i8( %vec, i64 22) + ret %c +} + +define @extract_nxv8i8_nxv1i8_7( %vec) { +; CHECK-LABEL: extract_nxv8i8_nxv1i8_7: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 3 +; CHECK-NEXT: slli a1, a0, 3 +; CHECK-NEXT: sub a0, a1, a0 +; CHECK-NEXT: vsetvli a1, zero, e8,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v8, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv1i8.nxv8i8( %vec, i64 7) + ret %c +} + +define @extract_nxv2f16_nxv16f16_0( %vec) { +; CHECK-LABEL: extract_nxv2f16_nxv16f16_0: +; CHECK: # %bb.0: +; CHECK-NEXT: # kill: def $v8 killed $v8 killed $v8m4 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2f16.nxv16f16( %vec, i64 0) + ret %c +} + +define @extract_nxv2f16_nxv16f16_2( %vec) { +; CHECK-LABEL: extract_nxv2f16_nxv16f16_2: +; CHECK: # %bb.0: +; CHECK-NEXT: csrr a0, vlenb +; CHECK-NEXT: srli a0, a0, 2 +; CHECK-NEXT: vsetvli a1, zero, e16,m1,ta,mu +; CHECK-NEXT: vslidedown.vx v8, v8, a0 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2f16.nxv16f16( %vec, i64 2) + ret %c +} + +define @extract_nxv2f16_nxv16f16_4( %vec) { +; CHECK-LABEL: extract_nxv2f16_nxv16f16_4: +; CHECK: # %bb.0: +; CHECK-NEXT: vmv1r.v v8, v9 +; CHECK-NEXT: ret + %c = call @llvm.experimental.vector.extract.nxv2f16.nxv16f16( %vec, i64 4) + ret %c +} + +declare @llvm.experimental.vector.extract.nxv1i8.nxv8i8( %vec, i64 %idx) + +declare @llvm.experimental.vector.extract.nxv2i8.nxv32i8( %vec, i64 %idx) + declare @llvm.experimental.vector.extract.nxv1i32.nxv2i32( %vec, i64 %idx) declare @llvm.experimental.vector.extract.nxv2i32.nxv8i32( %vec, i64 %idx) @@ -224,3 +370,5 @@ declare @llvm.experimental.vector.extract.nxv2i32.nxv16i32( %vec, i64 %idx) declare @llvm.experimental.vector.extract.nxv4i32.nxv16i32( %vec, i64 %idx) declare @llvm.experimental.vector.extract.nxv8i32.nxv16i32( %vec, i64 %idx) + +declare @llvm.experimental.vector.extract.nxv2f16.nxv16f16( %vec, i64 %idx)