diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -6,31 +6,20 @@ // //===----------------------------------------------------------------------===// // -/// \file Pass to pre-config the shape of AMX register -/// AMX register need to be configured before use. The shape of AMX register -/// is encoded in the 1st and 2nd machine operand of AMX pseudo instructions. -/// The pldtilecfg is to config tile registers. It should dominator all AMX -/// instructions. The pldtilecfg produce a virtual cfg register and the cfg -/// register is used by all AMX instructions. -/// This pass is to find the common dominator of all AMX instructions and -/// insert the pldtilecfg instruction. Besides the cfg register that pldtilecfg -/// produces is inserted as the last operand of each AMX instruction. We use -/// this scheme to model the def-use relationship between AMX config instruction -/// and other AMX instructions. Below is an example. +/// \file Pass to pre-config the shapes of AMX registers +/// AMX register needs to be configured before use. The shapes of AMX register +/// are encoded in the 1st and 2nd machine operand of AMX pseudo instructions. /// -/// ----B1---- -/// / \ -/// / \ -/// B2 B3 -/// %1:tile = PTILELOADDV %2:tile = PTILELOADDV +/// The instruction ldtilecfg is used to config the shapes. It must be reachable +/// for all variable shapes. ldtilecfg will be inserted more than once if we +/// cannot find a dominating point for all AMX instructions. /// -/// is transformed to +/// The shape register is caller saved according to ABI. We need to insert +/// ldtilecfg again after the call instruction if callee clobbers any AMX +/// registers. /// -/// B1 -/// %25:tilecfg = PLDTILECFG -/// / \ -/// / \ -/// %1:tile = PTILELOADDV %25 %2:tile = PTILELOADDV %25 +/// This pass calculates all points that ldtilecfg need to be inserted to and +/// insert them. It reports error if the reachability conditions aren't met. // //===----------------------------------------------------------------------===// @@ -38,9 +27,9 @@ #include "X86InstrBuilder.h" #include "X86RegisterInfo.h" #include "X86Subtarget.h" -#include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineLoopInfo.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetInstrInfo.h" @@ -51,19 +40,93 @@ using namespace llvm; #define DEBUG_TYPE "tile-pre-config" +#define ASSERT_VALID_COMPARE \ + assert((!MBB || !RHS.MBB || MBB == RHS.MBB) && \ + "Cannot compare between different BBs"); +#define REPORT_CONFIG_FAIL \ + report_fatal_error( \ + MF.getName() + \ + ": Failed to config tile register, please define the shape earlier"); namespace { +struct MIRef { + MachineInstr *MI = nullptr; + MachineBasicBlock *MBB = nullptr; + size_t Pos = 0; /* A virtual position for instr will be inserted after MI */ + MIRef() = default; + MIRef(MachineBasicBlock *MBB) : MBB(MBB) { + for (auto I = MBB->begin(), E = MBB->end(); I != E && I->isPHI(); + ++I, ++Pos) + MI = &*I; + } + MIRef(MachineInstr *MI, MachineBasicBlock *MBB) + : MI(MI), MBB(MBB), + Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} + MIRef(MachineInstr *MI, MachineBasicBlock *MBB, size_t Pos) + : MI(MI), MBB(MBB), Pos(Pos) {} + MachineInstr *operator->() { return MI; } + operator bool() const { return MBB != nullptr; } + bool operator==(const MIRef &RHS) const { + return MI == RHS.MI && MBB == RHS.MBB; + } + bool operator<(const MIRef &RHS) const { + ASSERT_VALID_COMPARE; + return Pos < RHS.Pos; + } + bool operator>(const MIRef &RHS) const { + ASSERT_VALID_COMPARE; + return Pos > RHS.Pos; + } +}; + +struct BBInfo { + MIRef FirstAMX; + MIRef LastCall; + MIRef LastShape; + bool NeedTileCfgLiveIn = false; + unsigned ShapeReachedCount = 0; +}; + class X86PreTileConfig : public MachineFunctionPass { - // context - MachineFunction *MF = nullptr; - const X86Subtarget *ST = nullptr; - const TargetRegisterInfo *TRI; - const TargetInstrInfo *TII; - MachineDominatorTree *DomTree = nullptr; - MachineRegisterInfo *MRI = nullptr; + MachineRegisterInfo *MRI; + SmallSet CfgNeedInsert; + SmallSet ShapeBBs; + SmallVector CfgLiveInBBs; + DenseMap BBVisitedInfo; - MachineInstr *getTileConfigPoint(); + /// Check if the callee will clobber AMX registers. + bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { + auto Iter = llvm::find_if( + MI.operands(), [](MachineOperand &MO) { return MO.isRegMask(); }); + if (Iter == MI.operands_end()) + return false; + UsableRegs.clearBitsInMask(Iter->getRegMask()); + return !UsableRegs.none(); + } + + /// Check if MI is AMX instruction. + bool isAMXInstruction(MachineInstr &MI) { + switch (MI.getOpcode()) { + default: + return false; + case X86::PTILESTOREDV: + collectShapeInfo(ShapeT(&MI.getOperand(0), &MI.getOperand(1), MRI)); + return true; + case X86::PTILELOADDV: + case X86::PTDPBSSDV: + case X86::PTDPBSUDV: + case X86::PTDPBUSDV: + case X86::PTDPBUUDV: + case X86::PTILEZEROV: + case X86::PTDPBF16PSV: + collectShapeInfo(ShapeT(&MI.getOperand(1), &MI.getOperand(2), MRI)); + return true; + } + } + + /// Collect the shape def information for later use. + void collectShapeInfo(ShapeT Shape); public: X86PreTileConfig() : MachineFunctionPass(ID) {} @@ -74,10 +137,22 @@ } /// X86PreTileConfig analysis usage. - void getAnalysisUsage(AnalysisUsage &AU) const override; + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequired(); + MachineFunctionPass::getAnalysisUsage(AU); + } - /// Perform register allocation. - bool runOnMachineFunction(MachineFunction &mf) override; + /// Clear MF related structures. + void releaseMemory() override { + ShapeBBs.clear(); + CfgLiveInBBs.clear(); + CfgNeedInsert.clear(); + BBVisitedInfo.clear(); + } + + /// Perform ldtilecfg instructions inserting. + bool runOnMachineFunction(MachineFunction &MF) override; static char ID; }; @@ -88,278 +163,210 @@ INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig", "Tile Register Pre-configure", false, false) -INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) +INITIALIZE_PASS_DEPENDENCY(MachineLoopInfo) INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig", "Tile Register Pre-configure", false, false) -void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequired(); - MachineFunctionPass::getAnalysisUsage(AU); -} - -static void buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx, - const TargetInstrInfo *TII, MachineRegisterInfo *MRI, - const X86Subtarget *ST) { - auto *MBB = MI->getParent(); - - // Zero stack slot. - if (ST->hasAVX512()) { - Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); - BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm) - .addReg(Zmm, RegState::Undef) - .addReg(Zmm, RegState::Undef); - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)), - FrameIdx) - .addReg(Zmm); - } else if (ST->hasAVX2()) { - Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); - BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORYrr), Ymm) - .addReg(Ymm, RegState::Undef) - .addReg(Ymm, RegState::Undef); - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), - FrameIdx) - .addReg(Ymm); - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSYmr)), - FrameIdx, 32) - .addReg(Ymm); - } else { - assert(ST->hasSSE2() && "AMX should assume SSE2 enabled"); - Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); - BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PXORrr), Xmm) - .addReg(Xmm, RegState::Undef) - .addReg(Xmm, RegState::Undef); - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), - FrameIdx) - .addReg(Xmm); - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), - FrameIdx, 16) - .addReg(Xmm); - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), - FrameIdx, 32) - .addReg(Xmm); - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::MOVUPSmr)), - FrameIdx, 48) - .addReg(Xmm); - } - - // build psuedo ldtilecfg - addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)), - FrameIdx); -} - -static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) { - unsigned Opcode = MI.getOpcode(); - switch (Opcode) { - default: - llvm_unreachable("Unexpected machine instruction on tile"); - case X86::PTILELOADDV: - case X86::PTDPBSSDV: - case X86::PTDPBSUDV: - case X86::PTDPBUSDV: - case X86::PTDPBUUDV: - case X86::PTILEZEROV: - case X86::PTDPBF16PSV: - MachineOperand &MO1 = const_cast(MI.getOperand(1)); - MachineOperand &MO2 = const_cast(MI.getOperand(2)); - ShapeT Shape(&MO1, &MO2, MRI); - return Shape; - } -} - -MachineInstr *X86PreTileConfig::getTileConfigPoint() { - DenseMap PhysShapeInfo; - MachineBasicBlock *MBB = nullptr; - DenseSet MIs; - for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { - Register VirtReg = Register::index2VirtReg(i); - if (MRI->reg_nodbg_empty(VirtReg)) - continue; - const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); - if (RC.getID() != X86::TILERegClassID) - continue; - - // Find the common dominator for all MI that define tile register. - for (const MachineOperand &MO : MRI->def_operands(VirtReg)) { +void X86PreTileConfig::collectShapeInfo(ShapeT Shape) { + auto RecordShapeInfo = [this](MachineInstr *MI) { + MachineBasicBlock *MBB = MI->getParent(); + MIRef MIR(MI, MBB); + if (BBVisitedInfo[MBB].LastShape < MIR) + BBVisitedInfo[MBB].LastShape = MIR; + ShapeBBs.insert(MBB); + }; + + for (auto *ShapeMO : {Shape.getRow(), Shape.getCol()}) { + Register ShapeReg = ShapeMO->getReg(); + for (MachineOperand &MO : MRI->def_operands(ShapeReg)) { if (MO.isUndef()) continue; - const auto *MI = MO.getParent(); - // PHI or IMPLICIT_DEF instructiion. - // There must be a input tile before PHI instruction. - if (MI->isTransient()) + MachineInstr *MI = MO.getParent(); + if (MI->isMoveImmediate()) continue; - if (!MBB) - MBB = const_cast(MI->getParent()); - MBB = DomTree->findNearestCommonDominator( - MBB, const_cast(MI->getParent())); - - // Collect the instructions that define shape. - ShapeT Shape = getShape(*MI, MRI); - std::array ShapeMOs = {Shape.getRow(), - Shape.getCol()}; - for (auto *ShapeMO : ShapeMOs) { - Register ShapeReg = ShapeMO->getReg(); - for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) { - const auto *ShapeMI = MO.getParent(); - MIs.insert(ShapeMI); + if (MI->isPHI()) { + for (unsigned Index = 0; Index < MI->getNumOperands() / 2; ++Index) { + Register DefReg = MI->getOperand(Index * 2 + 1).getReg(); + for (MachineOperand &MO2 : MRI->def_operands(DefReg)) { + if (MO2.isUndef()) + continue; + MachineInstr *MI2 = MO2.getParent(); + // We don't need to handle phi instruction recursively. Each phi + // instruction will be iterated once, so we have chance to record + // the recursive shapes later. + if (MI2->isMoveImmediate() || MI2->isPHI()) + continue; + RecordShapeInfo(MI2); + } } + continue; } + RecordShapeInfo(MI); } } - if (!MBB) - return nullptr; - // This pass is before the pass of eliminating PHI node, so it - // is in SSA form. - assert(MRI->isSSA() && "Not SSA form in pre-tile config"); - // Shape def should dominate tile config MBB. - // def s s1 s2 - // / \ \ / - // / \ \ / - // conf s3=phi(s1,s2) - // | - // c - // - for (const auto *MI : MIs) { - const MachineBasicBlock *ShapeMBB = MI->getParent(); - if (DomTree->dominates(ShapeMBB, MBB)) - continue; - if (MI->isMoveImmediate()) - continue; - report_fatal_error(MF->getName() + ": Failed to config tile register, " - "please define the shape earlier"); - } - - // ldtilecfg should be inserted after the MI that define the shape. - MachineBasicBlock::reverse_instr_iterator I, E; - for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) { - auto *MI = &*I; - if (MIs.count(MI) && (!MI->isMoveImmediate())) - break; - } - MachineBasicBlock::iterator MII; - if (I == E) - MII = MBB->getFirstNonPHI(); - else { - MII = MachineBasicBlock::iterator(&*I); - MII++; - } - return &*MII; } -static bool isAMXInstruction(MachineBasicBlock::iterator MII) { - switch (MII->getOpcode()) { - default: - return false; - case X86::PTILELOADDV: - case X86::PTILESTOREDV: - case X86::PTDPBSSDV: - case X86::PTDPBSUDV: - case X86::PTDPBUSDV: - case X86::PTDPBUUDV: - case X86::PTILEZEROV: - case X86::PTDPBF16PSV: - return true; +bool X86PreTileConfig::runOnMachineFunction(MachineFunction &MF) { + MRI = &MF.getRegInfo(); + const X86Subtarget &ST = MF.getSubtarget(); + const TargetInstrInfo *TII = ST.getInstrInfo(); + const TargetRegisterInfo *TRI = ST.getRegisterInfo(); + const MachineLoopInfo *MLI = &getAnalysis(); + const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); + + BitVector AMXRegs(TRI->getNumRegs()); + for (unsigned I = 0; I < RC->getNumRegs(); I++) + AMXRegs.set(X86::TMM0 + I); + + // Iterate MF to collect information. + for (auto &MBB : MF) { + size_t Pos = 0; + for (auto &MI : MBB) { + ++Pos; + if (isAMXInstruction(MI)) { + // If there's call before the AMX, we need to reload tile config. + if (BBVisitedInfo[&MBB].LastCall) + CfgNeedInsert.insert(BBVisitedInfo[&MBB].LastCall); + else /* Otherwise, we need tile config to live in this BB. */ + BBVisitedInfo[&MBB].NeedTileCfgLiveIn = true; + // Always record the first AMX in case there's shape def after it. + if (!BBVisitedInfo[&MBB].FirstAMX) + BBVisitedInfo[&MBB].FirstAMX = MIRef(&MI, &MBB, Pos); + } else if (MI.isCall() && isDestructiveCall(MI, AMXRegs)) { + // Record the call only if the callee clobbers all AMX registers. + BBVisitedInfo[&MBB].LastCall = MIRef(&MI, &MBB, Pos); + } + } + if (BBVisitedInfo[&MBB].NeedTileCfgLiveIn) { + if (&MBB == &MF.front()) + CfgNeedInsert.insert(MIRef(&MBB)); + else + CfgLiveInBBs.push_back(&MBB); + } } -} -struct BBInfo { - bool HasAMX = false; - bool HasCallBeforeAMX = false; - bool HasAMXBeforeCallInSuccs = false; - MachineInstr *LastCall = nullptr; - - BBInfo() = default; - BBInfo(SmallSet &CfgNeedInsert, MachineBasicBlock *MBB, - MachineInstr *MI = nullptr) { - MachineBasicBlock::iterator MII = MI ? MI->getIterator() : MBB->begin(); - for (auto E = MBB->end(); MII != E; ++MII) { - if (isAMXInstruction(MII)) { - HasAMX = true; - if (LastCall) - CfgNeedInsert.insert(LastCall); - } else if (MII->isCall()) { - LastCall = &*MII; - if (!HasAMX) - HasCallBeforeAMX = true; + // Update NeedTileCfgLiveIn for predecessors. + while (!CfgLiveInBBs.empty()) { + MachineBasicBlock *MBB = CfgLiveInBBs.pop_back_val(); + for (auto *Pred : MBB->predecessors()) { + if (BBVisitedInfo[Pred].LastCall) { + CfgNeedInsert.insert(BBVisitedInfo[Pred].LastCall); + } else if (!BBVisitedInfo[Pred].NeedTileCfgLiveIn) { + BBVisitedInfo[Pred].NeedTileCfgLiveIn = true; + if (Pred == &MF.front()) + CfgNeedInsert.insert(MIRef(Pred)); + else + CfgLiveInBBs.push_back(Pred); } } } -}; - -static void reloadTileConfig(MachineInstr *MI, int FI, - const TargetInstrInfo *TII, - const TargetRegisterInfo *TRI) { - SmallSet CfgNeedInsert; - SmallVector WorkList; - DenseMap BBVisitedInfo; - MachineBasicBlock *MBB = MI->getParent(); - BBVisitedInfo[MBB] = BBInfo(CfgNeedInsert, MBB, MI); + // There's no AMX instruction if we didn't find a tile config live in point. + if (CfgNeedInsert.empty()) + return false; - WorkList.push_back(MBB); - while (!WorkList.empty()) { - MBB = WorkList.pop_back_val(); - for (auto I = MBB->succ_begin(), E = MBB->succ_end(); I != E; ++I) { - if (!BBVisitedInfo.count(*I)) { - BBVisitedInfo[*I] = BBInfo(CfgNeedInsert, *I); - WorkList.push_back(*I); - } + // Calculate how many times the ShapeBB can reach to this BB. + unsigned ShapeBBNum = 0; + for (auto *MBB : ShapeBBs) { + SmallSet VistedBB; + SmallVector WorkList({MBB}); + while (!WorkList.empty()) { + MachineBasicBlock *MBB = WorkList.pop_back_val(); + ++BBVisitedInfo[MBB].ShapeReachedCount; + for (auto *Succ : MBB->successors()) + if (VistedBB.insert(Succ).second && + (!MLI->isLoopHeader(Succ) || + MLI->getLoopFor(Succ)->getBottomBlock() != MBB)) + WorkList.push_back(Succ); } + ++ShapeBBNum; } - WorkList.clear(); - for (auto I : BBVisitedInfo) { - WorkList.push_back(I.first); + DebugLoc DL; + SmallSet VisitedOrInserted; + int SS = MF.getFrameInfo().CreateStackObject( + ST.getTileConfigSize(), ST.getTileConfigAlignment(), false); + + // Try to insert for the tile config live in points. + for (auto I : CfgNeedInsert) { + SmallSet InsertPoints; + SmallVector WorkList({I}); while (!WorkList.empty()) { - MBB = WorkList.pop_back_val(); - if (BBVisitedInfo[MBB].HasCallBeforeAMX || - (!BBVisitedInfo[MBB].HasAMX && - !BBVisitedInfo[MBB].HasAMXBeforeCallInSuccs)) - continue; - for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) { - if (!BBVisitedInfo.count(*I) || - BBVisitedInfo[*I].HasAMXBeforeCallInSuccs) - continue; - if (BBVisitedInfo[*I].LastCall) - CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall); - BBVisitedInfo[*I].HasAMXBeforeCallInSuccs = true; - WorkList.push_back(*I); + MIRef I = WorkList.pop_back_val(); + if (!VisitedOrInserted.count(I)) { + if (BBVisitedInfo[I.MBB].ShapeReachedCount == ShapeBBNum) { + // If the BB is all shapes reachable, stop sink and try to insert. + InsertPoints.insert(I); + } else { + // Avoid the BB to be multi visited. + VisitedOrInserted.insert(I); + // We cannot sink it across any AMX instruction. + if (BBVisitedInfo[I.MBB].FirstAMX) + REPORT_CONFIG_FAIL; + // Sink the inserting point along the chain with NeedTileCfgLiveIn = + // true when MBB isn't all shapes reachable. + for (auto *Succ : I.MBB->successors()) + if (BBVisitedInfo[Succ].NeedTileCfgLiveIn) + WorkList.push_back(MIRef(Succ)); + } } } - } - for (auto *I : CfgNeedInsert) { - BitVector UsableRegs(TRI->getNumRegs()); - const TargetRegisterClass *RC = TRI->getRegClass(X86::TILERegClassID); - for (unsigned J = 0; J < RC->getNumRegs(); J++) - UsableRegs.set(X86::TMM0 + J); - for (MachineOperand &CallMO : I->operands()) { - if (CallMO.isRegMask()) - UsableRegs.clearBitsInMask(CallMO.getRegMask()); + // A given point might be forked due to shape conditions are not met. + for (MIRef I : InsertPoints) { + // Even MBB is all shapes reachable, we still need to check if there's + // AMX that intersects with shapes in the same MBB. + if (BBVisitedInfo[I.MBB].FirstAMX && + BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape) + REPORT_CONFIG_FAIL; + // Make sure we insert ldtilecfg after the last shape def in MBB. + if (I < BBVisitedInfo[I.MBB].LastShape) + I = BBVisitedInfo[I.MBB].LastShape; + // There're chances the MBB is sunk more than once. Record it to avoid + // multi insert. + if (VisitedOrInserted.insert(I).second) { + auto II = I.MI ? I->getIterator() : I.MBB->instr_begin(); + addFrameReference(BuildMI(*I.MBB, ++II, DL, TII->get(X86::LDTILECFG)), + SS); + } } - if (!UsableRegs.none()) - addFrameReference(BuildMI(*I->getParent(), ++I->getIterator(), DebugLoc(), - TII->get(X86::LDTILECFG)), - FI); } -} -bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) { - MF = &mf; - MRI = &mf.getRegInfo(); - ST = &mf.getSubtarget(); - TRI = ST->getRegisterInfo(); - TII = mf.getSubtarget().getInstrInfo(); - DomTree = &getAnalysis(); + // Zero stack slot. + MachineBasicBlock &MBB = MF.front(); + MachineInstr *MI = &*MBB.begin(); + if (ST.hasAVX512()) { + Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass); + BuildMI(MBB, MI, DL, TII->get(X86::VPXORDZrr), Zmm) + .addReg(Zmm, RegState::Undef) + .addReg(Zmm, RegState::Undef); + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSZmr)), SS) + .addReg(Zmm); + } else if (ST.hasAVX2()) { + Register Ymm = MRI->createVirtualRegister(&X86::VR256RegClass); + BuildMI(MBB, MI, DL, TII->get(X86::VPXORYrr), Ymm) + .addReg(Ymm, RegState::Undef) + .addReg(Ymm, RegState::Undef); + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS) + .addReg(Ymm); + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::VMOVUPSYmr)), SS, 32) + .addReg(Ymm); + } else { + assert(ST.hasSSE2() && "AMX should assume SSE2 enabled"); + Register Xmm = MRI->createVirtualRegister(&X86::VR128RegClass); + BuildMI(MBB, MI, DL, TII->get(X86::PXORrr), Xmm) + .addReg(Xmm, RegState::Undef) + .addReg(Xmm, RegState::Undef); + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS) + .addReg(Xmm); + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 16) + .addReg(Xmm); + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 32) + .addReg(Xmm); + addFrameReference(BuildMI(MBB, MI, DL, TII->get(X86::MOVUPSmr)), SS, 48) + .addReg(Xmm); + } - MachineInstr *MI = getTileConfigPoint(); - if (!MI) - return false; - unsigned Size = ST->getTileConfigSize(); - Align Alignment = ST->getTileConfigAlignment(); - int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false); - buildConfigMI(MI, SS, TII, MRI, ST); - reloadTileConfig(MI, SS, TII, TRI); return true; } diff --git a/llvm/test/CodeGen/X86/AMX/amx-across-func.ll b/llvm/test/CodeGen/X86/AMX/amx-across-func.ll --- a/llvm/test/CodeGen/X86/AMX/amx-across-func.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-across-func.ll @@ -113,9 +113,10 @@ ; CHECK-NEXT: pushq %rbx ; CHECK-NEXT: subq $3016, %rsp # imm = 0xBC8 ; CHECK-NEXT: movl %edi, %r14d -; CHECK-NEXT: callq foo ; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 ; CHECK-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: callq foo ; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) @@ -133,7 +134,6 @@ ; CHECK-NEXT: tileloadd (%r15,%r12), %tmm0 ; CHECK-NEXT: movabsq $64, %rax ; CHECK-NEXT: tilestored %tmm0, 1024(%rsp,%rax) # 1024-byte Folded Spill -; CHECK-NEXT: vzeroupper ; CHECK-NEXT: callq foo ; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; CHECK-NEXT: movabsq $64, %rax @@ -154,7 +154,6 @@ ; CHECK-NEXT: incl %r14d ; CHECK-NEXT: jmp .LBB2_8 ; CHECK-NEXT: .LBB2_4: -; CHECK-NEXT: vzeroupper ; CHECK-NEXT: callq foo ; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; CHECK-NEXT: movl $32, %eax @@ -180,13 +179,13 @@ ; IPRA: # %bb.0: ; IPRA-NEXT: subq $72, %rsp ; IPRA-NEXT: movl %edi, %eax -; IPRA-NEXT: callq foo ; IPRA-NEXT: vpxord %zmm0, %zmm0, %zmm0 ; IPRA-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) ; IPRA-NEXT: movb $1, {{[0-9]+}}(%rsp) ; IPRA-NEXT: movb $8, {{[0-9]+}}(%rsp) ; IPRA-NEXT: movw $8, {{[0-9]+}}(%rsp) ; IPRA-NEXT: ldtilecfg {{[0-9]+}}(%rsp) +; IPRA-NEXT: callq foo ; IPRA-NEXT: testl %edi, %edi ; IPRA-NEXT: jg .LBB2_4 ; IPRA-NEXT: # %bb.1: # %.preheader @@ -273,12 +272,15 @@ ; CHECK-NEXT: pushq %rbx ; CHECK-NEXT: subq $3024, %rsp # imm = 0xBD0 ; CHECK-NEXT: movl %edi, %ebx +; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 +; CHECK-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movl $buf, %r14d ; CHECK-NEXT: movl $32, %r15d ; CHECK-NEXT: movw $8, %bp ; CHECK-NEXT: movl $buf+2048, %r12d ; CHECK-NEXT: .p2align 4, 0x90 ; CHECK-NEXT: .LBB3_1: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vzeroupper ; CHECK-NEXT: callq foo ; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) @@ -287,13 +289,9 @@ ; CHECK-NEXT: testl %ebx, %ebx ; CHECK-NEXT: jle .LBB3_3 ; CHECK-NEXT: # %bb.2: # in Loop: Header=BB3_1 Depth=1 -; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 -; CHECK-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) -; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; CHECK-NEXT: tileloadd (%r14,%r15), %tmm0 ; CHECK-NEXT: movabsq $64, %rax ; CHECK-NEXT: tilestored %tmm0, 1024(%rsp,%rax) # 1024-byte Folded Spill -; CHECK-NEXT: vzeroupper ; CHECK-NEXT: callq foo ; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; CHECK-NEXT: movabsq $64, %rax @@ -314,6 +312,12 @@ ; IPRA-LABEL: test_loop2: ; IPRA: # %bb.0: ; IPRA-NEXT: subq $72, %rsp +; IPRA-NEXT: vpxord %zmm0, %zmm0, %zmm0 +; IPRA-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) +; IPRA-NEXT: movb $1, {{[0-9]+}}(%rsp) +; IPRA-NEXT: movb $8, {{[0-9]+}}(%rsp) +; IPRA-NEXT: movw $8, {{[0-9]+}}(%rsp) +; IPRA-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; IPRA-NEXT: movl $buf, %eax ; IPRA-NEXT: movl $32, %ecx ; IPRA-NEXT: movw $8, %dx @@ -324,12 +328,6 @@ ; IPRA-NEXT: testl %edi, %edi ; IPRA-NEXT: jle .LBB3_3 ; IPRA-NEXT: # %bb.2: # in Loop: Header=BB3_1 Depth=1 -; IPRA-NEXT: vpxord %zmm0, %zmm0, %zmm0 -; IPRA-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) -; IPRA-NEXT: movb $1, {{[0-9]+}}(%rsp) -; IPRA-NEXT: movb $8, {{[0-9]+}}(%rsp) -; IPRA-NEXT: movw $8, {{[0-9]+}}(%rsp) -; IPRA-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; IPRA-NEXT: tileloadd (%rax,%rcx), %tmm0 ; IPRA-NEXT: callq foo ; IPRA-NEXT: tilestored %tmm0, (%rsi,%rcx) diff --git a/llvm/test/CodeGen/X86/AMX/amx-config.ll b/llvm/test/CodeGen/X86/AMX/amx-config.ll --- a/llvm/test/CodeGen/X86/AMX/amx-config.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-config.ll @@ -10,10 +10,10 @@ define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) { ; AVX512-LABEL: test_api: ; AVX512: # %bb.0: -; AVX512-NEXT: testl %edi, %edi -; AVX512-NEXT: movsbl %sil, %eax ; AVX512-NEXT: vpxord %zmm0, %zmm0, %zmm0 ; AVX512-NEXT: vmovdqu64 %zmm0, -{{[0-9]+}}(%rsp) +; AVX512-NEXT: testl %edi, %edi +; AVX512-NEXT: movsbl %sil, %eax ; AVX512-NEXT: movb $1, -{{[0-9]+}}(%rsp) ; AVX512-NEXT: movb %al, -{{[0-9]+}}(%rsp) ; AVX512-NEXT: movw %si, -{{[0-9]+}}(%rsp) @@ -43,11 +43,11 @@ ; ; AVX2-LABEL: test_api: ; AVX2: # %bb.0: -; AVX2-NEXT: testl %edi, %edi -; AVX2-NEXT: movsbl %sil, %eax ; AVX2-NEXT: vxorps %ymm0, %ymm0, %ymm0 ; AVX2-NEXT: vmovups %ymm0, -{{[0-9]+}}(%rsp) ; AVX2-NEXT: vmovups %ymm0, -{{[0-9]+}}(%rsp) +; AVX2-NEXT: testl %edi, %edi +; AVX2-NEXT: movsbl %sil, %eax ; AVX2-NEXT: movb $1, -{{[0-9]+}}(%rsp) ; AVX2-NEXT: movb %al, -{{[0-9]+}}(%rsp) ; AVX2-NEXT: movw %si, -{{[0-9]+}}(%rsp) @@ -77,13 +77,13 @@ ; ; SSE2-LABEL: test_api: ; SSE2: # %bb.0: -; SSE2-NEXT: testl %edi, %edi -; SSE2-NEXT: movsbl %sil, %eax ; SSE2-NEXT: xorps %xmm0, %xmm0 ; SSE2-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) ; SSE2-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) ; SSE2-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) ; SSE2-NEXT: movups %xmm0, -{{[0-9]+}}(%rsp) +; SSE2-NEXT: testl %edi, %edi +; SSE2-NEXT: movsbl %sil, %eax ; SSE2-NEXT: movb $1, -{{[0-9]+}}(%rsp) ; SSE2-NEXT: movb %al, -{{[0-9]+}}(%rsp) ; SSE2-NEXT: movw %si, -{{[0-9]+}}(%rsp) diff --git a/llvm/test/CodeGen/X86/AMX/amx-ldtilecfg-insert.ll b/llvm/test/CodeGen/X86/AMX/amx-ldtilecfg-insert.ll --- a/llvm/test/CodeGen/X86/AMX/amx-ldtilecfg-insert.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-ldtilecfg-insert.ll @@ -47,6 +47,8 @@ ; CHECK-NEXT: movl %edi, %ebp ; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 ; CHECK-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: callq foo ; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movw %bx, {{[0-9]+}}(%rsp) @@ -59,9 +61,6 @@ ; CHECK-NEXT: movb %bpl, {{[0-9]+}}(%rsp) ; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) ; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) -; CHECK-NEXT: vzeroupper -; CHECK-NEXT: callq foo -; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; CHECK-NEXT: xorl %eax, %eax ; CHECK-NEXT: testb %al, %al ; CHECK-NEXT: jne .LBB1_3 @@ -116,6 +115,115 @@ ret void } +define dso_local void @test3(i16 signext %0, i16 signext %1) nounwind { +; CHECK-LABEL: test3: +; CHECK: # %bb.0: +; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 +; CHECK-NEXT: vmovdqu64 %zmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: xorl %eax, %eax +; CHECK-NEXT: testb %al, %al +; CHECK-NEXT: jne .LBB2_2 +; CHECK-NEXT: # %bb.1: # %if.true +; CHECK-NEXT: incl %edi +; CHECK-NEXT: jmp .LBB2_3 +; CHECK-NEXT: .LBB2_2: # %if.false +; CHECK-NEXT: decl %edi +; CHECK-NEXT: .LBB2_3: # %exit +; CHECK-NEXT: movb $1, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %dil, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %si, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg -{{[0-9]+}}(%rsp) +; CHECK-NEXT: tilezero %tmm0 +; CHECK-NEXT: movl $buf, %eax +; CHECK-NEXT: movl $32, %ecx +; CHECK-NEXT: tilestored %tmm0, (%rax,%rcx) +; CHECK-NEXT: tilerelease +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + br i1 undef, label %if.true, label %if.false + +if.true: + %3 = add i16 %0, 1 + br label %exit + +if.false: + %4 = sub i16 %0, 1 + br label %exit + +exit: + %5 = phi i16 [ %3, %if.true ], [ %4, %if.false ] + %6 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %5, i16 %1) + tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %6) + ret void +} + +; TODO: There's PRA Tile Register Configure bug needs to fix later. +define dso_local void @test4(i16 signext %0, i16 signext %1) nounwind { +; CHECK-LABEL: test4: +; CHECK: # %bb.0: +; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 +; CHECK-NEXT: vmovdqu64 %zmm0, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: xorl %eax, %eax +; CHECK-NEXT: testb %al, %al +; CHECK-NEXT: jne .LBB3_3 +; CHECK-NEXT: # %bb.1: # %if.true +; CHECK-NEXT: incl %edi +; CHECK-NEXT: xorl %eax, %eax +; CHECK-NEXT: testb %al, %al +; CHECK-NEXT: jne .LBB3_4 +; CHECK-NEXT: .LBB3_2: # %amx2 +; CHECK-NEXT: movb $1, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %dil, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %si, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movl $32, %eax +; CHECK-NEXT: movl $buf+1024, %ecx +; CHECK-NEXT: tileloadd (%rcx,%rax), %tmm0 +; CHECK-NEXT: movl $buf, %ecx +; CHECK-NEXT: tilestored %tmm0, (%rcx,%rax) +; CHECK-NEXT: tilerelease +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq +; CHECK-NEXT: .LBB3_3: # %if.false +; CHECK-NEXT: decl %edi +; CHECK-NEXT: xorl %eax, %eax +; CHECK-NEXT: testb %al, %al +; CHECK-NEXT: jne .LBB3_2 +; CHECK-NEXT: .LBB3_4: # %amx1 +; CHECK-NEXT: ldtilecfg -{{[0-9]+}}(%rsp) +; CHECK-NEXT: tilezero %tmm0 +; CHECK-NEXT: movl $buf, %eax +; CHECK-NEXT: movl $32, %ecx +; CHECK-NEXT: tilestored %tmm0, (%rax,%rcx) +; CHECK-NEXT: tilerelease +; CHECK-NEXT: vzeroupper +; CHECK-NEXT: retq + br i1 undef, label %if.true, label %if.false + +if.true: + %3 = add i16 %0, 1 + br i1 undef, label %amx1, label %amx2 + +if.false: + %4 = sub i16 %0, 1 + br i1 undef, label %amx2, label %amx1 + +amx1: + %5 = phi i16 [ %3, %if.true ], [ %4, %if.false ] + %6 = tail call x86_amx @llvm.x86.tilezero.internal(i16 %5, i16 %1) + tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %6) + br label %exit + +amx2: + %7 = phi i16 [ %3, %if.true ], [ %4, %if.false ] + %8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %7, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) + tail call void @llvm.x86.tilestored64.internal(i16 %7, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %8) + br label %exit + +exit: + ret void +} + declare dso_local void @foo() nounwind declare x86_amx @llvm.x86.tilezero.internal(i16, i16) declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) diff --git a/llvm/test/CodeGen/X86/opt-pipeline.ll b/llvm/test/CodeGen/X86/opt-pipeline.ll --- a/llvm/test/CodeGen/X86/opt-pipeline.ll +++ b/llvm/test/CodeGen/X86/opt-pipeline.ll @@ -117,12 +117,12 @@ ; CHECK-NEXT: X86 EFLAGS copy lowering ; CHECK-NEXT: X86 WinAlloca Expander ; CHECK-NEXT: MachineDominator Tree Construction +; CHECK-NEXT: Machine Natural Loop Construction ; CHECK-NEXT: Tile Register Pre-configure ; CHECK-NEXT: Detect Dead Lanes ; CHECK-NEXT: Process Implicit Definitions ; CHECK-NEXT: Remove unreachable machine basic blocks ; CHECK-NEXT: Live Variable Analysis -; CHECK-NEXT: Machine Natural Loop Construction ; CHECK-NEXT: Eliminate PHI nodes for register allocation ; CHECK-NEXT: Two-Address instruction pass ; CHECK-NEXT: Slot index numbering