diff --git a/llvm/lib/Target/RISCV/RISCVInsertWriteVXRM.cpp b/llvm/lib/Target/RISCV/RISCVInsertWriteVXRM.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertWriteVXRM.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertWriteVXRM.cpp @@ -192,6 +192,7 @@ bool computeVXRMChanges(const MachineBasicBlock &MBB, VXRMInfo &CurInfo); void computeIncomingVXRM(const MachineBasicBlock &MBB); void emitWriteVXRM(MachineBasicBlock &MBB); + void doPRE(MachineBasicBlock &MBB); }; } // end anonymous namespace @@ -213,6 +214,12 @@ unsigned NewVXRMImm = MI.getOperand(VXRMIdx).getImm() & 7; NeedVXRMChange = true; CurInfo.setVXRMImm(NewVXRMImm); + continue; + } + + if (MI.getOpcode() == RISCV::WriteVXRMImm) { + CurInfo.setVXRMImm(MI.getOperand(0).getImm()); + continue; } if (MI.isCall() || MI.isInlineAsm() || MI.modifiesRegister(RISCV::VXRM)) @@ -282,6 +289,12 @@ } CurInfo.setVXRMImm(NewVXRMImm); + continue; + } + + if (MI.getOpcode() == RISCV::WriteVXRMImm) { + CurInfo.setVXRMImm(MI.getOperand(0).getImm()); + continue; } if (MI.isCall() || MI.isInlineAsm() || MI.modifiesRegister(RISCV::VXRM)) @@ -292,6 +305,76 @@ report_fatal_error("Mismatched VXRM state"); } +/// Perform simple partial redundancy elimination of the VXRM write we're about +/// to insert by looking for cases where we can PRE from the beginning of one +/// block to the end of one of its predecessors. Specifically, this is geared +/// to catch the common case of a one rounding vxrm value in a single block +/// loop when we could do the write in the preheader instead. +void RISCVInsertWriteVXRM::doPRE(MachineBasicBlock &MBB) { + if (!BlockInfo[MBB.getNumber()].Pred.isUnknown()) + return; + + MachineBasicBlock *UnavailablePred = nullptr; + VXRMInfo AvailableInfo; + for (MachineBasicBlock *P : MBB.predecessors()) { + const VXRMInfo &PredInfo = BlockInfo[P->getNumber()].Exit; + if (PredInfo.isUnknown()) { + if (UnavailablePred) + return; + UnavailablePred = P; + } else if (!AvailableInfo.isValid()) { + AvailableInfo = PredInfo; + } else if (AvailableInfo != PredInfo) { + return; + } + } + + // Unreachable, single pred, or full redundancy. Note that FRE is handled by + // phase 3. + if (!UnavailablePred || !AvailableInfo.isStatic()) + return; + + // Critical edge - TODO: consider splitting? + if (UnavailablePred->succ_size() != 1) + return; + + // Make sure there's a VXRM write needed in this block. + bool FoundUser = false; + for (const MachineInstr &MI : MBB) { + int VXRMIdx = RISCVII::getVXRMOpNum(MI.getDesc()); + if (VXRMIdx >= 0) { + unsigned NewVXRMImm = MI.getOperand(VXRMIdx).getImm() & 7; + if (NewVXRMImm == AvailableInfo.getVXRMImm()) { + FoundUser = true; + break; + } + } + + if (MI.isCall() || MI.isInlineAsm() || MI.modifiesRegister(RISCV::VXRM)) + return; + } + + // If we didn't find a user, there's no need to do PRE. + if (!FoundUser) + return; + + // Finally, update both data flow state and insert the actual vsetvli. + // Doing both keeps the code in sync with the dataflow results, which + // is critical for correctness of phase 3. + LLVM_DEBUG(dbgs() << "PRE VXRM from " << MBB.getName() << " to " + << UnavailablePred->getName() << " with state " + << AvailableInfo << "\n"); + BlockInfo[UnavailablePred->getNumber()].Exit = AvailableInfo; + BlockInfo[MBB.getNumber()].Pred = AvailableInfo; + + // Note there's an implicit assumption here that terminators never use + // or modify VXRM. Also, fallthrough will return end(). + auto InsertPt = UnavailablePred->getFirstInstrTerminator(); + BuildMI(*UnavailablePred, InsertPt, UnavailablePred->findDebugLoc(InsertPt), + TII->get(RISCV::WriteVXRMImm)) + .addImm(AvailableInfo.getVXRMImm()); +} + bool RISCVInsertWriteVXRM::runOnMachineFunction(MachineFunction &MF) { // Skip if the vector extension is not enabled. const RISCVSubtarget &ST = MF.getSubtarget(); @@ -329,6 +412,10 @@ computeIncomingVXRM(MBB); } + // Perform partial redundancy elimination of VXRM writes. + for (MachineBasicBlock &MBB : MF) + doPRE(MBB); + // Phase 3 - add any VXRM writes needed. for (MachineBasicBlock &MBB : MF) emitWriteVXRM(MBB); diff --git a/llvm/test/CodeGen/RISCV/rvv/vxrm-insert.ll b/llvm/test/CodeGen/RISCV/rvv/vxrm-insert.ll --- a/llvm/test/CodeGen/RISCV/rvv/vxrm-insert.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vxrm-insert.ll @@ -392,19 +392,20 @@ define void @test10(i8* nocapture %ptr_dest, i8* nocapture readonly %ptr_op1, i8* nocapture readonly %ptr_op2, iXLen %n) { ; CHECK-LABEL: test10: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: beqz a3, .LBB9_2 -; CHECK-NEXT: .LBB9_1: # %for.body +; CHECK-NEXT: beqz a3, .LBB9_3 +; CHECK-NEXT: # %bb.1: # %for.body.preheader +; CHECK-NEXT: csrwi vxrm, 2 +; CHECK-NEXT: .LBB9_2: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: vsetvli a4, a3, e8, m4, ta, ma ; CHECK-NEXT: vsetvli zero, a4, e8, mf8, ta, ma ; CHECK-NEXT: vle8.v v8, (a1) ; CHECK-NEXT: vle8.v v9, (a2) -; CHECK-NEXT: csrwi vxrm, 2 ; CHECK-NEXT: vaadd.vv v8, v8, v9 ; CHECK-NEXT: sub a3, a3, a4 ; CHECK-NEXT: vse8.v v8, (a0) -; CHECK-NEXT: bnez a3, .LBB9_1 -; CHECK-NEXT: .LBB9_2: # %for.end +; CHECK-NEXT: bnez a3, .LBB9_2 +; CHECK-NEXT: .LBB9_3: # %for.end ; CHECK-NEXT: ret entry: %tobool.not9 = icmp eq iXLen %n, 0