diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -435,6 +435,7 @@ void insertVSETVLI(MachineBasicBlock &MBB, MachineInstr &MI, const VSETVLIInfo &Info, const VSETVLIInfo &PrevInfo); + void forwardPropagateVLMAX(MachineBasicBlock &MBB); bool computeVLVTYPEChanges(const MachineBasicBlock &MBB); void computeIncomingVLVTYPE(const MachineBasicBlock &MBB); void emitVSETVLIs(MachineBasicBlock &MBB); @@ -453,6 +454,11 @@ MI.getOpcode() == RISCV::PseudoVSETIVLI; } +static bool isVLMaxVSETVLI(const MachineInstr &MI) { + return MI.getOpcode() == RISCV::PseudoVSETVLIX0 && + MI.getOperand(0).getReg() != RISCV::X0; +} + static MachineInstr *elideCopies(MachineInstr *MI, const MachineRegisterInfo *MRI) { while (true) { @@ -887,6 +893,49 @@ return CurInfo.isCompatibleWithLoadStoreEEW(EEW, Require); } +// Forward propagate VLMAX vsetvlis. +void RISCVInsertVSETVLI::forwardPropagateVLMAX(MachineBasicBlock &MBB) { + for (const MachineInstr &MI : MBB) { + // Look for VLMAX vsetvlis. + if (!isVLMaxVSETVLI(MI)) + continue; + + VSETVLIInfo VI = getInfoForVSETVLI(MI); + const MachineOperand &DestVLOp = MI.getOperand(0); + Register DestVLReg = DestVLOp.getReg(); + // Walk through all uses of this VL. + for (MachineRegisterInfo::use_nodbg_iterator + UI = MRI->use_nodbg_begin(DestVLReg), + UIEnd = MRI->use_nodbg_end(); + UI != UIEnd;) { + MachineOperand &Use(*UI++); + assert(Use.getParent() != nullptr && "Expected parent instruction"); + const MachineInstr &UseMI = *Use.getParent(); + const unsigned UseIndex = UseMI.getOperandNo(&Use); + + uint64_t TSFlags = UseMI.getDesc().TSFlags; + if (!RISCVII::hasSEWOp(TSFlags) || !RISCVII::hasVLOp(TSFlags)) + continue; + + unsigned NumOperands = UseMI.getNumExplicitOperands(); + if (RISCVII::hasVecPolicyOp(TSFlags)) + --NumOperands; + + // NumOperands - 2 == VLOpIndex + if (UseIndex != (NumOperands - 2)) + continue; + + VSETVLIInfo UseInfo = computeInfoForInstr(UseMI, TSFlags, MRI); + if (!UseInfo.hasSameVLMAX(VI)) + continue; + + Use.setReg(RISCV::X0); + // TODO: Should we update the dead flag or remove the instruction if + // we propagated to all users? + } + } +} + bool RISCVInsertVSETVLI::computeVLVTYPEChanges(const MachineBasicBlock &MBB) { bool HadVectorOp = false; @@ -1143,6 +1192,10 @@ assert(BlockInfo.empty() && "Expect empty block infos"); BlockInfo.resize(MF.getNumBlockIDs()); + // Phase 0 - propagate VLMAX vsetvlis. + for (MachineBasicBlock &MBB : MF) + forwardPropagateVLMAX(MBB); + bool HaveVectorOp = false; // Phase 1 - determine how VL/VTYPE are affected by the each block. diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll --- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll @@ -591,21 +591,20 @@ ; CHECK-NEXT: blez a0, .LBB11_3 ; CHECK-NEXT: # %bb.1: # %for.body.preheader ; CHECK-NEXT: li a5, 0 -; CHECK-NEXT: li t1, 0 +; CHECK-NEXT: li t0, 0 ; CHECK-NEXT: slli a7, a6, 3 ; CHECK-NEXT: .LBB11_2: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: add t0, a2, a5 -; CHECK-NEXT: vsetvli zero, a6, e64, m1, ta, mu -; CHECK-NEXT: vle64.v v8, (t0) +; CHECK-NEXT: add a4, a2, a5 +; CHECK-NEXT: vle64.v v8, (a4) ; CHECK-NEXT: add a4, a3, a5 ; CHECK-NEXT: vle64.v v9, (a4) ; CHECK-NEXT: vfadd.vv v8, v8, v9 ; CHECK-NEXT: add a4, a1, a5 ; CHECK-NEXT: vse64.v v8, (a4) -; CHECK-NEXT: add t1, t1, a6 +; CHECK-NEXT: add t0, t0, a6 ; CHECK-NEXT: add a5, a5, a7 -; CHECK-NEXT: blt t1, a0, .LBB11_2 +; CHECK-NEXT: blt t0, a0, .LBB11_2 ; CHECK-NEXT: .LBB11_3: # %for.end ; CHECK-NEXT: ret entry: