diff --git a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp @@ -80,8 +80,15 @@ LLVM_MARK_AS_BITMASK_ENUM(/* LargestFlag = */ ALL) }; +class SchedGroup; + typedef DenseMap> SUnitsToCandidateSGsMap; +typedef function_ref, + const SIInstrInfo *, SmallVectorImpl &, + unsigned)> + InstructionRuleType; + // Classify instructions into groups to enable fine tuned control over the // scheduler. These groups may be more specific than current SchedModel // instruction classes. @@ -102,11 +109,12 @@ // SGID is used to map instructions to candidate SchedGroups unsigned SGID; + // The different rules each instruction in this SchedGroup must conform to + SmallVector Rules; + // Count of the number of created SchedGroups, used to initialize SGID. static unsigned NumSchedGroups; - ScheduleDAGInstrs *DAG; - const SIInstrInfo *TII; // Try to add and edge from SU A to SU B. @@ -120,6 +128,8 @@ // Collection of SUnits that are classified as members of this group. SmallVector Collection; + ScheduleDAGInstrs *DAG; + // Returns true if SU can be added to this SchedGroup. bool canAddSU(SUnit &SU) const; @@ -145,6 +155,25 @@ // Returns true if no more instructions may be added to this group. bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; } + // Append a constraint that SUs must meet in order to fit into this + // SchedGroup. Since many rules involve the relationship between a SchedGroup + // and the SUnits in other SchedGroups, rules are checked at Pipeline Solve + // time (rather than SchedGroup init time.) + void addRule(const InstructionRuleType &NewRule) { Rules.push_back(NewRule); } + + // Returns true if the SU matches all rules + bool allowedByRules(const SUnit *SU, + SmallVectorImpl &SyncPipe) const { + if (Rules.empty()) + return true; + for (auto &Rule : Rules) { + if (!Rule(SU, Collection, TII, SyncPipe, SGID)) { + return false; + } + } + return true; + } + // Add SU to the SchedGroup. void add(SUnit &SU) { LLVM_DEBUG(dbgs() << "For SchedGroup with mask " @@ -177,13 +206,13 @@ SchedGroup(SchedGroupMask SGMask, std::optional MaxSize, ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) - : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) { + : SGMask(SGMask), MaxSize(MaxSize), TII(TII), DAG(DAG) { SGID = NumSchedGroups++; } SchedGroup(SchedGroupMask SGMask, std::optional MaxSize, int SyncID, ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) - : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) { + : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), TII(TII), DAG(DAG) { SGID = NumSchedGroups++; } }; @@ -609,6 +638,9 @@ if (Match->isFull()) continue; + if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) + continue; + LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask " << (int)Match->getMask() << "and ID " << CandSGID << "\n"); @@ -692,6 +724,10 @@ LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n"); continue; } + if (!Match->allowedByRules(CurrSU.first, SyncPipeline)) { + LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " has conflicting rule\n"); + continue; + } TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges); LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n"); if (TempCost < BestNodeCost || BestNodeCost == -1) { @@ -861,13 +897,45 @@ const unsigned PipelineSyncID = 0; SchedGroup *SG = nullptr; - for (unsigned I = 0; I < MFMACount * 3; ++I) { + + // The SU is a successor of SU in prev SchedGroup + InstructionRuleType Rule1 = + [](const SUnit *SU, ArrayRef Collection, const SIInstrInfo *TII, + SmallVectorImpl &SyncPipe, unsigned SGID) { + auto MI = SU->getInstr(); + if (MI->getOpcode() == TargetOpcode::BUNDLE) + return false; + + SchedGroup *OtherGroup = nullptr; + for (auto &PipeSG : SyncPipe) { + if (PipeSG.getSGID() == (int)SGID - 1) { + OtherGroup = &PipeSG; + } + } + + if (!OtherGroup) + return false; + + return (std::any_of(OtherGroup->Collection.begin(), + OtherGroup->Collection.end(), [&SU](SUnit *Elt) { + return std::any_of(Elt->Succs.begin(), + Elt->Succs.end(), + [&SU](SDep &Succ) { + return Succ.getSUnit() == SU; + }); + })); + }; + + // Each iteration of pipeline has 1 MFMA and 1 DS_W, where the DS_W is a + // successor of the MFMA + for (unsigned I = 0; I < MFMACount; ++I) { SG = &SyncedSchedGroups[PipelineSyncID].emplace_back( SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII); SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]); SG = &SyncedSchedGroups[PipelineSyncID].emplace_back( - SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII); + SchedGroupMask::DS_WRITE, 1, PipelineSyncID, DAG, TII); + SG->addRule(Rule1); SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]); } } @@ -879,7 +947,7 @@ case MFMASmallGemmOptID: return std::make_unique(DAG, TII); case DemoOptID: - return std::make_unique(DAG, TII); + return std::make_unique(DAG, TII); } llvm_unreachable("Unknown IGLPStrategyID"); diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll --- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll +++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll @@ -153,45 +153,21 @@ ; GCN: ; %bb.0: ; %entry ; GCN-NEXT: s_load_dwordx2 s[0:1], s[0:1], 0x24 ; GCN-NEXT: v_lshlrev_b32_e32 v0, 7, v0 +; GCN-NEXT: v_mov_b32_e32 v2, 1.0 ; GCN-NEXT: v_mov_b32_e32 v3, 2.0 ; GCN-NEXT: ; iglp_opt mask(0x00000001) ; GCN-NEXT: s_waitcnt lgkmcnt(0) ; GCN-NEXT: v_add_u32_e32 v1, s0, v0 -; GCN-NEXT: v_add_u32_e32 v2, 0x6000, v1 -; GCN-NEXT: ds_read_b128 a[28:31], v2 offset:57456 -; GCN-NEXT: ds_read_b128 a[24:27], v2 offset:57440 -; GCN-NEXT: ds_read_b128 a[20:23], v2 offset:57424 -; GCN-NEXT: ds_read_b128 a[16:19], v2 offset:57408 -; GCN-NEXT: ds_read_b128 a[0:3], v2 offset:57344 -; GCN-NEXT: ds_read_b128 a[4:7], v2 offset:57360 -; GCN-NEXT: ds_read_b128 a[8:11], v2 offset:57376 -; GCN-NEXT: ds_read_b128 a[12:15], v2 offset:57392 -; GCN-NEXT: v_mov_b32_e32 v2, 1.0 -; GCN-NEXT: ds_read_b128 a[60:63], v1 offset:49264 -; GCN-NEXT: ds_read_b128 a[56:59], v1 offset:49248 -; GCN-NEXT: ds_read_b128 a[52:55], v1 offset:49232 -; GCN-NEXT: ds_read_b128 a[48:51], v1 offset:49216 -; GCN-NEXT: ds_read_b128 a[44:47], v1 offset:49200 -; GCN-NEXT: ds_read_b128 a[40:43], v1 offset:49184 -; GCN-NEXT: ds_read_b128 a[36:39], v1 offset:49168 -; GCN-NEXT: ds_read_b128 a[32:35], v1 offset:49152 -; GCN-NEXT: s_waitcnt lgkmcnt(8) -; GCN-NEXT: v_mfma_f32_32x32x1f32 a[0:31], v2, v3, a[0:31] -; GCN-NEXT: ds_read_b128 a[156:159], v1 offset:112 -; GCN-NEXT: ds_read_b128 a[152:155], v1 offset:96 -; GCN-NEXT: ds_read_b128 a[68:71], v1 offset:24592 -; GCN-NEXT: ds_read_b128 a[64:67], v1 offset:24576 -; GCN-NEXT: v_add_u32_e32 v0, s1, v0 -; GCN-NEXT: s_waitcnt lgkmcnt(4) -; GCN-NEXT: v_mfma_f32_32x32x1f32 a[32:63], v2, v3, a[32:63] -; GCN-NEXT: ds_read_b128 a[148:151], v1 offset:80 -; GCN-NEXT: ds_read_b128 a[144:147], v1 offset:64 -; GCN-NEXT: ds_read_b128 a[128:131], v1 -; GCN-NEXT: ds_read_b128 a[132:135], v1 offset:16 -; GCN-NEXT: ds_read_b128 a[136:139], v1 offset:32 -; GCN-NEXT: ds_read_b128 a[140:143], v1 offset:48 +; GCN-NEXT: ds_read_b128 a[28:31], v1 offset:112 +; GCN-NEXT: ds_read_b128 a[24:27], v1 offset:96 +; GCN-NEXT: ds_read_b128 a[20:23], v1 offset:80 +; GCN-NEXT: ds_read_b128 a[16:19], v1 offset:64 +; GCN-NEXT: ds_read_b128 a[0:3], v1 +; GCN-NEXT: ds_read_b128 a[4:7], v1 offset:16 +; GCN-NEXT: ds_read_b128 a[8:11], v1 offset:32 +; GCN-NEXT: ds_read_b128 a[12:15], v1 offset:48 ; GCN-NEXT: s_waitcnt lgkmcnt(0) -; GCN-NEXT: v_mfma_f32_32x32x1f32 a[128:159], v2, v3, a[128:159] +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[0:31], v2, v3, a[0:31] ; GCN-NEXT: ds_read_b128 a[124:127], v1 offset:8304 ; GCN-NEXT: ds_read_b128 a[120:123], v1 offset:8288 ; GCN-NEXT: ds_read_b128 a[116:119], v1 offset:8272 @@ -200,30 +176,47 @@ ; GCN-NEXT: ds_read_b128 a[104:107], v1 offset:8224 ; GCN-NEXT: ds_read_b128 a[100:103], v1 offset:8208 ; GCN-NEXT: ds_read_b128 a[96:99], v1 offset:8192 -; GCN-NEXT: s_waitcnt lgkmcnt(0) -; GCN-NEXT: v_mfma_f32_32x32x1f32 a[96:127], v2, v3, a[96:127] +; GCN-NEXT: v_add_u32_e32 v0, s1, v0 ; GCN-NEXT: ds_read_b128 a[92:95], v1 offset:24688 ; GCN-NEXT: ds_read_b128 a[88:91], v1 offset:24672 ; GCN-NEXT: ds_read_b128 a[84:87], v1 offset:24656 ; GCN-NEXT: ds_read_b128 a[80:83], v1 offset:24640 ; GCN-NEXT: ds_read_b128 a[76:79], v1 offset:24624 ; GCN-NEXT: ds_read_b128 a[72:75], v1 offset:24608 -; GCN-NEXT: s_nop 2 -; GCN-NEXT: ds_write_b128 v0, a[156:159] offset:112 -; GCN-NEXT: ds_write_b128 v0, a[152:155] offset:96 -; GCN-NEXT: ds_write_b128 v0, a[148:151] offset:80 -; GCN-NEXT: ds_write_b128 v0, a[144:147] offset:64 -; GCN-NEXT: ds_write_b128 v0, a[140:143] offset:48 -; GCN-NEXT: ds_write_b128 v0, a[136:139] offset:32 -; GCN-NEXT: ds_write_b128 v0, a[132:135] offset:16 -; GCN-NEXT: ds_write_b128 v0, a[128:131] +; GCN-NEXT: s_nop 3 +; GCN-NEXT: ds_write_b128 v0, a[28:31] offset:112 +; GCN-NEXT: s_waitcnt lgkmcnt(7) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[96:127], v2, v3, a[96:127] +; GCN-NEXT: ds_read_b128 a[68:71], v1 offset:24592 +; GCN-NEXT: ds_read_b128 a[64:67], v1 offset:24576 +; GCN-NEXT: ds_write_b128 v0, a[24:27] offset:96 +; GCN-NEXT: ds_write_b128 v0, a[20:23] offset:80 +; GCN-NEXT: ds_write_b128 v0, a[16:19] offset:64 +; GCN-NEXT: ds_write_b128 v0, a[12:15] offset:48 +; GCN-NEXT: ds_write_b128 v0, a[8:11] offset:32 +; GCN-NEXT: ds_write_b128 v0, a[4:7] offset:16 +; GCN-NEXT: ds_write_b128 v0, a[0:3] ; GCN-NEXT: v_mov_b32_e32 v0, s1 -; GCN-NEXT: s_waitcnt lgkmcnt(8) -; GCN-NEXT: v_mfma_f32_32x32x1f32 a[64:95], v2, v3, a[64:95] -; GCN-NEXT: ds_write_b128 v0, a[56:59] offset:24672 -; GCN-NEXT: ds_write_b128 v0, a[60:63] offset:24688 -; GCN-NEXT: ds_write_b128 v0, a[48:51] offset:24640 +; GCN-NEXT: ds_read_b128 a[28:31], v1 offset:49264 +; GCN-NEXT: ds_read_b128 a[24:27], v1 offset:49248 +; GCN-NEXT: ds_read_b128 a[20:23], v1 offset:49232 +; GCN-NEXT: ds_read_b128 a[16:19], v1 offset:49216 +; GCN-NEXT: ds_read_b128 a[12:15], v1 offset:49200 +; GCN-NEXT: ds_read_b128 a[8:11], v1 offset:49184 +; GCN-NEXT: ds_read_b128 a[4:7], v1 offset:49168 +; GCN-NEXT: ds_read_b128 a[0:3], v1 offset:49152 +; GCN-NEXT: v_add_u32_e32 v4, 0x6000, v1 ; GCN-NEXT: ds_write_b128 v0, a[120:123] offset:8288 +; GCN-NEXT: s_waitcnt lgkmcnt(14) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[64:95], v2, v3, a[64:95] +; GCN-NEXT: ds_read_b128 a[60:63], v4 offset:57456 +; GCN-NEXT: ds_read_b128 a[56:59], v4 offset:57440 +; GCN-NEXT: ds_read_b128 a[52:55], v4 offset:57424 +; GCN-NEXT: ds_read_b128 a[48:51], v4 offset:57408 +; GCN-NEXT: ds_read_b128 a[32:35], v4 offset:57344 +; GCN-NEXT: ds_read_b128 a[36:39], v4 offset:57360 +; GCN-NEXT: ds_read_b128 a[40:43], v4 offset:57376 +; GCN-NEXT: ds_read_b128 a[44:47], v4 offset:57392 ; GCN-NEXT: ds_write_b128 v0, a[124:127] offset:8304 ; GCN-NEXT: ds_write_b128 v0, a[112:115] offset:8256 ; GCN-NEXT: ds_write_b128 v0, a[116:119] offset:8272 @@ -231,15 +224,10 @@ ; GCN-NEXT: ds_write_b128 v0, a[108:111] offset:8240 ; GCN-NEXT: ds_write_b128 v0, a[96:99] offset:8192 ; GCN-NEXT: ds_write_b128 v0, a[100:103] offset:8208 -; GCN-NEXT: ds_write_b128 v0, a[52:55] offset:24656 -; GCN-NEXT: ds_write_b128 v0, a[40:43] offset:24608 -; GCN-NEXT: ds_write_b128 v0, a[44:47] offset:24624 -; GCN-NEXT: ds_write_b128 v0, a[32:35] offset:24576 -; GCN-NEXT: ds_write_b128 v0, a[36:39] offset:24592 -; GCN-NEXT: ds_write_b128 v0, a[24:27] offset:32864 -; GCN-NEXT: ds_write_b128 v0, a[28:31] offset:32880 -; GCN-NEXT: ds_write_b128 v0, a[16:19] offset:32832 +; GCN-NEXT: s_nop 3 ; GCN-NEXT: ds_write_b128 v0, a[88:91] offset:16480 +; GCN-NEXT: s_waitcnt lgkmcnt(14) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[0:31], v2, v3, a[0:31] ; GCN-NEXT: ds_write_b128 v0, a[92:95] offset:16496 ; GCN-NEXT: ds_write_b128 v0, a[80:83] offset:16448 ; GCN-NEXT: ds_write_b128 v0, a[84:87] offset:16464 @@ -247,11 +235,28 @@ ; GCN-NEXT: ds_write_b128 v0, a[76:79] offset:16432 ; GCN-NEXT: ds_write_b128 v0, a[64:67] offset:16384 ; GCN-NEXT: ds_write_b128 v0, a[68:71] offset:16400 -; GCN-NEXT: ds_write_b128 v0, a[20:23] offset:32848 -; GCN-NEXT: ds_write_b128 v0, a[8:11] offset:32800 -; GCN-NEXT: ds_write_b128 v0, a[12:15] offset:32816 -; GCN-NEXT: ds_write_b128 v0, a[0:3] offset:32768 -; GCN-NEXT: ds_write_b128 v0, a[4:7] offset:32784 +; GCN-NEXT: s_nop 7 +; GCN-NEXT: s_nop 3 +; GCN-NEXT: ds_write_b128 v0, a[24:27] offset:24672 +; GCN-NEXT: s_waitcnt lgkmcnt(14) +; GCN-NEXT: v_mfma_f32_32x32x1f32 a[32:63], v2, v3, a[32:63] +; GCN-NEXT: ds_write_b128 v0, a[28:31] offset:24688 +; GCN-NEXT: ds_write_b128 v0, a[16:19] offset:24640 +; GCN-NEXT: ds_write_b128 v0, a[20:23] offset:24656 +; GCN-NEXT: ds_write_b128 v0, a[8:11] offset:24608 +; GCN-NEXT: ds_write_b128 v0, a[12:15] offset:24624 +; GCN-NEXT: ds_write_b128 v0, a[0:3] offset:24576 +; GCN-NEXT: ds_write_b128 v0, a[4:7] offset:24592 +; GCN-NEXT: s_nop 7 +; GCN-NEXT: s_nop 3 +; GCN-NEXT: ds_write_b128 v0, a[56:59] offset:32864 +; GCN-NEXT: ds_write_b128 v0, a[60:63] offset:32880 +; GCN-NEXT: ds_write_b128 v0, a[48:51] offset:32832 +; GCN-NEXT: ds_write_b128 v0, a[52:55] offset:32848 +; GCN-NEXT: ds_write_b128 v0, a[40:43] offset:32800 +; GCN-NEXT: ds_write_b128 v0, a[44:47] offset:32816 +; GCN-NEXT: ds_write_b128 v0, a[32:35] offset:32768 +; GCN-NEXT: ds_write_b128 v0, a[36:39] offset:32784 ; GCN-NEXT: s_endpgm entry: call void @llvm.amdgcn.iglp.opt(i32 1)