diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -435,6 +435,7 @@ void insertVSETVLI(MachineBasicBlock &MBB, MachineInstr &MI, const VSETVLIInfo &Info, const VSETVLIInfo &PrevInfo); + void forwardPropagateVLMAX(MachineBasicBlock &MBB); bool computeVLVTYPEChanges(const MachineBasicBlock &MBB); void computeIncomingVLVTYPE(const MachineBasicBlock &MBB); void emitVSETVLIs(MachineBasicBlock &MBB); @@ -887,6 +888,44 @@ return CurInfo.isCompatibleWithLoadStoreEEW(EEW, Require); } +// Forward propagate VLMAX vsetvlis. +void RISCVInsertVSETVLI::forwardPropagateVLMAX(MachineBasicBlock &MBB) { + for (const MachineInstr &MI : MBB) { + // Look for VLMAX vsetvlis. + if (!(isVectorConfigInstr(MI) && MI.getOperand(1).isReg() && + MI.getOperand(1).getReg() == RISCV::X0 && + MI.getOperand(0).getReg() != RISCV::X0)) + continue; + + const MachineOperand &DestVLOp = MI.getOperand(0); + Register DestVLReg = DestVLOp.getReg(); + // Walk through all uses of this VL. + for (MachineRegisterInfo::use_nodbg_iterator + UI = MRI->use_nodbg_begin(DestVLReg), + UIEnd = MRI->use_nodbg_end(); + UI != UIEnd;) { + MachineOperand &Use(*UI++); + assert(Use.getParent() != nullptr && "Expected parent instruction"); + const MachineInstr &UseMI = *Use.getParent(); + const unsigned UseIndex = UseMI.getOperandNo(&Use); + + uint64_t TSFlags = UseMI.getDesc().TSFlags; + if (RISCVII::hasSEWOp(TSFlags) && RISCVII::hasVLOp(TSFlags)) { + unsigned NumOperands = UseMI.getNumExplicitOperands(); + if (RISCVII::hasVecPolicyOp(TSFlags)) + --NumOperands; + + // NumOperands - 2 == VLOpIndex + if (UseIndex == (NumOperands - 2)) { + Use.setReg(RISCV::X0); + // TODO: Should we update the dead flag or remove the instruction if + // we propagated to all users? + } + } + } + } +} + bool RISCVInsertVSETVLI::computeVLVTYPEChanges(const MachineBasicBlock &MBB) { bool HadVectorOp = false; @@ -1143,6 +1182,10 @@ assert(BlockInfo.empty() && "Expect empty block infos"); BlockInfo.resize(MF.getNumBlockIDs()); + // Phase 0 - propagate VLMAX vsetvlis. + for (MachineBasicBlock &MBB : MF) + forwardPropagateVLMAX(MBB); + bool HaveVectorOp = false; // Phase 1 - determine how VL/VTYPE are affected by the each block. diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll --- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll @@ -591,21 +591,20 @@ ; CHECK-NEXT: blez a0, .LBB11_3 ; CHECK-NEXT: # %bb.1: # %for.body.preheader ; CHECK-NEXT: li a5, 0 -; CHECK-NEXT: li t1, 0 +; CHECK-NEXT: li t0, 0 ; CHECK-NEXT: slli a7, a6, 3 ; CHECK-NEXT: .LBB11_2: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: add t0, a2, a5 -; CHECK-NEXT: vsetvli zero, a6, e64, m1, ta, mu -; CHECK-NEXT: vle64.v v8, (t0) +; CHECK-NEXT: add a4, a2, a5 +; CHECK-NEXT: vle64.v v8, (a4) ; CHECK-NEXT: add a4, a3, a5 ; CHECK-NEXT: vle64.v v9, (a4) ; CHECK-NEXT: vfadd.vv v8, v8, v9 ; CHECK-NEXT: add a4, a1, a5 ; CHECK-NEXT: vse64.v v8, (a4) -; CHECK-NEXT: add t1, t1, a6 +; CHECK-NEXT: add t0, t0, a6 ; CHECK-NEXT: add a5, a5, a7 -; CHECK-NEXT: blt t1, a0, .LBB11_2 +; CHECK-NEXT: blt t0, a0, .LBB11_2 ; CHECK-NEXT: .LBB11_3: # %for.end ; CHECK-NEXT: ret entry: