diff --git a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVFrameLowering.cpp @@ -381,7 +381,8 @@ } // 1. Multiply the number of v-slots to the length of registers Register FactorRegister = - TII->getVLENFactoredAmount(MF, MBB, MBBI, DL, Amount, Flag); + MF.getRegInfo().createVirtualRegister(&RISCV::GPRRegClass); + TII->getVLENFactoredAmount(MF, MBB, MBBI, DL, FactorRegister, Amount, Flag); // 2. SP = SP - RVV stack size BuildMI(MBB, MBBI, DL, TII->get(Opc), SPReg) .addReg(SPReg) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.h b/llvm/lib/Target/RISCV/RISCVInstrInfo.h --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.h +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.h @@ -177,10 +177,10 @@ unsigned OpIdx, const TargetRegisterInfo *TRI) const override; - Register getVLENFactoredAmount( + void getVLENFactoredAmount( MachineFunction &MF, MachineBasicBlock &MBB, - MachineBasicBlock::iterator II, const DebugLoc &DL, int64_t Amount, - MachineInstr::MIFlag Flag = MachineInstr::NoFlags) const; + MachineBasicBlock::iterator II, const DebugLoc &DL, Register DestReg, + int64_t Amount, MachineInstr::MIFlag Flag = MachineInstr::NoFlags) const; protected: const RISCVSubtarget &STI; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.cpp @@ -1953,12 +1953,12 @@ #undef CASE_WIDEOP_OPCODE_LMULS #undef CASE_WIDEOP_OPCODE_COMMON -Register RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF, - MachineBasicBlock &MBB, - MachineBasicBlock::iterator II, - const DebugLoc &DL, - int64_t Amount, - MachineInstr::MIFlag Flag) const { +void RISCVInstrInfo::getVLENFactoredAmount(MachineFunction &MF, + MachineBasicBlock &MBB, + MachineBasicBlock::iterator II, + const DebugLoc &DL, Register DestReg, + int64_t Amount, + MachineInstr::MIFlag Flag) const { assert(Amount > 0 && "There is no need to get VLEN scaled value."); assert(Amount % 8 == 0 && "Reserve the stack by the multiple of one vector size."); @@ -1966,17 +1966,15 @@ MachineRegisterInfo &MRI = MF.getRegInfo(); int64_t NumOfVReg = Amount / 8; - Register VL = MRI.createVirtualRegister(&RISCV::GPRRegClass); - BuildMI(MBB, II, DL, get(RISCV::PseudoReadVLENB), VL) - .setMIFlag(Flag); + BuildMI(MBB, II, DL, get(RISCV::PseudoReadVLENB), DestReg).setMIFlag(Flag); assert(isInt<32>(NumOfVReg) && "Expect the number of vector registers within 32-bits."); if (isPowerOf2_32(NumOfVReg)) { uint32_t ShiftAmount = Log2_32(NumOfVReg); if (ShiftAmount == 0) - return VL; - BuildMI(MBB, II, DL, get(RISCV::SLLI), VL) - .addReg(VL, RegState::Kill) + return; + BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg) + .addReg(DestReg, RegState::Kill) .addImm(ShiftAmount) .setMIFlag(Flag); } else if (STI.hasStdExtZba() && @@ -1999,35 +1997,35 @@ llvm_unreachable("Unexpected number of vregs"); } if (ShiftAmount) - BuildMI(MBB, II, DL, get(RISCV::SLLI), VL) - .addReg(VL, RegState::Kill) + BuildMI(MBB, II, DL, get(RISCV::SLLI), DestReg) + .addReg(DestReg, RegState::Kill) .addImm(ShiftAmount) .setMIFlag(Flag); - BuildMI(MBB, II, DL, get(Opc), VL) - .addReg(VL, RegState::Kill) - .addReg(VL) + BuildMI(MBB, II, DL, get(Opc), DestReg) + .addReg(DestReg, RegState::Kill) + .addReg(DestReg) .setMIFlag(Flag); } else if (isPowerOf2_32(NumOfVReg - 1)) { Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass); uint32_t ShiftAmount = Log2_32(NumOfVReg - 1); BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister) - .addReg(VL) + .addReg(DestReg) .addImm(ShiftAmount) .setMIFlag(Flag); - BuildMI(MBB, II, DL, get(RISCV::ADD), VL) + BuildMI(MBB, II, DL, get(RISCV::ADD), DestReg) .addReg(ScaledRegister, RegState::Kill) - .addReg(VL, RegState::Kill) + .addReg(DestReg, RegState::Kill) .setMIFlag(Flag); } else if (isPowerOf2_32(NumOfVReg + 1)) { Register ScaledRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass); uint32_t ShiftAmount = Log2_32(NumOfVReg + 1); BuildMI(MBB, II, DL, get(RISCV::SLLI), ScaledRegister) - .addReg(VL) + .addReg(DestReg) .addImm(ShiftAmount) .setMIFlag(Flag); - BuildMI(MBB, II, DL, get(RISCV::SUB), VL) + BuildMI(MBB, II, DL, get(RISCV::SUB), DestReg) .addReg(ScaledRegister, RegState::Kill) - .addReg(VL, RegState::Kill) + .addReg(DestReg, RegState::Kill) .setMIFlag(Flag); } else { Register N = MRI.createVirtualRegister(&RISCV::GPRRegClass); @@ -2037,13 +2035,11 @@ MF.getFunction(), "M- or Zmmul-extension must be enabled to calculate the vscaled size/" "offset."}); - BuildMI(MBB, II, DL, get(RISCV::MUL), VL) - .addReg(VL, RegState::Kill) + BuildMI(MBB, II, DL, get(RISCV::MUL), DestReg) + .addReg(DestReg, RegState::Kill) .addReg(N, RegState::Kill) .setMIFlag(Flag); } - - return VL; } // Returns true if this is the sext.w pattern, addiw rd, rs1, 0. diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.cpp @@ -196,8 +196,9 @@ ScalableAdjOpc = RISCV::SUB; } // 1. Get vlenb && multiply vlen with the number of vector registers. - ScalableFactorRegister = - TII->getVLENFactoredAmount(MF, MBB, II, DL, ScalableValue); + ScalableFactorRegister = MRI.createVirtualRegister(&RISCV::GPRRegClass); + TII->getVLENFactoredAmount(MF, MBB, II, DL, ScalableFactorRegister, + ScalableValue); } if (!isInt<12>(Offset.getFixed())) {