diff --git a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp @@ -205,6 +205,8 @@ S.getSUnit()->removePred(SP); } +enum class Direction { BOTTOM_UP = 0, TOP_DOWN = 1 }; + typedef std::pair> SUToCandSGsPair; typedef SmallVector SUsToCandSGsVec; @@ -254,6 +256,9 @@ // How many branches we have explored uint64_t BranchesExplored = 0; + // The direction in which we process the candidate SchedGroups per SU + Direction ProcessDirection; + // Update indices to fit next conflicting instruction void advancePosition(); // Recede indices to attempt to find better fit for previous conflicting @@ -290,9 +295,11 @@ PipelineSolver(DenseMap> &SyncedSchedGroups, DenseMap &SyncedInstrs, - ScheduleDAGMI *DAG) + ScheduleDAGMI *DAG, + Direction ProcessDirection = Direction::BOTTOM_UP) : DAG(DAG), SyncedInstrs(SyncedInstrs), - SyncedSchedGroups(SyncedSchedGroups) { + SyncedSchedGroups(SyncedSchedGroups), + ProcessDirection(ProcessDirection) { for (auto &PipelineInstrs : SyncedInstrs) { if (PipelineInstrs.second.size() > 0) { @@ -367,10 +374,14 @@ // Preserve the order of barrier for subsequent SchedGroupBarrier mutations for (auto &SyncPipeline : BestPipeline) { for (auto &SG : SyncPipeline) { + LLVM_DEBUG(dbgs() << "Printing SchedGroups\n"); + LLVM_DEBUG(dbgs() << "SchedGroup with SGID " << SG.getSGID() + << " has: \n"); SUnit *SGBarr = nullptr; for (auto &SU : SG.Collection) { if (SU->getInstr()->getOpcode() == AMDGPU::SCHED_GROUP_BARRIER) SGBarr = SU; + LLVM_DEBUG(dbgs() << "SU(" << SU->NodeNum << ")\n"); } // Command line requested IGroupLP doesn't have SGBarr if (!SGBarr) @@ -381,12 +392,16 @@ } for (auto &SyncPipeline : BestPipeline) { - auto I = SyncPipeline.rbegin(); - auto E = SyncPipeline.rend(); - for (; I != E; ++I) { - auto &GroupA = *I; - for (auto J = std::next(I); J != E; ++J) { - auto &GroupB = *J; + for (int I = 0; I < (int)SyncPipeline.size(); I++) { + int Idx = ProcessDirection == Direction::BOTTOM_UP + ? SyncPipeline.size() - 1 - I + : I; + auto &GroupA = SyncPipeline[Idx]; + for (auto J = I + 1; J < (int)SyncPipeline.size(); J++) { + int Jdx = ProcessDirection == Direction::BOTTOM_UP + ? SyncPipeline.size() - 1 - J + : J; + auto &GroupB = SyncPipeline[Jdx]; GroupA.link(GroupB); } } @@ -399,18 +414,16 @@ int AddedCost = 0; bool MakePred = false; - // The groups in the pipeline are in reverse order. Thus, - // by traversing them from last to first, we are traversing - // them in the order as they were introduced in the code. After we - // pass the group the SU is being assigned to, it should be - // linked as a predecessor of the subsequent SchedGroups - auto GroupNo = (int)SyncPipeline.size() - 1; - for (; GroupNo >= 0; GroupNo--) { - if (SyncPipeline[GroupNo].getSGID() == SGID) { + for (int I = 0; I < (int)SyncPipeline.size(); I++) { + int Idx = ProcessDirection == Direction::BOTTOM_UP + ? SyncPipeline.size() - 1 - I + : I; + auto Group = &SyncPipeline[Idx]; + if (Group->getSGID() == SGID) { MakePred = true; continue; } - auto Group = &SyncPipeline[GroupNo]; + AddedCost += Group->link(*SU, MakePred, AddedEdges); assert(AddedCost >= 0); } @@ -494,11 +507,12 @@ SUToCandSGsPair &CurrSU, SmallVectorImpl> &ReadyList, SmallVectorImpl &SyncPipeline) { assert(CurrSU.second.size() >= 1); - auto I = CurrSU.second.rbegin(); - auto E = CurrSU.second.rend(); - for (; I != E; ++I) { + for (int I = 0; I < (int)CurrSU.second.size(); I++) { + int Idx = ProcessDirection == Direction::BOTTOM_UP + ? CurrSU.second.size() - 1 - I + : I; std::vector> AddedEdges; - int CandSGID = *I; + int CandSGID = CurrSU.second[Idx]; SchedGroup *Match; for (auto &SG : SyncPipeline) { if (SG.getSGID() == CandSGID) @@ -507,15 +521,15 @@ if (UseCostHeur) { if (Match->isFull()) { - ReadyList.push_back(std::pair(*I, MissPenalty)); + ReadyList.push_back(std::pair(CandSGID, MissPenalty)); continue; } int TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges); - ReadyList.push_back(std::pair(*I, TempCost)); + ReadyList.push_back(std::pair(CandSGID, TempCost)); removeEdges(AddedEdges); } else - ReadyList.push_back(std::pair(*I, -1)); + ReadyList.push_back(std::pair(CandSGID, -1)); } if (UseCostHeur) { @@ -634,15 +648,12 @@ LLVM_DEBUG(dbgs() << "Fitting SU(" << CurrSU.first->NodeNum << ") in Pipeline # " << CurrSyncGroupIdx << "\n"); - // Since we have added the potential SchedGroups from bottom up, but - // traversed the DAG from top down, parse over the groups from last to - // first. If we fail to do this for the greedy algorithm, the solution will - // likely not be good in more complex cases. - auto I = CurrSU.second.rbegin(); - auto E = CurrSU.second.rend(); - for (; I != E; ++I) { + for (int I = 0; I < (int)CurrSU.second.size(); I++) { + int Idx = ProcessDirection == Direction::BOTTOM_UP + ? CurrSU.second.size() - 1 - I + : I; std::vector> AddedEdges; - int CandSGID = *I; + int CandSGID = CurrSU.second[Idx]; SchedGroup *Match; for (auto &SG : SyncPipeline) { if (SG.getSGID() == CandSGID) @@ -721,9 +732,11 @@ } makePipeline(); + LLVM_DEBUG(dbgs() << "After applying mutation\n"); + LLVM_DEBUG(DAG->dump()); } -enum IGLPStrategyID : int { MFMASmallGemmOptID = 0 }; +enum IGLPStrategyID : int { MFMASmallGemmOptID = 0, DemoOptID = 2 }; // Implement a IGLP scheduling strategy. class IGLPStrategy { @@ -741,6 +754,8 @@ // Returns true if this strategy should be applied to a ScheduleDAG. virtual bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) = 0; + virtual Direction getDirection() = 0; + IGLPStrategy(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) : DAG(DAG), TII(TII) {} @@ -748,6 +763,9 @@ }; class MFMASmallGemmOpt final : public IGLPStrategy { +private: + Direction OptDir = Direction::BOTTOM_UP; + public: void applyIGLPStrategy( DenseMap &SyncedInstrs, @@ -755,6 +773,8 @@ bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; } + Direction getDirection() override { return OptDir; } + MFMASmallGemmOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) : IGLPStrategy(DAG, TII) {} }; @@ -781,12 +801,53 @@ } } +class DemoOpt final : public IGLPStrategy { +private: + Direction OptDir = Direction::TOP_DOWN; + +public: + void applyIGLPStrategy( + DenseMap &SyncedInstrs, + DenseMap> &SyncedSchedGroups) override; + + bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; } + + Direction getDirection() override { return OptDir; } + + DemoOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) + : IGLPStrategy(DAG, TII) {} +}; + +void DemoOpt::applyIGLPStrategy( + DenseMap &SyncedInstrs, + DenseMap> &SyncedSchedGroups) { + // Count the number of MFMA instructions. + unsigned MFMACount = 0; + for (const MachineInstr &I : *DAG) + if (TII->isMFMAorWMMA(I)) + ++MFMACount; + + const unsigned PipelineSyncID = 0; + SchedGroup *SG = nullptr; + for (unsigned I = 0; I < MFMACount * 3; ++I) { + SG = &SyncedSchedGroups[PipelineSyncID].emplace_back( + SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII); + SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]); + + SG = &SyncedSchedGroups[PipelineSyncID].emplace_back( + SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII); + SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]); + } +} + static std::unique_ptr createIGLPStrategy(IGLPStrategyID ID, ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) { switch (ID) { case MFMASmallGemmOptID: return std::make_unique(DAG, TII); + case DemoOptID: + return std::make_unique(DAG, TII); } llvm_unreachable("Unknown IGLPStrategyID"); @@ -806,6 +867,13 @@ // Used to track instructions that can be mapped to multiple sched groups DenseMap SyncedInstrs; + // The order in which the PipelineSolver should process the candidate + // SchedGroup for a PipelineInstr. BOTTOM_UP will try to add SUs to the last + // created SchedGroup first, and will consider that as the ultimate + // predecessor group when linking. TOP_DOWN instead links and processes the + // first created SchedGroup first. + Direction SolverDirection = Direction::BOTTOM_UP; + // Add DAG edges that enforce SCHED_BARRIER ordering. void addSchedBarrierEdges(SUnit &SU); @@ -1034,7 +1102,7 @@ } if (foundSB || foundIGLP) { - PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG); + PipelineSolver PS(SyncedSchedGroups, SyncedInstrs, DAG, SolverDirection); // PipelineSolver performs the mutation by adding the edges it // determined as the best PS.solve(); @@ -1114,8 +1182,10 @@ IGLPStrategyID StrategyID = (IGLPStrategyID)SU.getInstr()->getOperand(0).getImm(); auto S = createIGLPStrategy(StrategyID, DAG, TII); - if (S->shouldApplyStrategy(DAG)) + if (S->shouldApplyStrategy(DAG)) { + SolverDirection = S->getDirection(); S->applyIGLPStrategy(SyncedInstrs, SyncedSchedGroups); + } } } // namespace diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll --- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll +++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll @@ -147,6 +147,144 @@ ret void } + +define amdgpu_kernel void @test_iglp_opt_rev_mfma_gemm(ptr addrspace(3) noalias %in, ptr addrspace(3) noalias %out) #0 { +; GCN-LABEL: test_iglp_opt_rev_mfma_gemm: +; GCN: ; %bb.0: ; %entry +; GCN-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0x24 +; GCN-NEXT: v_lshlrev_b32_e32 v0, 7, v0 +; GCN-NEXT: v_mov_b32_e32 v3, 2.0 +; GCN-NEXT: ; iglp_opt mask(0x00000002) +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_add_u32_e32 v1, s0, v0 +; GCN-NEXT: v_add_u32_e32 v2, 0x6000, v1 +; GCN-NEXT: ds_read_b128 a[28:31], v2 offset:57456 +; GCN-NEXT: ds_read_b128 a[24:27], v2 offset:57440 +; GCN-NEXT: ds_read_b128 a[20:23], v2 offset:57424 +; GCN-NEXT: ds_read_b128 a[16:19], v2 offset:57408 +; GCN-NEXT: ds_read_b128 a[0:3], v2 offset:57344 +; GCN-NEXT: ds_read_b128 a[4:7], v2 offset:57360 +; GCN-NEXT: ds_read_b128 a[8:11], v2 offset:57376 +; GCN-NEXT: ds_read_b128 a[12:15], v2 offset:57392 +; GCN-NEXT: v_mov_b32_e32 v2, 1.0 +; GCN-NEXT: ds_read_b128 a[60:63], v1 offset:49264 +; GCN-NEXT: ds_read_b128 a[56:59], v1 offset:49248 +; GCN-NEXT: ds_read_b128 a[52:55], v1 offset:49232 +; GCN-NEXT: ds_read_b128 a[48:51], v1 offset:49216 +; GCN-NEXT: ds_read_b128 a[44:47], v1 offset:49200 +; GCN-NEXT: ds_read_b128 a[40:43], v1 offset:49184 +; GCN-NEXT: ds_read_b128 a[36:39], v1 offset:49168 +; GCN-NEXT: ds_read_b128 a[32:35], v1 offset:49152 +; GCN-NEXT: s_waitcnt lgkmcnt(8) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[0:31], v2, v3, a[0:31] +; GCN-NEXT: ds_read_b128 a[156:159], v1 offset:112 +; GCN-NEXT: ds_read_b128 a[152:155], v1 offset:96 +; GCN-NEXT: ds_read_b128 a[68:71], v1 offset:24592 +; GCN-NEXT: ds_read_b128 a[64:67], v1 offset:24576 +; GCN-NEXT: v_add_u32_e32 v0, s1, v0 +; GCN-NEXT: s_waitcnt lgkmcnt(4) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[32:63], v2, v3, a[32:63] +; GCN-NEXT: ds_read_b128 a[148:151], v1 offset:80 +; GCN-NEXT: ds_read_b128 a[144:147], v1 offset:64 +; GCN-NEXT: ds_read_b128 a[128:131], v1 +; GCN-NEXT: ds_read_b128 a[132:135], v1 offset:16 +; GCN-NEXT: ds_read_b128 a[136:139], v1 offset:32 +; GCN-NEXT: ds_read_b128 a[140:143], v1 offset:48 +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[128:159], v2, v3, a[128:159] +; GCN-NEXT: ds_read_b128 a[124:127], v1 offset:8304 +; GCN-NEXT: ds_read_b128 a[120:123], v1 offset:8288 +; GCN-NEXT: ds_read_b128 a[116:119], v1 offset:8272 +; GCN-NEXT: ds_read_b128 a[112:115], v1 offset:8256 +; GCN-NEXT: ds_read_b128 a[108:111], v1 offset:8240 +; GCN-NEXT: ds_read_b128 a[104:107], v1 offset:8224 +; GCN-NEXT: ds_read_b128 a[100:103], v1 offset:8208 +; GCN-NEXT: ds_read_b128 a[96:99], v1 offset:8192 +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[96:127], v2, v3, a[96:127] +; GCN-NEXT: ds_read_b128 a[92:95], v1 offset:24688 +; GCN-NEXT: ds_read_b128 a[88:91], v1 offset:24672 +; GCN-NEXT: ds_read_b128 a[84:87], v1 offset:24656 +; GCN-NEXT: ds_read_b128 a[80:83], v1 offset:24640 +; GCN-NEXT: ds_read_b128 a[76:79], v1 offset:24624 +; GCN-NEXT: ds_read_b128 a[72:75], v1 offset:24608 +; GCN-NEXT: s_nop 2 +; GCN-NEXT: ds_write_b128 v0, a[156:159] offset:112 +; GCN-NEXT: ds_write_b128 v0, a[152:155] offset:96 +; GCN-NEXT: ds_write_b128 v0, a[148:151] offset:80 +; GCN-NEXT: ds_write_b128 v0, a[144:147] offset:64 +; GCN-NEXT: ds_write_b128 v0, a[140:143] offset:48 +; GCN-NEXT: ds_write_b128 v0, a[136:139] offset:32 +; GCN-NEXT: ds_write_b128 v0, a[132:135] offset:16 +; GCN-NEXT: ds_write_b128 v0, a[128:131] +; GCN-NEXT: v_mov_b32_e32 v0, s1 +; GCN-NEXT: s_waitcnt lgkmcnt(8) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[64:95], v2, v3, a[64:95] +; GCN-NEXT: ds_write_b128 v0, a[56:59] offset:24672 +; GCN-NEXT: ds_write_b128 v0, a[60:63] offset:24688 +; GCN-NEXT: ds_write_b128 v0, a[48:51] offset:24640 +; GCN-NEXT: ds_write_b128 v0, a[120:123] offset:8288 +; GCN-NEXT: ds_write_b128 v0, a[124:127] offset:8304 +; GCN-NEXT: ds_write_b128 v0, a[112:115] offset:8256 +; GCN-NEXT: ds_write_b128 v0, a[116:119] offset:8272 +; GCN-NEXT: ds_write_b128 v0, a[104:107] offset:8224 +; GCN-NEXT: ds_write_b128 v0, a[108:111] offset:8240 +; GCN-NEXT: ds_write_b128 v0, a[96:99] offset:8192 +; GCN-NEXT: ds_write_b128 v0, a[100:103] offset:8208 +; GCN-NEXT: ds_write_b128 v0, a[52:55] offset:24656 +; GCN-NEXT: ds_write_b128 v0, a[40:43] offset:24608 +; GCN-NEXT: ds_write_b128 v0, a[44:47] offset:24624 +; GCN-NEXT: ds_write_b128 v0, a[32:35] offset:24576 +; GCN-NEXT: ds_write_b128 v0, a[36:39] offset:24592 +; GCN-NEXT: ds_write_b128 v0, a[24:27] offset:32864 +; GCN-NEXT: ds_write_b128 v0, a[28:31] offset:32880 +; GCN-NEXT: ds_write_b128 v0, a[16:19] offset:32832 +; GCN-NEXT: ds_write_b128 v0, a[88:91] offset:16480 +; GCN-NEXT: ds_write_b128 v0, a[92:95] offset:16496 +; GCN-NEXT: ds_write_b128 v0, a[80:83] offset:16448 +; GCN-NEXT: ds_write_b128 v0, a[84:87] offset:16464 +; GCN-NEXT: ds_write_b128 v0, a[72:75] offset:16416 +; GCN-NEXT: ds_write_b128 v0, a[76:79] offset:16432 +; GCN-NEXT: ds_write_b128 v0, a[64:67] offset:16384 +; GCN-NEXT: ds_write_b128 v0, a[68:71] offset:16400 +; GCN-NEXT: ds_write_b128 v0, a[20:23] offset:32848 +; GCN-NEXT: ds_write_b128 v0, a[8:11] offset:32800 +; GCN-NEXT: ds_write_b128 v0, a[12:15] offset:32816 +; GCN-NEXT: ds_write_b128 v0, a[0:3] offset:32768 +; GCN-NEXT: ds_write_b128 v0, a[4:7] offset:32784 +; GCN-NEXT: s_endpgm +entry: + call void @llvm.amdgcn.iglp.opt(i32 2) + %idx = call i32 @llvm.amdgcn.workitem.id.x() + %load.0.addr = getelementptr <32 x float>, ptr addrspace(3) %in, i32 %idx + %load.0 = load <32 x float>, ptr addrspace(3) %load.0.addr + %load.1.addr = getelementptr <32 x float>, ptr addrspace(3) %load.0.addr, i32 64 + %load.1 = load <32 x float>, ptr addrspace(3) %load.1.addr + %load.2.addr = getelementptr <32 x float>, ptr addrspace(3) %load.1.addr, i32 128 + %load.2 = load <32 x float>, ptr addrspace(3) %load.2.addr + %load.3.addr = getelementptr <32 x float>, ptr addrspace(3) %load.2.addr, i32 192 + %load.3 = load <32 x float>, ptr addrspace(3) %load.3.addr + %load.4.addr = getelementptr <32 x float>, ptr addrspace(3) %load.3.addr, i32 256 + %load.4 = load <32 x float>, ptr addrspace(3) %load.4.addr + %mai.0 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.0, i32 0, i32 0, i32 0) + %mai.1 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.1, i32 0, i32 0, i32 0) + %mai.2 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.2, i32 0, i32 0, i32 0) + %mai.3 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.3, i32 0, i32 0, i32 0) + %mai.4 = tail call <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float 1.0, float 2.0, <32 x float> %load.4, i32 0, i32 0, i32 0) + %store.0.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 %idx + store <32 x float> %mai.0, ptr addrspace(3) %store.0.addr + %store.1.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 64 + store <32 x float> %mai.1, ptr addrspace(3) %store.1.addr + %store.2.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 128 + store <32 x float> %mai.2, ptr addrspace(3) %store.2.addr + %store.3.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 192 + store <32 x float> %mai.3, ptr addrspace(3) %store.3.addr + %store.4.addr = getelementptr <32 x float>, ptr addrspace(3) %out, i32 256 + store <32 x float> %mai.4, ptr addrspace(3) %store.4.addr + ret void +} + + declare void @llvm.amdgcn.iglp.opt(i32) #1 declare i32 @llvm.amdgcn.workitem.id.x() #1 declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32, i32, i32) #1