diff --git a/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp b/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp --- a/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp +++ b/llvm/lib/Target/RISCV/RISCVMergeBaseOffset.cpp @@ -49,9 +49,10 @@ bool detectAndFoldOffset(MachineInstr &HiLUI, MachineInstr &LoADDI); void foldOffset(MachineInstr &HiLUI, MachineInstr &LoADDI, MachineInstr &Tail, int64_t Offset); - bool matchLargeOffset(MachineInstr &TailAdd, Register GSReg, int64_t &Offset); - bool matchShiftedOffset(MachineInstr &TailShXAdd, Register GSReg, - int64_t &Offset); + bool foldLargeOffset(MachineInstr &HiLUI, MachineInstr &LoADDI, + MachineInstr &TailAdd, Register GSReg); + bool foldShiftedOffset(MachineInstr &HiLUI, MachineInstr &LoADDI, + MachineInstr &TailShXAdd, Register GSReg); RISCVMergeBaseOffsetOpt() : MachineFunctionPass(ID) {} @@ -71,7 +72,6 @@ private: MachineRegisterInfo *MRI; - std::set DeadInstrs; }; } // end anonymous namespace @@ -117,15 +117,17 @@ // Put the offset back in HiLUI and the LoADDI HiLUI.getOperand(1).setOffset(Offset); LoADDI.getOperand(2).setOffset(Offset); - // Delete the tail instruction. - DeadInstrs.insert(&Tail); MRI->replaceRegWith(Tail.getOperand(0).getReg(), LoADDI.getOperand(0).getReg()); + // Delete the tail instruction. + Tail.eraseFromParent(); LLVM_DEBUG(dbgs() << " Merged offset " << Offset << " into base.\n" << " " << HiLUI << " " << LoADDI;); } // Detect patterns for large offsets that are passed into an ADD instruction. +// If the pattern is found, updates the offset in HiLUI and LoADDI instructions +// and deletes TailAdd and the instructions that produced the offset. // // Base address lowering is of the form: // HiLUI: lui vreg1, %hi(s) @@ -143,9 +145,10 @@ // \ / // \ / // TailAdd: add vreg4, vreg2, voff -bool RISCVMergeBaseOffsetOpt::matchLargeOffset(MachineInstr &TailAdd, - Register GAReg, - int64_t &Offset) { +bool RISCVMergeBaseOffsetOpt::foldLargeOffset(MachineInstr &HiLUI, + MachineInstr &LoADDI, + MachineInstr &TailAdd, + Register GAReg) { assert((TailAdd.getOpcode() == RISCV::ADD) && "Expected ADD instruction!"); Register Rs = TailAdd.getOperand(1).getReg(); Register Rt = TailAdd.getOperand(2).getReg(); @@ -171,7 +174,7 @@ LuiImmOp.getTargetFlags() != RISCVII::MO_None || !MRI->hasOneUse(OffsetLui.getOperand(0).getReg())) return false; - Offset = SignExtend64<32>(LuiImmOp.getImm() << 12); + int64_t Offset = SignExtend64<32>(LuiImmOp.getImm() << 12); Offset += OffLo; // RV32 ignores the upper 32 bits. ADDIW sign extends the result. if (!ST->is64Bit() || OffsetTail.getOpcode() == RISCV::ADDIW) @@ -181,15 +184,17 @@ return false; LLVM_DEBUG(dbgs() << " Offset Instrs: " << OffsetTail << " " << OffsetLui); - DeadInstrs.insert(&OffsetTail); - DeadInstrs.insert(&OffsetLui); + foldOffset(HiLUI, LoADDI, TailAdd, Offset); + OffsetTail.eraseFromParent(); + OffsetLui.eraseFromParent(); return true; } else if (OffsetTail.getOpcode() == RISCV::LUI) { // The offset value has all zero bits in the lower 12 bits. Only LUI // exists. LLVM_DEBUG(dbgs() << " Offset Instr: " << OffsetTail); - Offset = SignExtend64<32>(OffsetTail.getOperand(1).getImm() << 12); - DeadInstrs.insert(&OffsetTail); + int64_t Offset = SignExtend64<32>(OffsetTail.getOperand(1).getImm() << 12); + foldOffset(HiLUI, LoADDI, TailAdd, Offset); + OffsetTail.eraseFromParent(); return true; } return false; @@ -199,14 +204,17 @@ // The offset has 1,2, or 3 trailing zeros and fits in simm13, simm14, simm15. // The constant is created with addi voff, x0, C, and shXadd is used to // fill insert the trailing zeros and do the addition. +// If the pattern is found, updates the offset in HiLUI and LoADDI instructions +// and deletes TailShXAdd and the instructions that produced the offset. // // HiLUI: lui vreg1, %hi(s) // LoADDI: addi vreg2, vreg1, %lo(s) // OffsetTail: addi voff, x0, C // TailAdd: shXadd vreg4, voff, vreg2 -bool RISCVMergeBaseOffsetOpt::matchShiftedOffset(MachineInstr &TailShXAdd, - Register GAReg, - int64_t &Offset) { +bool RISCVMergeBaseOffsetOpt::foldShiftedOffset(MachineInstr &HiLUI, + MachineInstr &LoADDI, + MachineInstr &TailShXAdd, + Register GAReg) { assert((TailShXAdd.getOpcode() == RISCV::SH1ADD || TailShXAdd.getOpcode() == RISCV::SH2ADD || TailShXAdd.getOpcode() == RISCV::SH3ADD) && @@ -230,7 +238,7 @@ !OffsetTail.getOperand(2).isImm()) return false; - Offset = OffsetTail.getOperand(2).getImm(); + int64_t Offset = OffsetTail.getOperand(2).getImm(); assert(isInt<12>(Offset) && "Unexpected offset"); unsigned ShAmt; @@ -244,7 +252,8 @@ Offset = (uint64_t)Offset << ShAmt; LLVM_DEBUG(dbgs() << " Offset Instr: " << OffsetTail); - DeadInstrs.insert(&OffsetTail); + foldOffset(HiLUI, LoADDI, TailShXAdd, Offset); + OffsetTail.eraseFromParent(); return true; } @@ -274,8 +283,8 @@ if (TailTail.getOpcode() == RISCV::ADDI) { Offset += TailTail.getOperand(2).getImm(); LLVM_DEBUG(dbgs() << " Offset Instrs: " << Tail << TailTail); - DeadInstrs.insert(&Tail); foldOffset(HiLUI, LoADDI, TailTail, Offset); + Tail.eraseFromParent(); return true; } } @@ -293,11 +302,7 @@ // both hi 20 and lo 12 bits. // 2) LUI (offset20) // This happens in case the lower 12 bits of the offset are zeros. - int64_t Offset; - if (!matchLargeOffset(Tail, DestReg, Offset)) - return false; - foldOffset(HiLUI, LoADDI, Tail, Offset); - return true; + return foldLargeOffset(HiLUI, LoADDI, Tail, DestReg); } case RISCV::SH1ADD: case RISCV::SH2ADD: @@ -305,11 +310,7 @@ // The offset is too large to fit in the immediate field of ADDI. // It may be encoded as (SH2ADD (ADDI X0, C), DestReg) or // (SH3ADD (ADDI X0, C), DestReg). - int64_t Offset; - if (!matchShiftedOffset(Tail, DestReg, Offset)) - return false; - foldOffset(HiLUI, LoADDI, Tail, Offset); - return true; + return foldShiftedOffset(HiLUI, LoADDI, Tail, DestReg); } } } @@ -376,7 +377,7 @@ UseMI.getOperand(1).setReg(HiLUI.getOperand(0).getReg()); } - DeadInstrs.insert(&LoADDI); + LoADDI.eraseFromParent(); return true; } @@ -387,7 +388,6 @@ ST = &Fn.getSubtarget(); bool MadeChange = false; - DeadInstrs.clear(); MRI = &Fn.getRegInfo(); for (MachineBasicBlock &MBB : Fn) { LLVM_DEBUG(dbgs() << "MBB: " << MBB.getName() << "\n"); @@ -400,9 +400,7 @@ MadeChange |= detectAndFoldOffset(HiLUI, *LoADDI); } } - // Delete dead instructions. - for (auto *MI : DeadInstrs) - MI->eraseFromParent(); + return MadeChange; }