Index: llvm/lib/Target/RISCV/RISCVFrameLowering.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVFrameLowering.cpp +++ llvm/lib/Target/RISCV/RISCVFrameLowering.cpp @@ -328,7 +328,6 @@ MachineInstr::MIFlag Flag) const { assert(Amount != 0 && "Did not need to adjust stack pointer for RVV."); - const RISCVInstrInfo *TII = STI.getInstrInfo(); const Register SPReg = getSPReg(STI); // Optimize compile time offset case @@ -347,20 +346,9 @@ return; } - unsigned Opc = RISCV::ADD; - if (Amount < 0) { - Amount = -Amount; - Opc = RISCV::SUB; - } - // 1. Multiply the number of v-slots to the length of registers - Register FactorRegister = - MF.getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); - TII->getVLENFactoredAmount(MF, MBB, MBBI, DL, FactorRegister, Amount, Flag); - // 2. SP = SP - RVV stack size - BuildMI(MBB, MBBI, DL, TII->get(Opc), SPReg) - .addReg(SPReg) - .addReg(FactorRegister, RegState::Kill) - .setMIFlag(Flag); + const RISCVRegisterInfo &RI = *STI.getRegisterInfo(); + RI.adjustReg(MBB, MBBI, DL, SPReg, SPReg, StackOffset::getScalable(Amount), + Flag); } void RISCVFrameLowering::emitPrologue(MachineFunction &MF, Index: llvm/lib/Target/RISCV/RISCVRegisterInfo.h =================================================================== --- llvm/lib/Target/RISCV/RISCVRegisterInfo.h +++ llvm/lib/Target/RISCV/RISCVRegisterInfo.h @@ -48,8 +48,9 @@ MaybeAlign RequiredAlign) const; // Update DestReg to have the value of SrcReg plus an Offset. - void adjustReg(MachineBasicBlock::iterator II, Register DestReg, - Register SrcReg, StackOffset Offset) const; + void adjustReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator II, + const DebugLoc &DL, Register DestReg, Register SrcReg, + StackOffset Offset, MachineInstr::MIFlag Flag) const; bool eliminateFrameIndex(MachineBasicBlock::iterator MI, int SPAdj, unsigned FIOperandNum, Index: llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp =================================================================== --- llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -219,19 +219,19 @@ .setMIFlag(Flag); } -void RISCVRegisterInfo::adjustReg(MachineBasicBlock::iterator II, Register DestReg, - Register SrcReg, StackOffset Offset) const { +void RISCVRegisterInfo::adjustReg(MachineBasicBlock &MBB, + MachineBasicBlock::iterator II, + const DebugLoc &DL, Register DestReg, + Register SrcReg, StackOffset Offset, + MachineInstr::MIFlag Flag) const { if (DestReg == SrcReg && !Offset.getFixed() && !Offset.getScalable()) return; - MachineInstr &MI = *II; - MachineFunction &MF = *MI.getParent()->getParent(); + MachineFunction &MF = *MBB.getParent(); MachineRegisterInfo &MRI = MF.getRegInfo(); const RISCVSubtarget &ST = MF.getSubtarget(); const RISCVInstrInfo *TII = ST.getInstrInfo(); - DebugLoc DL = MI.getDebugLoc(); - MachineBasicBlock &MBB = *MI.getParent(); bool SrcRegIsKill = false; @@ -243,9 +243,13 @@ ScalableAdjOpc = RISCV::SUB; } // Get vlenb and multiply vlen with the number of vector registers. - TII->getVLENFactoredAmount(MF, MBB, II, DL, DestReg, ScalableValue); + Register ScratchReg = DestReg; + if (DestReg == SrcReg) + ScratchReg = MRI.createVirtualRegister(&RISCV::GPRRegClass); + TII->getVLENFactoredAmount(MF, MBB, II, DL, ScratchReg, ScalableValue, Flag); BuildMI(MBB, II, DL, TII->get(ScalableAdjOpc), DestReg) - .addReg(SrcReg).addReg(DestReg, RegState::Kill); + .addReg(SrcReg).addReg(ScratchReg, RegState::Kill) + .setMIFlag(Flag); SrcReg = DestReg; SrcRegIsKill = true; } @@ -256,13 +260,15 @@ if (isInt<12>(Offset.getFixed())) { BuildMI(MBB, II, DL, TII->get(RISCV::ADDI), DestReg) .addReg(SrcReg, getKillRegState(SrcRegIsKill)) - .addImm(Offset.getFixed()); + .addImm(Offset.getFixed()) + .setMIFlag(Flag); } else { Register ScratchReg = MRI.createVirtualRegister(&RISCV::GPRRegClass); TII->movImm(MBB, II, DL, ScratchReg, Offset.getFixed()); BuildMI(MBB, II, DL, TII->get(RISCV::ADD), DestReg) .addReg(SrcReg, getKillRegState(SrcRegIsKill)) - .addReg(ScratchReg, RegState::Kill); + .addReg(ScratchReg, RegState::Kill) + .setMIFlag(Flag); } } } @@ -325,7 +331,8 @@ DestReg = MI.getOperand(0).getReg(); else DestReg = MRI.createVirtualRegister(&RISCV::GPRRegClass); - adjustReg(II, DestReg, FrameReg, Offset); + adjustReg(*II->getParent(), II, DL, DestReg, FrameReg, Offset, + MachineInstr::NoFlags); MI.getOperand(FIOperandNum).ChangeToRegister(DestReg, /*IsDef*/false, /*IsImp*/false, /*IsKill*/true);