Index: llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h =================================================================== --- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h +++ llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h @@ -135,6 +135,7 @@ bool doPeepholeMergeVVMFold(); bool performVMergeToVAdd(SDNode *N); bool performCombineVMergeAndVOps(SDNode *N, bool IsTA); + SDNode *tryShrinkVLForVMV(SDNode *Node); }; namespace RISCV { Index: llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -56,6 +56,32 @@ return getLastNonGlueOrChainOpIdx(Node); } +static unsigned getSEWOpIdx(const SDNode *Node, const MCInstrDesc &MCID) { + assert(RISCVII::hasSEWOp(MCID.TSFlags)); + unsigned SEWOpIdx = getLastNonGlueOrChainOpIdx(Node); + if (RISCVII::hasVecPolicyOp(MCID.TSFlags)) + --SEWOpIdx; + return SEWOpIdx; +} + +static unsigned getVLOpIdx(const SDNode *Node, const MCInstrDesc &MCID) { + assert(RISCVII::hasVLOp(MCID.TSFlags) && RISCVII::hasSEWOp(MCID.TSFlags)); + // Instruction with VL operand also has SEW that is right after it. + return getSEWOpIdx(Node, MCID) - 1; +} + +static unsigned getSEWOp(const SDNode *Node, const MCInstrDesc &MCID) { + assert(RISCVII::hasSEWOp(MCID.TSFlags)); + unsigned Log2SEW = Node->getConstantOperandVal(getSEWOpIdx(Node, MCID)); + unsigned SEW = Log2SEW ? 1 << Log2SEW : 8; + return SEW; +} + +static SDValue getVLOperand(const SDNode *Node, const MCInstrDesc &MCID) { + assert(RISCVII::hasVLOp(MCID.TSFlags)); + return Node->getOperand(getVLOpIdx(Node, MCID)); +} + void RISCVDAGToDAGISel::PreprocessISelDAG() { SelectionDAG::allnodes_iterator Position = CurDAG->allnodes_end(); @@ -1786,10 +1812,18 @@ ReplaceNode(Node, Extract.getNode()); return; } - case RISCVISD::VMV_S_X_VL: - case RISCVISD::VFMV_S_F_VL: case RISCVISD::VMV_V_X_VL: case RISCVISD::VFMV_V_F_VL: { + // Try to shrink VL for a splat-like move. + SDNode *UpdatedNode = tryShrinkVLForVMV(Node); + if (UpdatedNode != Node) { + ReplaceNode(Node, UpdatedNode); + return; + } + [[fallthrough]]; + } + case RISCVISD::VMV_S_X_VL: + case RISCVISD::VFMV_S_F_VL: { // Try to match splat of a scalar load to a strided load with stride of x0. bool IsScalarMove = Node->getOpcode() == RISCVISD::VMV_S_X_VL || Node->getOpcode() == RISCVISD::VFMV_S_F_VL; @@ -2476,6 +2510,91 @@ return false; } + +static bool isVLMax(SDValue VL) { + if (auto *Constant = dyn_cast(VL)) + return Constant->getSExtValue() == RISCV::VLMaxSentinel; + auto *RegVL = dyn_cast(VL); + return RegVL && RegVL->getReg() == RISCV::X0; +} + +static bool isVLLessThan(SDValue VL1, SDValue VL2) { + assert(VL1 && VL2); + if (isVLMax(VL1)) + return false; + if (isVLMax(VL2)) + return true; + auto *ConstantVL1 = dyn_cast(VL1); + auto *ConstantVL2 = dyn_cast(VL2); + if (!ConstantVL1 || !ConstantVL2) + // Cannot compare reg-reg/constant-reg/reg-constant cases apart from X0 + // and VLMaxSentinel that are handled above. + return false; + return ConstantVL1->getSExtValue() < ConstantVL2->getSExtValue(); +} + +/// Returns true if the user instruction has a VL operand, and is +/// known to demand only that number of lanes from this input use. +static bool allowsVLShrinking(const SDUse &Use) { + const SDNode *User = Use.getUser(); + if (!User->isMachineOpcode()) + return false; + + // A VSE instruction doesn't have a merge operand, and doesn't + // read past VL at all. That makes it a simple case to start with. + const RISCVVPseudosTable::PseudoInfo *RVV = + RISCVVPseudosTable::getPseudoInfo(User->getMachineOpcode()); + if (!RVV) + return false; + switch (RVV->BaseInstr) { + default: + return false; + case RISCV::VSE8_V: + case RISCV::VSE16_V: + case RISCV::VSE32_V: + case RISCV::VSE64_V: + return true; + } +} + +// Analyzes users of a splat-like VMV/VFMV instruction and chooses the minimal +// possible VL. +SDNode *RISCVDAGToDAGISel::tryShrinkVLForVMV(SDNode *Node) { + const RISCVInstrInfo &TII = *Subtarget->getInstrInfo(); + + // FIXME: this can be profitable for the moves with multiple uses as well. + if (!Node->hasOneUse()) + return Node; + const SDNode::use_iterator UI = Node->use_begin(); + if (!allowsVLShrinking(UI.getUse())) + return Node; + + const SDNode *User = *UI; + const MCInstrDesc &UserMCID = TII.get(User->getMachineOpcode()); + + // If SEW or LMUL differs, then VL values may not be comparable + MVT VT = Node->getSimpleValueType(0); + const unsigned SEW = VT.getScalarSizeInBits(); + const unsigned UserSEW = getSEWOp(User, UserMCID); + if (SEW != UserSEW) + return Node; + + RISCVII::VLMUL LMUL = RISCVTargetLowering::getLMUL(VT); + RISCVII::VLMUL UserLMUL = RISCVII::getLMul(UserMCID.TSFlags); + if (LMUL != UserLMUL) + return Node; + + SDValue VL = getVLOperand(User, UserMCID); + SDValue OldVL = Node->getOperand(Node->getNumOperands() - 1); + if (!isVLLessThan(VL, OldVL)) + return Node; + + // MergeOp, Src, VL. + SmallVector Ops(Node->op_begin(), Node->op_end()); + Ops[Node->getNumOperands() - 1] = VL; + return CurDAG->UpdateNodeOperands(Node, Ops); +} + // Try to remove sext.w if the input is a W instruction or can be made into // a W instruction cheaply. bool RISCVDAGToDAGISel::doPeepholeSExtW(SDNode *N) { Index: llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll =================================================================== --- llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll +++ llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll @@ -681,7 +681,6 @@ ; CHECK-NEXT: # %bb.1: # %for.body.preheader ; CHECK-NEXT: li a3, 0 ; CHECK-NEXT: slli a4, a2, 3 -; CHECK-NEXT: vsetvli a5, zero, e64, m1, ta, ma ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB13_2: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 @@ -717,11 +716,10 @@ ; CHECK-NEXT: li a2, 0 ; CHECK-NEXT: vsetivli a3, 4, e64, m1, ta, mu ; CHECK-NEXT: slli a4, a3, 3 -; CHECK-NEXT: vsetvli a5, zero, e64, m1, ta, ma ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB14_1: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, ma +; CHECK-NEXT: vsetvli zero, zero, e64, m1, ta, ma ; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: add a2, a2, a3 ; CHECK-NEXT: add a1, a1, a4 @@ -751,11 +749,11 @@ ; CHECK-LABEL: vector_init_vsetvli_fv2: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: li a2, 0 -; CHECK-NEXT: vsetvli a3, zero, e64, m1, ta, ma +; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, mu ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB15_1: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, ma +; CHECK-NEXT: vsetvli zero, zero, e64, m1, ta, ma ; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: addi a2, a2, 4 ; CHECK-NEXT: addi a1, a1, 32 @@ -785,11 +783,10 @@ ; CHECK-LABEL: vector_init_vsetvli_fv3: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: li a2, 0 -; CHECK-NEXT: vsetvli a3, zero, e64, m1, ta, ma +; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, ma ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB16_1: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, ma ; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: addi a2, a2, 4 ; CHECK-NEXT: addi a1, a1, 32