diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -435,6 +435,7 @@ void insertVSETVLI(MachineBasicBlock &MBB, MachineInstr &MI, const VSETVLIInfo &Info, const VSETVLIInfo &PrevInfo); + void forwardPropagateVLMAX(MachineBasicBlock &MBB); bool computeVLVTYPEChanges(const MachineBasicBlock &MBB); void computeIncomingVLVTYPE(const MachineBasicBlock &MBB); void emitVSETVLIs(MachineBasicBlock &MBB); @@ -887,6 +888,50 @@ return CurInfo.isCompatibleWithLoadStoreEEW(EEW, Require); } +// Forward propagate VLMAX vsetvlis. +void RISCVInsertVSETVLI::forwardPropagateVLMAX(MachineBasicBlock &MBB) { + for (const MachineInstr &MI : MBB) { + // Look for VLMAX vsetvlis. + if (MI.getOpcode() != RISCV::PseudoVSETVLIX0 || + MI.getOperand(0).getReg() == RISCV::X0) + continue; + + VSETVLIInfo VI = getInfoForVSETVLI(MI); + const MachineOperand &DestVLOp = MI.getOperand(0); + Register DestVLReg = DestVLOp.getReg(); + // Walk through all uses of this VL. + for (MachineRegisterInfo::use_nodbg_iterator + UI = MRI->use_nodbg_begin(DestVLReg), + UIEnd = MRI->use_nodbg_end(); + UI != UIEnd;) { + MachineOperand &Use(*UI++); + assert(Use.getParent() != nullptr && "Expected parent instruction"); + const MachineInstr &UseMI = *Use.getParent(); + const unsigned UseIndex = UseMI.getOperandNo(&Use); + + bool Propagate = false; + uint64_t TSFlags = UseMI.getDesc().TSFlags; + if (RISCVII::hasSEWOp(TSFlags) && RISCVII::hasVLOp(TSFlags)) { + unsigned NumOperands = UseMI.getNumExplicitOperands(); + if (RISCVII::hasVecPolicyOp(TSFlags)) + --NumOperands; + + // NumOperands - 2 == VLOpIndex + if (UseIndex == (NumOperands - 2)) { + VSETVLIInfo UseInfo = computeInfoForInstr(UseMI, TSFlags, MRI); + Propagate = UseInfo.hasSameVLMAX(VI); + } + } + + if (Propagate) + Use.setReg(RISCV::X0); + + // TODO: Should we update the dead flag or remove the instruction if we + // propagated to all users? + } + } +} + bool RISCVInsertVSETVLI::computeVLVTYPEChanges(const MachineBasicBlock &MBB) { bool HadVectorOp = false; @@ -1143,6 +1188,10 @@ assert(BlockInfo.empty() && "Expect empty block infos"); BlockInfo.resize(MF.getNumBlockIDs()); + // Phase 0 - propagate AVL when VLMAX is the same + for (MachineBasicBlock &MBB : MF) + forwardPropagateVLMAX(MBB); + bool HaveVectorOp = false; // Phase 1 - determine how VL/VTYPE are affected by the each block. diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll --- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll @@ -582,3 +582,57 @@ ret %h } declare @llvm.riscv.vwadd.w.nxv2i32.nxv2i16(, , , i64) + +; We should only need 1 vsetvli for this code. +define void @vlmax(i64 %N, double* %c, double* %a, double* %b) { +; CHECK-LABEL: vlmax: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a6, zero, e64, m1, ta, mu +; CHECK-NEXT: blez a0, .LBB11_3 +; CHECK-NEXT: # %bb.1: # %for.body.preheader +; CHECK-NEXT: li a5, 0 +; CHECK-NEXT: li t0, 0 +; CHECK-NEXT: slli a7, a6, 3 +; CHECK-NEXT: .LBB11_2: # %for.body +; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: add a4, a2, a5 +; CHECK-NEXT: vle64.v v8, (a4) +; CHECK-NEXT: add a4, a3, a5 +; CHECK-NEXT: vle64.v v9, (a4) +; CHECK-NEXT: vfadd.vv v8, v8, v9 +; CHECK-NEXT: add a4, a1, a5 +; CHECK-NEXT: vse64.v v8, (a4) +; CHECK-NEXT: add t0, t0, a6 +; CHECK-NEXT: add a5, a5, a7 +; CHECK-NEXT: blt t0, a0, .LBB11_2 +; CHECK-NEXT: .LBB11_3: # %for.end +; CHECK-NEXT: ret +entry: + %0 = tail call i64 @llvm.riscv.vsetvlimax.i64(i64 3, i64 0) + %cmp13 = icmp sgt i64 %N, 0 + br i1 %cmp13, label %for.body, label %for.end + +for.body: ; preds = %entry, %for.body + %i.014 = phi i64 [ %add, %for.body ], [ 0, %entry ] + %arrayidx = getelementptr inbounds double, double* %a, i64 %i.014 + %1 = bitcast double* %arrayidx to * + %2 = tail call @llvm.riscv.vle.nxv1f64.i64( undef, * %1, i64 %0) + %arrayidx1 = getelementptr inbounds double, double* %b, i64 %i.014 + %3 = bitcast double* %arrayidx1 to * + %4 = tail call @llvm.riscv.vle.nxv1f64.i64( undef, * %3, i64 %0) + %5 = tail call @llvm.riscv.vfadd.nxv1f64.nxv1f64.i64( undef, %2, %4, i64 %0) + %arrayidx2 = getelementptr inbounds double, double* %c, i64 %i.014 + %6 = bitcast double* %arrayidx2 to * + tail call void @llvm.riscv.vse.nxv1f64.i64( %5, * %6, i64 %0) + %add = add nuw nsw i64 %i.014, %0 + %cmp = icmp slt i64 %add, %N + br i1 %cmp, label %for.body, label %for.end + +for.end: ; preds = %for.body, %entry + ret void +} + +declare i64 @llvm.riscv.vsetvlimax.i64(i64, i64) +declare @llvm.riscv.vle.nxv1f64.i64(, * nocapture, i64) +declare @llvm.riscv.vfadd.nxv1f64.nxv1f64.i64(, , , i64) +declare void @llvm.riscv.vse.nxv1f64.i64(, * nocapture, i64)