diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -28,6 +28,7 @@ #include "llvm/CodeGen/LiveIntervals.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include +#include using namespace llvm; #define DEBUG_TYPE "riscv-insert-vsetvli" @@ -49,9 +50,12 @@ Uninitialized, AVLIsReg, AVLIsImm, + VLModified, Unknown, } State = Uninitialized; + const MachineInstr *VLModifier; + // Fields from VTYPE. RISCVII::VLMUL VLMul = RISCVII::LMUL_1; uint8_t SEW = 0; @@ -76,6 +80,11 @@ bool isValid() const { return State != Uninitialized; } void setUnknown() { State = Unknown; } bool isUnknown() const { return State == Unknown; } + void setVLModified(const MachineInstr *MI) { + State = VLModified; + VLModifier = MI; + } + bool isVLModified() const { return State == VLModified; } void setAVLReg(Register Reg) { AVLReg = Reg; @@ -110,6 +119,8 @@ return false; } + const MachineInstr *getVLModifier() const { return VLModifier; } + bool hasSameAVL(const VSETVLIInfo &Other) const { assert(isValid() && Other.isValid() && "Can't compare invalid VSETVLIInfos"); @@ -243,6 +254,10 @@ if (isUnknown() || InstrInfo.isUnknown()) return false; + // Nothing is compatible with VLModified. + if (isVLModified() || InstrInfo.isVLModified()) + return false; + // If only our VLMAX ratio is valid, then this isn't compatible. if (SEWLMULRatioOnly) return false; @@ -349,6 +364,10 @@ if (!isValid()) return Other; + // If either is vl-modified, the result is unknown. + if (isVLModified() || Other.isVLModified()) + return VSETVLIInfo::getUnknown(); + // If either is unknown, the result is unknown. if (isUnknown() || Other.isUnknown()) return VSETVLIInfo::getUnknown(); @@ -413,6 +432,7 @@ std::vector BlockInfo; std::queue WorkList; + std::unordered_map RegToVLModifer; public: static char ID; @@ -438,6 +458,7 @@ bool computeVLVTYPEChanges(const MachineBasicBlock &MBB); void computeIncomingVLVTYPE(const MachineBasicBlock &MBB); void emitVSETVLIs(MachineBasicBlock &MBB); + void collectVLCopy(const MachineBasicBlock &MBB); }; } // end anonymous namespace @@ -685,6 +706,10 @@ if (DefInfo.hasSameAVL(CurInfo) && DefInfo.hasSameVTYPE(CurInfo)) return false; } + + if (CurInfo.isVLModified() && + RegToVLModifer[Require.getAVLReg()] == CurInfo.getVLModifier()) + return false; } } @@ -920,10 +945,13 @@ BBInfo.Change = NewInfo; } } + if (MI.modifiesRegister(RISCV::VL)) + BBInfo.Change.setVLModified(&MI); // If this is something that updates VL/VTYPE that we don't know about, set // the state to unknown. - if (MI.isCall() || MI.isInlineAsm() || MI.modifiesRegister(RISCV::VL) || + if (MI.isCall() || MI.isInlineAsm() || + (!BBInfo.Change.isVLModified() && MI.modifiesRegister(RISCV::VL)) || MI.modifiesRegister(RISCV::VTYPE)) { BBInfo.Change = VSETVLIInfo::getUnknown(); } @@ -1002,6 +1030,13 @@ !PBBInfo.Exit.hasCompatibleVTYPE(Require, /*Strict*/ false)) return true; + if (PBBInfo.Exit.isVLModified()) { + // We need InReg is a copy of the result vl of VLModifer. + if (RegToVLModifer[InReg] != PBBInfo.Exit.getVLModifier()) + return true; + continue; + } + // We need the PHI input to the be the output of a VSET(I)VLI. MachineInstr *DefMI = MRI->getVRegDef(InReg); if (!DefMI || !isVectorConfigInstr(*DefMI)) @@ -1105,14 +1140,25 @@ if (NeedInsertVSETVLI) insertVSETVLI(MBB, MI, NewInfo, CurInfo); CurInfo = NewInfo; + } else if (CurInfo.isVLModified()) { + // Circumvent VSET(I)VLI insertion for CurInfo having different + // VSETVLIInfo with exit of MBB. + CurInfo = NewInfo; } } + + if (MI.modifiesRegister(RISCV::VL)) { + if (!CurInfo.isValid()) + CurInfo = NewInfo; + CurInfo.setVLModified(&MI); + } PrevVSETVLIMI = nullptr; } // If this is something updates VL/VTYPE that we don't know about, set // the state to unknown. - if (MI.isCall() || MI.isInlineAsm() || MI.modifiesRegister(RISCV::VL) || + if (MI.isCall() || MI.isInlineAsm() || + (MI.modifiesRegister(RISCV::VL) && !CurInfo.isVLModified()) || MI.modifiesRegister(RISCV::VTYPE)) { CurInfo = VSETVLIInfo::getUnknown(); PrevVSETVLIMI = nullptr; @@ -1123,7 +1169,7 @@ if (MI.isTerminator()) { const VSETVLIInfo &ExitInfo = BlockInfo[MBB.getNumber()].Exit; if (CurInfo.isValid() && ExitInfo.isValid() && !ExitInfo.isUnknown() && - CurInfo != ExitInfo) { + !ExitInfo.isVLModified() && CurInfo != ExitInfo) { insertVSETVLI(MBB, MI, ExitInfo, CurInfo); CurInfo = ExitInfo; } @@ -1131,6 +1177,35 @@ } } +// Try to collect all the vl copy of all vl-modified vector operations except +// VSET(I)VLI. +void RISCVInsertVSETVLI::collectVLCopy(const MachineBasicBlock &MBB) { + const MachineInstr *VLModifier = nullptr; + + for (auto &MI : MBB) { + if (VLModifier && MI.getOpcode() == RISCV::PseudoReadVL) { + Register ResultedVL = MI.getOperand(0).getReg(); + // Only deal with virtual register to avoid alias problem. + if (Register::isVirtualRegister(ResultedVL)) + RegToVLModifer[ResultedVL] = VLModifier; + continue; + } + + uint64_t TSFlags = MI.getDesc().TSFlags; + if (RISCVII::hasSEWOp(TSFlags)) { + if (MI.modifiesRegister(RISCV::VL)) + VLModifier = &MI; + else + VLModifier = nullptr; + continue; + } + + if (MI.isCall() || MI.isInlineAsm() || MI.modifiesRegister(RISCV::VTYPE) || + MI.modifiesRegister(RISCV::VL)) + VLModifier = nullptr; + } +} + bool RISCVInsertVSETVLI::runOnMachineFunction(MachineFunction &MF) { // Skip if the vector extension is not enabled. const RISCVSubtarget &ST = MF.getSubtarget(); @@ -1164,6 +1239,9 @@ computeIncomingVLVTYPE(MBB); } + for (const MachineBasicBlock &MBB : MF) + collectVLCopy(MBB); + // Phase 3 - add any vsetvli instructions needed in the block. Use the // Phase 2 information to avoid adding vsetvlis before the first vector // instruction in the block if the VL/VTYPE is satisfied by its diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-modify-vl.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-modify-vl.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-modify-vl.ll @@ -0,0 +1,84 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv64 -mattr=+v \ +; RUN: -target-abi=lp64d -verify-machineinstrs -< %s | FileCheck %s + +declare i64 @llvm.riscv.vsetvli.i64(i64, i64 immarg, i64 immarg) +declare { , i64 } @llvm.riscv.vleff.nxv32i8.i64(, * nocapture, i64) +declare @llvm.riscv.vmseq.nxv32i8.i8.i64(, i8, i64) +declare @llvm.riscv.vadd.nxv32i8.i8.i64(, , i8, i64) +declare @llvm.riscv.vadd.nxv16i16.i16.i64(, , i16, i64) + +define @seq1(i1 zeroext %cond, i8* %str, i64 %n, i8 %x) { +; CHECK-LABEL: seq1: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, a2, e8, m4, ta, mu +; CHECK-NEXT: vle8ff.v v8, (a1) +; CHECK-NEXT: vadd.vx v8, v8, a3 +; CHECK-NEXT: vmseq.vi v0, v8, 0 +; CHECK-NEXT: csrr a0, vl +; CHECK-NEXT: ret +entry: + %0 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %n, i64 0, i64 2) + %1 = bitcast i8* %str to * + %2 = tail call { , i64 } @llvm.riscv.vleff.nxv32i8.i64( undef, * %1, i64 %0) + %3 = extractvalue { , i64 } %2, 0 + %4 = extractvalue { , i64 } %2, 1 + %5 = tail call @llvm.riscv.vadd.nxv32i8.i8.i64( undef, %3, i8 %x, i64 %4) + %6 = tail call @llvm.riscv.vmseq.nxv32i8.i8.i64( %5, i8 0, i64 %4) + ret %6 +} + +define @cross_bb(i1 zeroext %cond, i8 zeroext %x, %vv, i8* %str, i64 %n) { +; CHECK-LABEL: cross_bb: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a4, a3, e8, m4, ta, mu +; CHECK-NEXT: beqz a0, .LBB1_2 +; CHECK-NEXT: # %bb.1: # %if.then +; CHECK-NEXT: vle8ff.v v12, (a2) +; CHECK-NEXT: j .LBB1_3 +; CHECK-NEXT: .LBB1_2: # %if.else +; CHECK-NEXT: vsetvli a0, a3, e8, m4, ta, mu +; CHECK-NEXT: .LBB1_3: # %if.end +; CHECK-NEXT: vadd.vx v8, v8, a1 +; CHECK-NEXT: vadd.vx v8, v8, a1 +; CHECK-NEXT: ret +entry: + %0 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %n, i64 0, i64 2) + br i1 %cond, label %if.then, label %if.else + +if.then: ; preds = %entry + %1 = bitcast i8* %str to * + %2 = tail call { , i64 } @llvm.riscv.vleff.nxv32i8.i64( undef, * %1, i64 %0) + %3 = extractvalue { , i64 } %2, 1 + br label %if.end + +if.else: ; preds = %entry + %4 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %n, i64 0, i64 2) + br label %if.end + +if.end: ; preds = %if.else, %if.then + %new_vl.0 = phi i64 [ %3, %if.then ], [ %4, %if.else ] + %5 = tail call @llvm.riscv.vadd.nxv32i8.i8.i64( undef, %vv, i8 %x, i64 %new_vl.0) + %6 = tail call @llvm.riscv.vadd.nxv32i8.i8.i64( undef, %5, i8 %x, i64 %new_vl.0) + ret %6 +} + +; Test not eleminating useful vsetvli. +define @no_work(i1 zeroext %cond, i8* %str, i64 %n, %v, i16 %x) { +; CHECK-LABEL: no_work: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli a0, a2, e8, m4, ta, mu +; CHECK-NEXT: vle8ff.v v12, (a1) +; CHECK-NEXT: csrr a0, vl +; CHECK-NEXT: vsetvli zero, a0, e16, m4, ta, mu +; CHECK-NEXT: vadd.vx v8, v8, a3 +; CHECK-NEXT: ret +entry: + %0 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %n, i64 0, i64 2) + %1 = bitcast i8* %str to * + %2 = tail call { , i64 } @llvm.riscv.vleff.nxv32i8.i64( undef, * %1, i64 %0) + %3 = extractvalue { , i64 } %2, 1 + %4 = tail call @llvm.riscv.vadd.nxv16i16.i16.i64( undef, %v, i16 %x, i64 %3) + ret %4 +} +