diff --git a/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp b/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp --- a/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp +++ b/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp @@ -37,9 +37,10 @@ bool detectAndFoldOffset(MachineInstr &Hi, MachineInstr &Lo); void foldOffset(MachineInstr &Hi, MachineInstr &Lo, MachineInstr &Tail, int64_t Offset); - bool matchLargeOffset(MachineInstr &TailAdd, Register GSReg, int64_t &Offset); - bool matchShiftedOffset(MachineInstr &TailShXAdd, Register GSReg, - int64_t &Offset); + bool foldLargeOffset(MachineInstr &Hi, MachineInstr &Lo, + MachineInstr &TailAdd, Register GSReg); + bool foldShiftedOffset(MachineInstr &Hi, MachineInstr &Lo, + MachineInstr &TailShXAdd, Register GSReg); RISCVMergeBaseOffsetOpt() : MachineFunctionPass(ID) {} @@ -59,7 +60,6 @@ private: MachineRegisterInfo *MRI; - std::set DeadInstrs; }; } // end anonymous namespace @@ -125,17 +125,19 @@ if (Hi.getOpcode() != RISCV::AUIPC) Lo.getOperand(2).setOffset(Offset); // Delete the tail instruction. - DeadInstrs.insert(&Tail); MRI->replaceRegWith(Tail.getOperand(0).getReg(), Lo.getOperand(0).getReg()); + Tail.eraseFromParent(); LLVM_DEBUG(dbgs() << " Merged offset " << Offset << " into base.\n" << " " << Hi << " " << Lo;); } // Detect patterns for large offsets that are passed into an ADD instruction. +// If the pattern is found, updates the offset in Hi and Lo instructions +// and deletes TailAdd and the instructions that produced the offset. // // Base address lowering is of the form: -// HiLUI: lui vreg1, %hi(s) -// LoADDI: addi vreg2, vreg1, %lo(s) +// Hi: lui vreg1, %hi(s) +// Lo: addi vreg2, vreg1, %lo(s) // / \ // / \ // / \ @@ -149,9 +151,10 @@ // \ / // \ / // TailAdd: add vreg4, vreg2, voff -bool RISCVMergeBaseOffsetOpt::matchLargeOffset(MachineInstr &TailAdd, - Register GAReg, - int64_t &Offset) { +bool RISCVMergeBaseOffsetOpt::foldLargeOffset(MachineInstr &Hi, + MachineInstr &Lo, + MachineInstr &TailAdd, + Register GAReg) { assert((TailAdd.getOpcode() == RISCV::ADD) && "Expected ADD instruction!"); Register Rs = TailAdd.getOperand(1).getReg(); Register Rt = TailAdd.getOperand(2).getReg(); @@ -177,7 +180,7 @@ LuiImmOp.getTargetFlags() != RISCVII::MO_None || !MRI->hasOneUse(OffsetLui.getOperand(0).getReg())) return false; - Offset = SignExtend64<32>(LuiImmOp.getImm() << 12); + int64_t Offset = SignExtend64<32>(LuiImmOp.getImm() << 12); Offset += OffLo; // RV32 ignores the upper 32 bits. ADDIW sign extends the result. if (!ST->is64Bit() || OffsetTail.getOpcode() == RISCV::ADDIW) @@ -187,15 +190,17 @@ return false; LLVM_DEBUG(dbgs() << " Offset Instrs: " << OffsetTail << " " << OffsetLui); - DeadInstrs.insert(&OffsetTail); - DeadInstrs.insert(&OffsetLui); + foldOffset(Hi, Lo, TailAdd, Offset); + OffsetTail.eraseFromParent(); + OffsetLui.eraseFromParent(); return true; } else if (OffsetTail.getOpcode() == RISCV::LUI) { // The offset value has all zero bits in the lower 12 bits. Only LUI // exists. LLVM_DEBUG(dbgs() << " Offset Instr: " << OffsetTail); - Offset = SignExtend64<32>(OffsetTail.getOperand(1).getImm() << 12); - DeadInstrs.insert(&OffsetTail); + int64_t Offset = SignExtend64<32>(OffsetTail.getOperand(1).getImm() << 12); + foldOffset(Hi, Lo, TailAdd, Offset); + OffsetTail.eraseFromParent(); return true; } return false; @@ -205,14 +210,17 @@ // The offset has 1, 2, or 3 trailing zeros and fits in simm13, simm14, simm15. // The constant is created with addi voff, x0, C, and shXadd is used to // fill insert the trailing zeros and do the addition. +// If the pattern is found, updates the offset in Hi and Lo instructions +// and deletes TailShXAdd and the instructions that produced the offset. // // Hi: lui vreg1, %hi(s) // Lo: addi vreg2, vreg1, %lo(s) // OffsetTail: addi voff, x0, C // TailAdd: shXadd vreg4, voff, vreg2 -bool RISCVMergeBaseOffsetOpt::matchShiftedOffset(MachineInstr &TailShXAdd, - Register GAReg, - int64_t &Offset) { +bool RISCVMergeBaseOffsetOpt::foldShiftedOffset(MachineInstr &Hi, + MachineInstr &Lo, + MachineInstr &TailShXAdd, + Register GAReg) { assert((TailShXAdd.getOpcode() == RISCV::SH1ADD || TailShXAdd.getOpcode() == RISCV::SH2ADD || TailShXAdd.getOpcode() == RISCV::SH3ADD) && @@ -236,7 +244,7 @@ !OffsetTail.getOperand(2).isImm()) return false; - Offset = OffsetTail.getOperand(2).getImm(); + int64_t Offset = OffsetTail.getOperand(2).getImm(); assert(isInt<12>(Offset) && "Unexpected offset"); unsigned ShAmt; @@ -250,7 +258,8 @@ Offset = (uint64_t)Offset << ShAmt; LLVM_DEBUG(dbgs() << " Offset Instr: " << OffsetTail); - DeadInstrs.insert(&OffsetTail); + foldOffset(Hi, Lo, TailShXAdd, Offset); + OffsetTail.eraseFromParent(); return true; } @@ -280,8 +289,8 @@ if (TailTail.getOpcode() == RISCV::ADDI) { Offset += TailTail.getOperand(2).getImm(); LLVM_DEBUG(dbgs() << " Offset Instrs: " << Tail << TailTail); - DeadInstrs.insert(&Tail); foldOffset(Hi, Lo, TailTail, Offset); + Tail.eraseFromParent(); return true; } } @@ -299,11 +308,7 @@ // both hi 20 and lo 12 bits. // 2) LUI (offset20) // This happens in case the lower 12 bits of the offset are zeros. - int64_t Offset; - if (!matchLargeOffset(Tail, DestReg, Offset)) - return false; - foldOffset(Hi, Lo, Tail, Offset); - return true; + return foldLargeOffset(Hi, Lo, Tail, DestReg); } case RISCV::SH1ADD: case RISCV::SH2ADD: @@ -311,11 +316,7 @@ // The offset is too large to fit in the immediate field of ADDI. // It may be encoded as (SH2ADD (ADDI X0, C), DestReg) or // (SH3ADD (ADDI X0, C), DestReg). - int64_t Offset; - if (!matchShiftedOffset(Tail, DestReg, Offset)) - return false; - foldOffset(Hi, Lo, Tail, Offset); - return true; + return foldShiftedOffset(Hi, Lo, Tail, DestReg); } } } @@ -389,7 +390,7 @@ UseMI.getOperand(1).setReg(Hi.getOperand(0).getReg()); } - DeadInstrs.insert(&Lo); + Lo.eraseFromParent(); return true; } @@ -400,7 +401,6 @@ ST = &Fn.getSubtarget(); bool MadeChange = false; - DeadInstrs.clear(); MRI = &Fn.getRegInfo(); for (MachineBasicBlock &MBB : Fn) { LLVM_DEBUG(dbgs() << "MBB: " << MBB.getName() << "\n"); @@ -413,9 +413,7 @@ MadeChange |= detectAndFoldOffset(Hi, *Lo); } } - // Delete dead instructions. - for (auto *MI : DeadInstrs) - MI->eraseFromParent(); + return MadeChange; }