diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -362,6 +362,8 @@ addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) .addReg(Xmm); } + // Fill in the palette first. + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOV8mi)), SS).addImm(1); return true; } diff --git a/llvm/lib/Target/X86/X86TileConfig.cpp b/llvm/lib/Target/X86/X86TileConfig.cpp --- a/llvm/lib/Target/X86/X86TileConfig.cpp +++ b/llvm/lib/Target/X86/X86TileConfig.cpp @@ -22,9 +22,7 @@ #include "X86MachineFunctionInfo.h" #include "X86RegisterInfo.h" #include "X86Subtarget.h" -#include "llvm/ADT/PostOrderIterator.h" #include "llvm/CodeGen/LiveIntervals.h" -#include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstr.h" @@ -32,8 +30,6 @@ #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetInstrInfo.h" #include "llvm/CodeGen/TargetRegisterInfo.h" -#include "llvm/CodeGen/TileShapeInfo.h" -#include "llvm/CodeGen/VirtRegMap.h" #include "llvm/InitializePasses.h" using namespace llvm; @@ -42,28 +38,19 @@ namespace { -class X86TileConfig : public MachineFunctionPass { - // context - MachineFunction *MF = nullptr; - const X86Subtarget *ST = nullptr; - const TargetRegisterInfo *TRI; - const TargetInstrInfo *TII; - MachineDominatorTree *DomTree = nullptr; - MachineRegisterInfo *MRI = nullptr; - VirtRegMap *VRM = nullptr; - LiveIntervals *LIS = nullptr; - - MachineInstr *getTileConfigPoint(); - void tileConfig(); - -public: +struct X86TileConfig : public MachineFunctionPass { + X86TileConfig() : MachineFunctionPass(ID) {} /// Return the pass name. StringRef getPassName() const override { return "Tile Register Configure"; } /// X86TileConfig analysis usage. - void getAnalysisUsage(AnalysisUsage &AU) const override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired(); + MachineFunctionPass::getAnalysisUsage(AU); + } /// Perform register allocation. bool runOnMachineFunction(MachineFunction &mf) override; @@ -82,168 +69,99 @@ INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure", false, false) -INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) -INITIALIZE_PASS_DEPENDENCY(VirtRegMap) INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure", false, false) -void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired(); - AU.addRequired(); - AU.addPreserved(); - AU.addRequired(); - AU.setPreservesAll(); - MachineFunctionPass::getAnalysisUsage(AU); -} - -static unsigned getTilePhysRegIndex(Register PhysReg) { - assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) && - "Tile register number is invalid"); - return (PhysReg - X86::TMM0); -} - -static MachineInstr * -storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, - Register SrcReg, unsigned BitSize, int FrameIdx, int Offset, - const TargetInstrInfo *TII, const TargetRegisterClass *RC, - const TargetRegisterInfo *TRI) { - - unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit; - unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr; - if (BitSize == TRI->getRegSizeInBits(*RC)) - SubIdx = 0; - MachineInstr *NewMI = - addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx, - Offset) - .addReg(SrcReg, 0, SubIdx); - return NewMI; -} - -static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB, - MachineBasicBlock::iterator MI, - int64_t Imm, unsigned BitSize, - int FrameIdx, int Offset, - const TargetInstrInfo *TII) { - unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi; - return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), - FrameIdx, Offset) - .addImm(Imm); -} - -MachineInstr *X86TileConfig::getTileConfigPoint() { - MachineBasicBlock *Entry = &*MF->begin(); - ReversePostOrderTraversal RPOT(Entry); - for (MachineBasicBlock *MBB : RPOT) { - for (MachineInstr &MI : *MBB) - // Refer X86PreTileConfig.cpp. - // We only support one tile config for now. The other ldtilecfg - // is for spill purpose and is dominated by the first ldtilecfg. - if (MI.getOpcode() == X86::LDTILECFG) - return &MI; +bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { + const X86Subtarget &ST = MF.getSubtarget(); + const TargetRegisterInfo *TRI = ST.getRegisterInfo(); + const TargetInstrInfo *TII = ST.getInstrInfo(); + MachineRegisterInfo &MRI = MF.getRegInfo(); + LiveIntervals *LIS = &getAnalysis(); + + int SS = INT_MAX; + MachineInstr *ConstMI = nullptr; + for (MachineBasicBlock &MBB : MF) { + for (MachineInstr &MI : MBB) { + if (MI.getOpcode() == X86::LDTILECFG) { + // We need to distinguish with legacy interface. + if (!ConstMI) + return false; + SS = MI.getOperand(0).getIndex(); + assert(SS == ConstMI->getOperand(0).getIndex() && + "Found a wrong point"); + break; + } + if (!ConstMI && MI.getOpcode() == X86::MOV8mi && + MI.getOperand(1).getImm() == 1) + ConstMI = &MI; + } + if (!ConstMI) + return false; + if (SS != INT_MAX) + break; } - return nullptr; -} - -void X86TileConfig::tileConfig() { - MachineInstr *MI = getTileConfigPoint(); - if (!MI) - return; - MachineBasicBlock *MBB = MI->getParent(); - int SS = MI->getOperand(0).getIndex(); - BitVector PhysRegs(TRI->getNumRegs()); - - // Fill in the palette first. - auto *NewMI = storeImmToStackSlot(*MBB, *MI, 1, 8, SS, 0, TII); - LIS->InsertMachineInstrInMaps(*NewMI); // Fill in the shape of each tile physical register. - for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { - Register VirtReg = Register::index2VirtReg(i); - if (MRI->reg_nodbg_empty(VirtReg)) - continue; - const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); - if (RC.getID() != X86::TILERegClassID) - continue; - Register PhysReg = VRM->getPhys(VirtReg); - if (PhysRegs.test(PhysReg)) - continue; - PhysRegs.set(PhysReg); - ShapeT Shape = VRM->getShape(VirtReg); - Register RowReg = Shape.getRow()->getReg(); - Register ColReg = Shape.getCol()->getReg(); - - // Here is the data format for the tile config. - // 0 palette - // 1 start_row - // 2-15 reserved, must be zero - // 16-17 tile0.colsb Tile 0 bytes per row. - // 18-19 tile1.colsb Tile 1 bytes per row. - // 20-21 tile2.colsb Tile 2 bytes per row. - // ... (sequence continues) - // 30-31 tile7.colsb Tile 7 bytes per row. - // 32-47 reserved, must be zero - // 48 tile0.rows Tile 0 rows. - // 49 tile1.rows Tile 1 rows. - // 50 tile2.rows Tile 2 rows. - // ... (sequence continues) - // 55 tile7.rows Tile 7 rows. - // 56-63 reserved, must be zero - unsigned Index = getTilePhysRegIndex(PhysReg); - int RowOffset = 48 + Index; - int ColOffset = 16 + Index * 2; - - unsigned BitSize = 8; - for (const auto &Pair : {std::make_pair(RowReg, RowOffset), - std::make_pair(ColReg, ColOffset)}) { - int64_t Imm; - int ImmCount = 0; - // All def must be the same value, otherwise it is invalid MIs. - // Immediate is prefered. - for (const MachineOperand &MO : MRI->def_operands(Pair.first)) { - const auto *Inst = MO.getParent(); - if (Inst->isMoveImmediate()) { - ImmCount++; - Imm = Inst->getOperand(1).getImm(); - break; - } - } - auto StoreConfig = [&](int Offset) { + const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); + for (unsigned I = 0; + I < RC->getNumRegs() && !MRI.reg_nodbg_empty(X86::TMM0 + I); ++I) { + for (const MachineOperand &MO : MRI.def_operands(X86::TMM0 + I)) { + const MachineInstr *MI = MO.getParent(); + if (MI->getOpcode() == X86::LDTILECFG) + continue; + auto StackOffsetFixup = [&](bool IsRow) { + // Here is the data format for the tile config. + // 0 palette + // 1 start_row + // 2-15 reserved, must be zero + // 16-17 tile0.colsb Tile 0 bytes per row. + // 18-19 tile1.colsb Tile 1 bytes per row. + // 20-21 tile2.colsb Tile 2 bytes per row. + // ... (sequence continues) + // 30-31 tile7.colsb Tile 7 bytes per row. + // 32-47 reserved, must be zero + // 48 tile0.rows Tile 0 rows. + // 49 tile1.rows Tile 1 rows. + // 50 tile2.rows Tile 2 rows. + // ... (sequence continues) + // 55 tile7.rows Tile 7 rows. + // 56-63 reserved, must be zero MachineInstr *NewMI = nullptr; - if (ImmCount) - NewMI = storeImmToStackSlot(*MBB, *MI, Imm, BitSize, SS, Offset, TII); - else { - const TargetRegisterClass *RC = MRI->getRegClass(Pair.first); - NewMI = storeRegToStackSlot(*MBB, *MI, Pair.first, BitSize, SS, - Offset, TII, RC, TRI); + int Offset = IsRow ? 48 + I : 16 + I * 2; + Register ShapeReg = MI->getOperand(IsRow ? 1 : 2).getReg(); + LiveInterval &LI = LIS->getInterval(ShapeReg); + VNInfo *VNI = LI.getVNInfoBefore(LIS->getInstructionIndex(*MI)); + MachineInstr *ShapeMI = LIS->getInstructionFromIndex(VNI->def); + MachineBasicBlock *MBB = ShapeMI->getParent(); + if (ShapeMI->isMoveImmediate()) { + auto MIB = BuildMI(MF.front(), *ConstMI, DebugLoc(), + TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)); + NewMI = addFrameReference(MIB, SS, Offset) + .addImm(ShapeMI->getOperand(1).getImm()); + } else { + unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit; + unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(ShapeReg)); + if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16)) + SubIdx = 0; + auto MIB = BuildMI(*MBB, ++ShapeMI->getIterator(), DebugLoc(), + TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)); + NewMI = + addFrameReference(MIB, SS, Offset).addReg(ShapeReg, 0, SubIdx); } SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI); - if (!ImmCount) { - // Extend the live interval. + if (!ShapeMI->isMoveImmediate()) { SmallVector EndPoints = {SIdx.getRegSlot()}; - LiveInterval &Int = LIS->getInterval(Pair.first); + LiveInterval &Int = LIS->getInterval(ShapeReg); LIS->extendToIndices(Int, EndPoints); } }; - StoreConfig(Pair.second); - BitSize += 8; + StackOffsetFixup(true); + StackOffsetFixup(false); + // The same AMX register always have the same shape. + break; } } -} - -bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) { - MF = &mf; - MRI = &mf.getRegInfo(); - ST = &mf.getSubtarget(); - TRI = ST->getRegisterInfo(); - TII = mf.getSubtarget().getInstrInfo(); - DomTree = &getAnalysis(); - VRM = &getAnalysis(); - LIS = &getAnalysis(); - - if (VRM->isShapeMapEmpty()) - return false; - - tileConfig(); return true; }