diff --git a/llvm/lib/Target/X86/X86InstrInfo.h b/llvm/lib/Target/X86/X86InstrInfo.h --- a/llvm/lib/Target/X86/X86InstrInfo.h +++ b/llvm/lib/Target/X86/X86InstrInfo.h @@ -374,6 +374,10 @@ int FrameIndex, const TargetRegisterClass *RC, const TargetRegisterInfo *TRI) const override; + void loadStoreTileReg(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, + unsigned Opc, Register Reg, int FrameIdx, + bool isKill = false) const; + bool expandPostRAPseudo(MachineInstr &MI) const override; /// Check whether the target can fold a load that feeds a subreg operand diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -3550,10 +3550,23 @@ return None; } +static unsigned getLoadStoreOpcodeForFP16(bool Load, const X86Subtarget &STI) { + if (STI.hasFP16()) + return Load ? X86::VMOVSHZrm_alt : X86::VMOVSHZmr; + if (Load) + return STI.hasAVX512() ? X86::VMOVSSZrm + : STI.hasAVX() ? X86::VMOVSSrm + : X86::MOVSSrm; + else + return STI.hasAVX512() ? X86::VMOVSSZmr + : STI.hasAVX() ? X86::VMOVSSmr + : X86::MOVSSmr; +} + static unsigned getLoadStoreRegOpcode(Register Reg, const TargetRegisterClass *RC, bool IsStackAligned, - const X86Subtarget &STI, bool load) { + const X86Subtarget &STI, bool Load) { bool HasAVX = STI.hasAVX(); bool HasAVX512 = STI.hasAVX512(); bool HasVLX = STI.hasVLX(); @@ -3567,18 +3580,18 @@ // Copying to or from a physical H register on x86-64 requires a NOREX // move. Otherwise use a normal move. if (isHReg(Reg) || X86::GR8_ABCD_HRegClass.hasSubClassEq(RC)) - return load ? X86::MOV8rm_NOREX : X86::MOV8mr_NOREX; - return load ? X86::MOV8rm : X86::MOV8mr; + return Load ? X86::MOV8rm_NOREX : X86::MOV8mr_NOREX; + return Load ? X86::MOV8rm : X86::MOV8mr; case 2: if (X86::VK16RegClass.hasSubClassEq(RC)) - return load ? X86::KMOVWkm : X86::KMOVWmk; + return Load ? X86::KMOVWkm : X86::KMOVWmk; assert(X86::GR16RegClass.hasSubClassEq(RC) && "Unknown 2-byte regclass"); - return load ? X86::MOV16rm : X86::MOV16mr; + return Load ? X86::MOV16rm : X86::MOV16mr; case 4: if (X86::GR32RegClass.hasSubClassEq(RC)) - return load ? X86::MOV32rm : X86::MOV32mr; + return Load ? X86::MOV32rm : X86::MOV32mr; if (X86::FR32XRegClass.hasSubClassEq(RC)) - return load ? + return Load ? (HasAVX512 ? X86::VMOVSSZrm_alt : HasAVX ? X86::VMOVSSrm_alt : X86::MOVSSrm_alt) : @@ -3586,10 +3599,10 @@ HasAVX ? X86::VMOVSSmr : X86::MOVSSmr); if (X86::RFP32RegClass.hasSubClassEq(RC)) - return load ? X86::LD_Fp32m : X86::ST_Fp32m; + return Load ? X86::LD_Fp32m : X86::ST_Fp32m; if (X86::VK32RegClass.hasSubClassEq(RC)) { assert(STI.hasBWI() && "KMOVD requires BWI"); - return load ? X86::KMOVDkm : X86::KMOVDmk; + return Load ? X86::KMOVDkm : X86::KMOVDmk; } // All of these mask pair classes have the same spill size, the same kind // of kmov instructions can be used with all of them. @@ -3598,17 +3611,16 @@ X86::VK4PAIRRegClass.hasSubClassEq(RC) || X86::VK8PAIRRegClass.hasSubClassEq(RC) || X86::VK16PAIRRegClass.hasSubClassEq(RC)) - return load ? X86::MASKPAIR16LOAD : X86::MASKPAIR16STORE; - if ((X86::FR16RegClass.hasSubClassEq(RC) || - X86::FR16XRegClass.hasSubClassEq(RC)) && - STI.hasFP16()) - return load ? X86::VMOVSHZrm_alt : X86::VMOVSHZmr; + return Load ? X86::MASKPAIR16LOAD : X86::MASKPAIR16STORE; + if (X86::FR16RegClass.hasSubClassEq(RC) || + X86::FR16XRegClass.hasSubClassEq(RC)) + return getLoadStoreOpcodeForFP16(Load, STI); llvm_unreachable("Unknown 4-byte regclass"); case 8: if (X86::GR64RegClass.hasSubClassEq(RC)) - return load ? X86::MOV64rm : X86::MOV64mr; + return Load ? X86::MOV64rm : X86::MOV64mr; if (X86::FR64XRegClass.hasSubClassEq(RC)) - return load ? + return Load ? (HasAVX512 ? X86::VMOVSDZrm_alt : HasAVX ? X86::VMOVSDrm_alt : X86::MOVSDrm_alt) : @@ -3616,22 +3628,22 @@ HasAVX ? X86::VMOVSDmr : X86::MOVSDmr); if (X86::VR64RegClass.hasSubClassEq(RC)) - return load ? X86::MMX_MOVQ64rm : X86::MMX_MOVQ64mr; + return Load ? X86::MMX_MOVQ64rm : X86::MMX_MOVQ64mr; if (X86::RFP64RegClass.hasSubClassEq(RC)) - return load ? X86::LD_Fp64m : X86::ST_Fp64m; + return Load ? X86::LD_Fp64m : X86::ST_Fp64m; if (X86::VK64RegClass.hasSubClassEq(RC)) { assert(STI.hasBWI() && "KMOVQ requires BWI"); - return load ? X86::KMOVQkm : X86::KMOVQmk; + return Load ? X86::KMOVQkm : X86::KMOVQmk; } llvm_unreachable("Unknown 8-byte regclass"); case 10: assert(X86::RFP80RegClass.hasSubClassEq(RC) && "Unknown 10-byte regclass"); - return load ? X86::LD_Fp80m : X86::ST_FpP80m; + return Load ? X86::LD_Fp80m : X86::ST_FpP80m; case 16: { if (X86::VR128XRegClass.hasSubClassEq(RC)) { // If stack is realigned we can use aligned stores. if (IsStackAligned) - return load ? + return Load ? (HasVLX ? X86::VMOVAPSZ128rm : HasAVX512 ? X86::VMOVAPSZ128rm_NOVLX : HasAVX ? X86::VMOVAPSrm : @@ -3641,7 +3653,7 @@ HasAVX ? X86::VMOVAPSmr : X86::MOVAPSmr); else - return load ? + return Load ? (HasVLX ? X86::VMOVUPSZ128rm : HasAVX512 ? X86::VMOVUPSZ128rm_NOVLX : HasAVX ? X86::VMOVUPSrm : @@ -3657,7 +3669,7 @@ assert(X86::VR256XRegClass.hasSubClassEq(RC) && "Unknown 32-byte regclass"); // If stack is realigned we can use aligned stores. if (IsStackAligned) - return load ? + return Load ? (HasVLX ? X86::VMOVAPSZ256rm : HasAVX512 ? X86::VMOVAPSZ256rm_NOVLX : X86::VMOVAPSYrm) : @@ -3665,7 +3677,7 @@ HasAVX512 ? X86::VMOVAPSZ256mr_NOVLX : X86::VMOVAPSYmr); else - return load ? + return Load ? (HasVLX ? X86::VMOVUPSZ256rm : HasAVX512 ? X86::VMOVUPSZ256rm_NOVLX : X86::VMOVUPSYrm) : @@ -3676,9 +3688,13 @@ assert(X86::VR512RegClass.hasSubClassEq(RC) && "Unknown 64-byte regclass"); assert(STI.hasAVX512() && "Using 512-bit register requires AVX512"); if (IsStackAligned) - return load ? X86::VMOVAPSZrm : X86::VMOVAPSZmr; + return Load ? X86::VMOVAPSZrm : X86::VMOVAPSZmr; else - return load ? X86::VMOVUPSZrm : X86::VMOVUPSZmr; + return Load ? X86::VMOVUPSZrm : X86::VMOVUPSZmr; + case 1024: + assert(X86::TILERegClass.hasSubClassEq(RC) && "Unknown 1024-byte regclass"); + assert(STI.hasAMXTILE() && "Using 8*1024-bit register requires AMX-TILE"); + return Load ? X86::TILELOADD : X86::TILESTORED; } } @@ -3835,44 +3851,73 @@ return getLoadStoreRegOpcode(DestReg, RC, IsStackAligned, STI, true); } +static bool isAMXOpcode(unsigned Opc) { + switch (Opc) { + default: + return false; + case X86::TILELOADD: + case X86::TILESTORED: + return true; + } +} + +void X86InstrInfo::loadStoreTileReg(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MI, + unsigned Opc, Register Reg, int FrameIdx, + bool isKill) const { + switch (Opc) { + default: + llvm_unreachable("Unexpected special opcode!"); + case X86::TILESTORED: { + // tilestored %tmm, (%sp, %idx) + MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); + Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); + BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64); + MachineInstr *NewMI = + addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx) + .addReg(Reg, getKillRegState(isKill)); + MachineOperand &MO = NewMI->getOperand(X86::AddrIndexReg); + MO.setReg(VirtReg); + MO.setIsKill(true); + break; + } + case X86::TILELOADD: { + // tileloadd (%sp, %idx), %tmm + MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); + Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); + BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64); + MachineInstr *NewMI = addFrameReference( + BuildMI(MBB, MI, DebugLoc(), get(Opc), Reg), FrameIdx); + MachineOperand &MO = NewMI->getOperand(1 + X86::AddrIndexReg); + MO.setReg(VirtReg); + MO.setIsKill(true); + break; + } + } +} + void X86InstrInfo::storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, - Register SrcReg, bool isKill, int FrameIdx, + Register SrcReg, bool isKill, + int FrameIdx, const TargetRegisterClass *RC, const TargetRegisterInfo *TRI) const { const MachineFunction &MF = *MBB.getParent(); const MachineFrameInfo &MFI = MF.getFrameInfo(); - MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); assert(MFI.getObjectSize(FrameIdx) >= TRI->getSpillSize(*RC) && "Stack slot too small for store"); - if (RC->getID() == X86::TILERegClassID) { - unsigned Opc = X86::TILESTORED; - // tilestored %tmm, (%sp, %idx) - Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); - BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64); - MachineInstr *NewMI = - addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx) - .addReg(SrcReg, getKillRegState(isKill)); - MachineOperand &MO = NewMI->getOperand(2); - MO.setReg(VirtReg); - MO.setIsKill(true); - } else if ((RC->getID() == X86::FR16RegClassID || - RC->getID() == X86::FR16XRegClassID) && - !Subtarget.hasFP16()) { - unsigned Opc = Subtarget.hasAVX512() ? X86::VMOVSSZmr - : Subtarget.hasAVX() ? X86::VMOVSSmr - : X86::MOVSSmr; - addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx) - .addReg(SrcReg, getKillRegState(isKill)); - } else { - unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); - bool isAligned = - (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || - (RI.canRealignStack(MF) && !MFI.isFixedObjectIndex(FrameIdx)); - unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget); + + unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); + bool isAligned = + (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || + (RI.canRealignStack(MF) && !MFI.isFixedObjectIndex(FrameIdx)); + + unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget); + if (isAMXOpcode(Opc)) + loadStoreTileReg(MBB, MI, Opc, SrcReg, FrameIdx, isKill); + else addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx) .addReg(SrcReg, getKillRegState(isKill)); - } } void X86InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB, @@ -3884,35 +3929,17 @@ const MachineFrameInfo &MFI = MF.getFrameInfo(); assert(MFI.getObjectSize(FrameIdx) >= TRI->getSpillSize(*RC) && "Load size exceeds stack slot"); - if (RC->getID() == X86::TILERegClassID) { - unsigned Opc = X86::TILELOADD; - // tileloadd (%sp, %idx), %tmm - MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); - Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); - MachineInstr *NewMI = - BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64); - NewMI = addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), - FrameIdx); - MachineOperand &MO = NewMI->getOperand(3); - MO.setReg(VirtReg); - MO.setIsKill(true); - } else if ((RC->getID() == X86::FR16RegClassID || - RC->getID() == X86::FR16XRegClassID) && - !Subtarget.hasFP16()) { - unsigned Opc = Subtarget.hasAVX512() ? X86::VMOVSSZrm - : Subtarget.hasAVX() ? X86::VMOVSSrm - : X86::MOVSSrm; - addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), - FrameIdx); - } else { - unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); - bool isAligned = - (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || - (RI.canRealignStack(MF) && !MFI.isFixedObjectIndex(FrameIdx)); - unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget); + unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); + bool isAligned = + (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || + (RI.canRealignStack(MF) && !MFI.isFixedObjectIndex(FrameIdx)); + + unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget); + if (isAMXOpcode(Opc)) + loadStoreTileReg(MBB, MI, Opc, DestReg, FrameIdx); + else addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), FrameIdx); - } } bool X86InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg, diff --git a/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi.mir b/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi.mir --- a/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi.mir +++ b/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi.mir @@ -1,5 +1,5 @@ # NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py -# RUN: llc -mtriple=x86_64-- -run-pass=fastpretileconfig -o - %s | FileCheck %s +# RUN: llc -mtriple=x86_64-- -mattr=+amx-tile -run-pass=fastpretileconfig -o - %s | FileCheck %s # # This case test tile phi is nested accessed, but the its def block is # not visited yet. diff --git a/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi2.mir b/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi2.mir --- a/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi2.mir +++ b/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi2.mir @@ -1,5 +1,5 @@ # NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py -# RUN: llc -mtriple=x86_64-- -run-pass=fastpretileconfig -o - %s | FileCheck %s +# RUN: llc -mtriple=x86_64-- -mattr=+amx-tile -run-pass=fastpretileconfig -o - %s | FileCheck %s # # bb.0 # def %0 diff --git a/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi4.mir b/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi4.mir --- a/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi4.mir +++ b/llvm/test/CodeGen/X86/AMX/amx-fastconfig-phi4.mir @@ -1,5 +1,5 @@ # NOTE: Assertions have been autogenerated by utils/update_mir_test_checks.py -# RUN: llc -mtriple=x86_64-- -run-pass=fastpretileconfig -o - %s | FileCheck %s +# RUN: llc -mtriple=x86_64-- -mattr=+amx-tile -run-pass=fastpretileconfig -o - %s | FileCheck %s # # bb.0 # def %0