diff --git a/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp --- a/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXFrameLowering.cpp @@ -36,6 +36,9 @@ MachineInstr *MI = &MBB.front(); MachineRegisterInfo &MR = MF.getRegInfo(); + const NVPTXRegisterInfo *NRI = + MF.getSubtarget().getRegisterInfo(); + // This instruction really occurs before first instruction // in the BB, so giving it no debug location. DebugLoc dl = DebugLoc(); @@ -50,15 +53,15 @@ (Is64Bit ? NVPTX::cvta_local_yes_64 : NVPTX::cvta_local_yes); unsigned MovDepotOpcode = (Is64Bit ? NVPTX::MOV_DEPOT_ADDR_64 : NVPTX::MOV_DEPOT_ADDR); - if (!MR.use_empty(NVPTX::VRFrame)) { + if (!MR.use_empty(NRI->getFrameRegister(MF))) { // If %SP is not used, do not bother emitting "cvta.local %SP, %SPL". MI = BuildMI(MBB, MI, dl, MF.getSubtarget().getInstrInfo()->get(CvtaLocalOpcode), - NVPTX::VRFrame) - .addReg(NVPTX::VRFrameLocal); + NRI->getFrameRegister(MF)) + .addReg(NRI->getFrameLocalRegister(MF)); } BuildMI(MBB, MI, dl, MF.getSubtarget().getInstrInfo()->get(MovDepotOpcode), - NVPTX::VRFrameLocal) + NRI->getFrameLocalRegister(MF)) .addImm(MF.getFunctionNumber()); } } diff --git a/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp b/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp --- a/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXPeephole.cpp @@ -21,17 +21,19 @@ // This peephole pass optimizes these cases, for example // // It will transform the following pattern -// %0 = LEA_ADDRi64 %VRFrame, 4 +// %0 = LEA_ADDRi64 %VRFrame64, 4 // %1 = cvta_to_local_yes_64 %0 // // into -// %1 = LEA_ADDRi64 %VRFrameLocal, 4 +// %1 = LEA_ADDRi64 %VRFrameLocal64, 4 // -// %VRFrameLocal is the virtual register name of %SPL +// %VRFrameLocal64 is the virtual register name of %SPL // //===----------------------------------------------------------------------===// #include "NVPTX.h" +#include "NVPTXRegisterInfo.h" +#include "NVPTXSubtarget.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineRegisterInfo.h" @@ -92,9 +94,12 @@ return false; } + const NVPTXRegisterInfo *NRI = + MF.getSubtarget().getRegisterInfo(); + // Check the LEA_ADDRi operand is Frame index auto &BaseAddrOp = GenericAddrDef->getOperand(1); - if (BaseAddrOp.isReg() && BaseAddrOp.getReg() == NVPTX::VRFrame) { + if (BaseAddrOp.isReg() && BaseAddrOp.getReg() == NRI->getFrameRegister(MF)) { return true; } @@ -108,10 +113,13 @@ const TargetInstrInfo *TII = MF.getSubtarget().getInstrInfo(); auto &Prev = *MRI.getUniqueVRegDef(Root.getOperand(1).getReg()); + const NVPTXRegisterInfo *NRI = + MF.getSubtarget().getRegisterInfo(); + MachineInstrBuilder MIB = BuildMI(MF, Root.getDebugLoc(), TII->get(Prev.getOpcode()), Root.getOperand(0).getReg()) - .addReg(NVPTX::VRFrameLocal) + .addReg(NRI->getFrameLocalRegister(MF)) .add(Prev.getOperand(2)); MBB.insert((MachineBasicBlock::iterator)&Root, MIB); @@ -142,10 +150,13 @@ } // Instruction } // Basic Block + const NVPTXRegisterInfo *NRI = + MF.getSubtarget().getRegisterInfo(); + // Remove unnecessary %VRFrame = cvta.local %VRFrameLocal const auto &MRI = MF.getRegInfo(); - if (MRI.use_empty(NVPTX::VRFrame)) { - if (auto MI = MRI.getUniqueVRegDef(NVPTX::VRFrame)) { + if (MRI.use_empty(NRI->getFrameRegister(MF))) { + if (auto MI = MRI.getUniqueVRegDef(NRI->getFrameRegister(MF))) { MI->eraseFromParentAndMarkDBGValuesForRemoval(); } } diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.h @@ -43,6 +43,7 @@ RegScavenger *RS = nullptr) const override; Register getFrameRegister(const MachineFunction &MF) const override; + Register getFrameLocalRegister(const MachineFunction &MF) const; ManagedStringPool *getStrPool() const { return const_cast(&ManagedStrPool); diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp @@ -13,6 +13,7 @@ #include "NVPTXRegisterInfo.h" #include "NVPTX.h" #include "NVPTXSubtarget.h" +#include "NVPTXTargetMachine.h" #include "llvm/ADT/BitVector.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" @@ -122,10 +123,19 @@ MI.getOperand(FIOperandNum + 1).getImm(); // Using I0 as the frame pointer - MI.getOperand(FIOperandNum).ChangeToRegister(NVPTX::VRFrame, false); + MI.getOperand(FIOperandNum).ChangeToRegister(getFrameRegister(MF), false); MI.getOperand(FIOperandNum + 1).ChangeToImmediate(Offset); } Register NVPTXRegisterInfo::getFrameRegister(const MachineFunction &MF) const { - return NVPTX::VRFrame; + const NVPTXTargetMachine &TM = + static_cast(MF.getTarget()); + return TM.is64Bit() ? NVPTX::VRFrame64 : NVPTX::VRFrame32; +} + +Register +NVPTXRegisterInfo::getFrameLocalRegister(const MachineFunction &MF) const { + const NVPTXTargetMachine &TM = + static_cast(MF.getTarget()); + return TM.is64Bit() ? NVPTX::VRFrameLocal64 : NVPTX::VRFrameLocal32; } diff --git a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td --- a/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td @@ -22,8 +22,10 @@ //===----------------------------------------------------------------------===// // Special Registers used as stack pointer -def VRFrame : NVPTXReg<"%SP">; -def VRFrameLocal : NVPTXReg<"%SPL">; +def VRFrame32 : NVPTXReg<"%SP">; +def VRFrame64 : NVPTXReg<"%SP">; +def VRFrameLocal32 : NVPTXReg<"%SPL">; +def VRFrameLocal64 : NVPTXReg<"%SPL">; // Special Registers used as the stack def VRDepot : NVPTXReg<"%Depot">; @@ -56,8 +58,8 @@ //===----------------------------------------------------------------------===// def Int1Regs : NVPTXRegClass<[i1], 8, (add (sequence "P%u", 0, 4))>; def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>; -def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4))>; -def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4))>; +def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4), VRFrame32, VRFrameLocal32)>; +def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>; def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>; def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>; def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>; @@ -68,5 +70,5 @@ def Float64ArgRegs : NVPTXRegClass<[f64], 64, (add (sequence "da%u", 0, 4))>; // Read NVPTXRegisterInfo.cpp to see how VRFrame and VRDepot are used. -def SpecialRegs : NVPTXRegClass<[i32], 32, (add VRFrame, VRFrameLocal, VRDepot, +def SpecialRegs : NVPTXRegClass<[i32], 32, (add VRFrame32, VRFrameLocal32, VRDepot, (sequence "ENVREG%u", 0, 31))>; diff --git a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll --- a/llvm/test/CodeGen/NVPTX/local-stack-frame.ll +++ b/llvm/test/CodeGen/NVPTX/local-stack-frame.ll @@ -1,5 +1,5 @@ -; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s --check-prefix=PTX32 -; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 | FileCheck %s --check-prefix=PTX64 +; RUN: llc < %s -march=nvptx -mcpu=sm_20 -verify-machineinstrs | FileCheck %s --check-prefix=PTX32 +; RUN: llc < %s -march=nvptx64 -mcpu=sm_20 -verify-machineinstrs | FileCheck %s --check-prefix=PTX64 ; Ensure we access the local stack properly