diff --git a/llvm/lib/Target/RISCV/RISCVRedundantCopyElimination.cpp b/llvm/lib/Target/RISCV/RISCVRedundantCopyElimination.cpp --- a/llvm/lib/Target/RISCV/RISCVRedundantCopyElimination.cpp +++ b/llvm/lib/Target/RISCV/RISCVRedundantCopyElimination.cpp @@ -24,6 +24,7 @@ //===----------------------------------------------------------------------===// #include "RISCV.h" +#include "RISCVInstrInfo.h" #include "llvm/ADT/Statistic.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineRegisterInfo.h" @@ -39,6 +40,7 @@ class RISCVRedundantCopyElimination : public MachineFunctionPass { const MachineRegisterInfo *MRI; const TargetRegisterInfo *TRI; + const TargetInstrInfo *TII; public: static char ID; @@ -68,16 +70,17 @@ INITIALIZE_PASS(RISCVRedundantCopyElimination, "riscv-copyelim", "RISCV redundant copy elimination pass", false, false) -static bool guaranteesZeroRegInBlock(const MachineInstr &MI, - const MachineBasicBlock &MBB) { - unsigned Opc = MI.getOpcode(); - if (Opc == RISCV::BEQ && MI.getOperand(1).getReg() == RISCV::X0 && - &MBB == MI.getOperand(2).getMBB()) +static bool +guaranteesZeroRegInBlock(MachineBasicBlock &MBB, + const SmallVectorImpl &Cond, + MachineBasicBlock *TBB) { + assert(Cond.size() == 3 && "Unexpected number of operands"); + assert(TBB != nullptr && "Expected branch target basic block"); + auto CC = static_cast(Cond[0].getImm()); + if (CC == RISCVCC::COND_EQ && Cond[2].getReg() == RISCV::X0 && TBB == &MBB) return true; - if (Opc == RISCV::BNE && MI.getOperand(1).getReg() == RISCV::X0 && - &MBB != MI.getOperand(2).getMBB()) + if (CC == RISCVCC::COND_NE && Cond[2].getReg() == RISCV::X0 && TBB != &MBB) return true; - return false; } @@ -92,24 +95,17 @@ if (PredMBB->succ_size() != 2) return false; - MachineBasicBlock::iterator CondBr = PredMBB->getLastNonDebugInstr(); - if (CondBr == PredMBB->end()) + MachineBasicBlock *TBB = nullptr, *FBB = nullptr; + SmallVector Cond; + if (TII->analyzeBranch(*PredMBB, TBB, FBB, Cond, /*AllowModify*/ false) || + Cond.empty()) return false; - while (true) { - // If we run out of terminators, give up. - if (!CondBr->isTerminator()) - return false; - // If we found a branch with X0, stop searching and try to remove copies. - if (guaranteesZeroRegInBlock(*CondBr, MBB)) - break; - // If we reached the beginning of the basic block, give up. - if (CondBr == PredMBB->begin()) - return false; - --CondBr; - } + // Is this a branch with X0? + if (!guaranteesZeroRegInBlock(MBB, Cond, TBB)) + return false; - Register TargetReg = CondBr->getOperand(0).getReg(); + Register TargetReg = Cond[1].getReg(); if (!TargetReg) return false; @@ -144,6 +140,12 @@ if (!Changed) return false; + MachineBasicBlock::iterator CondBr = PredMBB->getFirstTerminator(); + assert((CondBr->getOpcode() == RISCV::BEQ || + CondBr->getOpcode() == RISCV::BNE) && + "Unexpected opcode"); + assert(CondBr->getOperand(0).getReg() == TargetReg && "Unexpected register"); + // Otherwise, we have to fixup the use-def chain, starting with the // BEQ/BNE. Conservatively mark as much as we can live. CondBr->clearRegisterKills(TargetReg, TRI); @@ -163,6 +165,7 @@ if (skipFunction(MF.getFunction())) return false; + TII = MF.getSubtarget().getInstrInfo(); TRI = MF.getSubtarget().getRegisterInfo(); MRI = &MF.getRegInfo();