diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -393,6 +393,12 @@ return getPointerTy(DL); } + /// Returns the type to be used for the EVL/AVL operand of VP nodes: + /// ISD::VP_ADD, ISD::VP_SUB, etc. It must be a legal scalar integer type, + /// and must be at least as large as i32. The EVL is implicitly zero-extended + /// to any larger type. + virtual MVT getVPExplicitVectorLengthTy() const { return MVT::i32; } + /// This callback is used to inspect load/store instructions and add /// target-specific MachineMemOperand flags to them. The default /// implementation does nothing. diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -7287,6 +7287,7 @@ void SelectionDAGBuilder::visitVectorPredicationIntrinsic( const VPIntrinsic &VPIntrin) { + SDLoc DL = getCurSDLoc(); unsigned Opcode = getISDForVPIntrinsic(VPIntrin); SmallVector ValueVTs; @@ -7294,12 +7295,22 @@ ComputeValueVTs(TLI, DAG.getDataLayout(), VPIntrin.getType(), ValueVTs); SDVTList VTs = DAG.getVTList(ValueVTs); + auto EVLParamPos = + VPIntrinsic::GetVectorLengthParamPos(VPIntrin.getIntrinsicID()); + + MVT EVLParamVT = TLI.getVPExplicitVectorLengthTy(); + assert(EVLParamVT.isScalarInteger() && EVLParamVT.bitsGE(MVT::i32) && + "Unexpected target EVL type"); + // Request operands. SmallVector OpValues; - for (int i = 0; i < (int)VPIntrin.getNumArgOperands(); ++i) - OpValues.push_back(getValue(VPIntrin.getArgOperand(i))); + for (int I = 0; I < (int)VPIntrin.getNumArgOperands(); ++I) { + auto Op = getValue(VPIntrin.getArgOperand(I)); + if (I == EVLParamPos) + Op = DAG.getNode(ISD::ZERO_EXTEND, DL, EVLParamVT, Op); + OpValues.push_back(Op); + } - SDLoc DL = getCurSDLoc(); SDValue Result = DAG.getNode(Opcode, DL, VTs, OpValues); setValue(&VPIntrin, Result); } diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -558,6 +558,8 @@ bool useRVVForFixedLengthVectorVT(MVT VT) const; + MVT getVPExplicitVectorLengthTy() const override; + /// RVV code generation for fixed length vectors does not lower all /// BUILD_VECTORs. This makes BUILD_VECTOR legalisation a source of stores to /// merge. However, merging them creates a BUILD_VECTOR that is just as diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -505,12 +505,8 @@ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); - for (unsigned VPOpc : IntegerVPOps) { + for (unsigned VPOpc : IntegerVPOps) setOperationAction(VPOpc, VT, Custom); - // RV64 must custom-legalize the i32 EVL parameter. - if (Subtarget.is64Bit()) - setOperationAction(VPOpc, MVT::i32, Custom); - } setOperationAction(ISD::MLOAD, VT, Custom); setOperationAction(ISD::MSTORE, VT, Custom); @@ -721,12 +717,8 @@ setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom); setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom); - for (unsigned VPOpc : IntegerVPOps) { + for (unsigned VPOpc : IntegerVPOps) setOperationAction(VPOpc, VT, Custom); - // RV64 must custom-legalize the i32 EVL parameter. - if (Subtarget.is64Bit()) - setOperationAction(VPOpc, MVT::i32, Custom); - } } for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) { @@ -831,6 +823,10 @@ return VT.changeVectorElementTypeToInteger(); } +MVT RISCVTargetLowering::getVPExplicitVectorLengthTy() const { + return Subtarget.getXLenVT(); +} + bool RISCVTargetLowering::getTgtMemIntrinsic(IntrinsicInfo &Info, const CallInst &I, MachineFunction &MF, @@ -4315,17 +4311,10 @@ unsigned RISCVISDOpc) const { SDLoc DL(Op); MVT VT = Op.getSimpleValueType(); - Optional EVLIdx = ISD::getVPExplicitVectorLengthIdx(Op.getOpcode()); - SmallVector Ops; - MVT XLenVT = Subtarget.getXLenVT(); for (const auto &OpIdx : enumerate(Op->ops())) { SDValue V = OpIdx.value(); - if ((unsigned)OpIdx.index() == EVLIdx) { - Ops.push_back(DAG.getZExtOrTrunc(V, DL, XLenVT)); - continue; - } assert(!isa(V) && "Unexpected VTSDNode node!"); // Pass through operands which aren't fixed-length vectors. if (!V.getValueType().isFixedLengthVector()) {