diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -1160,13 +1160,15 @@ // with current VL/VTYPE. bool NeedInsertVSETVLI = true; if (PrevVSETVLIMI) { - bool HasSameAVL = - CurInfo.hasSameAVL(NewInfo) || - (NewInfo.hasAVLReg() && NewInfo.getAVLReg().isVirtual() && - NewInfo.getAVLReg() == PrevVSETVLIMI->getOperand(0).getReg()); // If these two VSETVLI have the same AVL and the same VLMAX, // we could merge these two VSETVLI. - if (HasSameAVL && CurInfo.hasSameVLMAX(NewInfo)) { + // TODO: If we remove this, we get a `vsetvli x0, x0, vtype' + // here. We could simply let this be emitted, then remove + // the unused vsetvlis in a post-pass. + if (CurInfo.hasSameAVL(NewInfo) && CurInfo.hasSameVLMAX(NewInfo)) { + // WARNING: For correctness, it is essential the contents of VL + // and VTYPE stay the same after MI. This greatly limits the + // mutation we can legally do here. PrevVSETVLIMI->getOperand(2).setImm(NewInfo.encodeVTYPE()); NeedInsertVSETVLI = false; } @@ -1248,6 +1250,32 @@ } if (RISCVII::hasSEWOp(TSFlags)) { + if (RISCVII::hasVLOp(TSFlags)) { + const auto Require = computeInfoForInstr(MI, TSFlags, MRI); + // If the AVL is the result of a previous vsetvli which has the + // same AVL and VLMAX as our current state, we can reuse the AVL + // from the current state for the new one. This allows us to + // generate 'vsetvli x0, x0, vtype" or possible skip the transition + // entirely. + if (!CurInfo.isUnknown() && Require.hasAVLReg() && + Require.getAVLReg().isVirtual()) { + if (MachineInstr *DefMI = MRI->getVRegDef(Require.getAVLReg())) { + if (isVectorConfigInstr(*DefMI)) { + VSETVLIInfo DefInfo = getInfoForVSETVLI(*DefMI); + if (DefInfo.hasSameAVL(CurInfo) && + DefInfo.hasSameVLMAX(CurInfo)) { + MachineOperand &VLOp = MI.getOperand(getVLOpNum(MI)); + if (CurInfo.hasAVLImm()) + VLOp.ChangeToImmediate(CurInfo.getAVLImm()); + else + VLOp.ChangeToRegister(CurInfo.getAVLReg(), /*IsDef*/ false); + CurInfo = computeInfoForInstr(MI, TSFlags, MRI); + continue; + } + } + } + } + } CurInfo = computeInfoForInstr(MI, TSFlags, MRI); continue; } diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll --- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll @@ -281,10 +281,9 @@ define @test15(i64 %avl, %a, %b) nounwind { ; CHECK-LABEL: test15: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetvli a0, a0, e64, m1, ta, mu +; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, mu ; CHECK-NEXT: vfadd.vv v8, v8, v9 ; CHECK-NEXT: vfadd.vv v8, v8, v9 -; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, mu ; CHECK-NEXT: ret entry: %vsetvli = tail call i64 @llvm.riscv.vsetvli(i64 %avl, i64 2, i64 7) @@ -354,12 +353,12 @@ define @test18( %a, double %b) nounwind { ; CHECK-LABEL: test18: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli a0, 6, e64, m1, tu, mu +; CHECK-NEXT: vsetivli zero, 6, e64, m1, tu, mu ; CHECK-NEXT: vmv1r.v v9, v8 ; CHECK-NEXT: vfmv.s.f v9, fa0 -; CHECK-NEXT: vsetvli zero, a0, e64, m1, ta, mu +; CHECK-NEXT: vsetvli zero, zero, e64, m1, ta, mu ; CHECK-NEXT: vfadd.vv v8, v8, v8 -; CHECK-NEXT: vsetivli zero, 1, e64, m1, tu, mu +; CHECK-NEXT: vsetvli zero, zero, e64, m1, tu, mu ; CHECK-NEXT: vfmv.s.f v8, fa0 ; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, mu ; CHECK-NEXT: vfadd.vv v8, v9, v8