diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp @@ -329,7 +329,6 @@ MachineInstr::MIFlag Flag) const { assert(Amount != 0 && "Did not need to adjust stack pointer for RVV."); - const RISCVInstrInfo *TII = STI.getInstrInfo(); const Register SPReg = getSPReg(STI); // Optimize compile time offset case @@ -348,20 +347,9 @@ return; } - unsigned Opc = RISCV::ADD; - if (Amount < 0) { - Amount = -Amount; - Opc = RISCV::SUB; - } - // 1. Multiply the number of v-slots to the length of registers - Register FactorRegister = - MF.getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); - TII->getVLENFactoredAmount(MF, MBB, MBBI, DL, FactorRegister, Amount, Flag); - // 2. SP = SP - RVV stack size - BuildMI(MBB, MBBI, DL, TII->get(Opc), SPReg) - .addReg(SPReg) - .addReg(FactorRegister, RegState::Kill) - .setMIFlag(Flag); + const RISCVRegisterInfo &RI = *STI.getRegisterInfo(); + RI.adjustReg(MBB, MBBI, DL, SPReg, SPReg, StackOffset::getScalable(Amount), + Flag); } void RISCVFrameLowering::emitPrologue(MachineFunction &MF, diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.h +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.h @@ -49,8 +49,9 @@ MaybeAlign RequiredAlign, bool KillSrcReg) const; // Update DestReg to have the value of SrcReg plus an Offset. - void adjustReg(MachineBasicBlock::iterator II, Register DestReg, - Register SrcReg, StackOffset Offset) const; + void adjustReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator II, + const DebugLoc &DL, Register DestReg, Register SrcReg, + StackOffset Offset, MachineInstr::MIFlag Flag) const; bool eliminateFrameIndex(MachineBasicBlock::iterator MI, int SPAdj, unsigned FIOperandNum, diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -220,18 +220,19 @@ .setMIFlag(Flag); } -void RISCVRegisterInfo::adjustReg(MachineBasicBlock::iterator II, Register DestReg, - Register SrcReg, StackOffset Offset) const { +void RISCVRegisterInfo::adjustReg(MachineBasicBlock &MBB, + MachineBasicBlock::iterator II, + const DebugLoc &DL, Register DestReg, + Register SrcReg, StackOffset Offset, + MachineInstr::MIFlag Flag) const { if (DestReg == SrcReg && !Offset.getFixed() && !Offset.getScalable()) return; - MachineInstr &MI = *II; - MachineFunction &MF = *MI.getParent()->getParent(); + MachineFunction &MF = *MBB.getParent(); + MachineRegisterInfo &MRI = MF.getRegInfo(); const RISCVSubtarget &ST = MF.getSubtarget(); const RISCVInstrInfo *TII = ST.getInstrInfo(); - DebugLoc DL = MI.getDebugLoc(); - MachineBasicBlock &MBB = *MI.getParent(); bool SrcRegIsKill = false; @@ -243,16 +244,20 @@ ScalableAdjOpc = RISCV::SUB; } // Get vlenb and multiply vlen with the number of vector registers. - TII->getVLENFactoredAmount(MF, MBB, II, DL, DestReg, ScalableValue); + Register ScratchReg = DestReg; + if (DestReg == SrcReg) + ScratchReg = MRI.createVirtualRegister(&RISCV::GPRRegClass); + TII->getVLENFactoredAmount(MF, MBB, II, DL, ScratchReg, ScalableValue, Flag); BuildMI(MBB, II, DL, TII->get(ScalableAdjOpc), DestReg) - .addReg(SrcReg).addReg(DestReg, RegState::Kill); + .addReg(SrcReg).addReg(ScratchReg, RegState::Kill) + .setMIFlag(Flag); SrcReg = DestReg; SrcRegIsKill = true; } if (Offset.getFixed()) adjustReg(MBB, II, DL, DestReg, SrcReg, Offset.getFixed(), - MachineInstr::NoFlags, None, SrcRegIsKill); + Flag, None, SrcRegIsKill); } @@ -313,7 +318,8 @@ DestReg = MI.getOperand(0).getReg(); else DestReg = MRI.createVirtualRegister(&RISCV::GPRRegClass); - adjustReg(II, DestReg, FrameReg, Offset); + adjustReg(*II->getParent(), II, DL, DestReg, FrameReg, Offset, + MachineInstr::NoFlags); MI.getOperand(FIOperandNum).ChangeToRegister(DestReg, /*IsDef*/false, /*IsImp*/false, /*IsKill*/true);