diff --git a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUIGroupLP.cpp @@ -81,7 +81,9 @@ }; typedef DenseMap> SUnitsToCandidateSGsMap; - +typedef function_ref, + const SIInstrInfo *)> + InstructionRuleType; // Classify instructions into groups to enable fine tuned control over the // scheduler. These groups may be more specific than current SchedModel // instruction classes. @@ -95,6 +97,9 @@ // Maximum number of SUnits that can be added to this group. std::optional MaxSize; + // The different rules each instruction in this SchedGroup must conform to + std::optional> Rules; + // SchedGroups will only synchronize with other SchedGroups that have the same // SyncID. int SyncID = 0; @@ -145,6 +150,18 @@ // Returns true if no more instructions may be added to this group. bool isFull() const { return MaxSize && Collection.size() >= *MaxSize; } + // Returns true if the SU matches all rules + bool allowedByRules(const SUnit *SU) const { + if (!Rules.has_value()) + return true; + for (auto &Rule : *Rules) { + if (!Rule(SU, Collection, TII)) { + return false; + } + } + return true; + } + // Add SU to the SchedGroup. void add(SUnit &SU) { LLVM_DEBUG(dbgs() << "For SchedGroup with mask " @@ -176,14 +193,17 @@ SchedGroupMask getMask() { return SGMask; } SchedGroup(SchedGroupMask SGMask, std::optional MaxSize, + std::optional> Rules, ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) - : SGMask(SGMask), MaxSize(MaxSize), DAG(DAG), TII(TII) { + : SGMask(SGMask), MaxSize(MaxSize), Rules(Rules), DAG(DAG), TII(TII) { SGID = NumSchedGroups++; } - SchedGroup(SchedGroupMask SGMask, std::optional MaxSize, int SyncID, - ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) - : SGMask(SGMask), MaxSize(MaxSize), SyncID(SyncID), DAG(DAG), TII(TII) { + SchedGroup(SchedGroupMask SGMask, std::optional MaxSize, + std::optional> Rules, + int SyncID, ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) + : SGMask(SGMask), MaxSize(MaxSize), Rules(Rules), SyncID(SyncID), + DAG(DAG), TII(TII) { SGID = NumSchedGroups++; } }; @@ -569,6 +589,9 @@ if (Match->isFull()) continue; + if (!Match->allowedByRules(CurrSU.first)) + continue; + LLVM_DEBUG(dbgs() << "Assigning to SchedGroup with Mask " << (int)Match->getMask() << "and ID " << CandSGID << "\n"); @@ -656,6 +679,11 @@ LLVM_DEBUG(dbgs() << "SGID # " << CandSGID << " is full\n"); continue; } + if (!Match->allowedByRules(CurrSU.first)) { + LLVM_DEBUG(dbgs() << "SGID # " << CandSGID + << " has conflicting rule\n"); + continue; + } TempCost = addEdges(SyncPipeline, CurrSU.first, CandSGID, AddedEdges); LLVM_DEBUG(dbgs() << "Cost of Group " << TempCost << "\n"); if (TempCost < BestNodeCost || BestNodeCost == -1) { @@ -723,7 +751,7 @@ makePipeline(); } -enum IGLPStrategyID : int { MFMASmallGemmOptID = 0 }; +enum IGLPStrategyID : int { MFMASmallGemmOptID = 0, DemoID = 1 }; // Implement a IGLP scheduling strategy. class IGLPStrategy { @@ -772,11 +800,82 @@ SchedGroup *SG = nullptr; for (unsigned I = 0; I < MFMACount * 3; ++I) { SG = &SyncedSchedGroups[PipelineSyncID].emplace_back( - SchedGroupMask::DS, 2, PipelineSyncID, DAG, TII); + SchedGroupMask::DS, 2, std::nullopt, PipelineSyncID, DAG, TII); SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]); SG = &SyncedSchedGroups[PipelineSyncID].emplace_back( - SchedGroupMask::MFMA, 1, PipelineSyncID, DAG, TII); + SchedGroupMask::MFMA, 1, std::nullopt, PipelineSyncID, DAG, TII); + SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]); + } +} + +class DemoOpt final : public IGLPStrategy { +public: + void applyIGLPStrategy( + DenseMap &SyncedInstrs, + DenseMap> &SyncedSchedGroups) override; + + bool shouldApplyStrategy(ScheduleDAGInstrs *DAG) override { return true; } + + DemoOpt(ScheduleDAGInstrs *DAG, const SIInstrInfo *TII) + : IGLPStrategy(DAG, TII) {} +}; + +void DemoOpt::applyIGLPStrategy( + DenseMap &SyncedInstrs, + DenseMap> &SyncedSchedGroups) { + // Count the number of MFMA instructions. + const unsigned PipelineSyncID = 0; + SchedGroup *SG = nullptr; + + // The SchedGroup has 1 MFMA and 1 DS_W, where the DS_W is a successor of the + // MFMA + InstructionRuleType Rule1 = [](const SUnit *SU, ArrayRef Collection, + const SIInstrInfo *TII) { + auto MI = SU->getInstr(); + if (MI->getOpcode() == TargetOpcode::BUNDLE) + return false; + if (TII->isDS(*MI) && MI->mayStore()) { + if (Collection.size() > 1) + return false; + if (Collection.size() == 0) + return true; + + auto OtherElt = Collection[0]; + if (TII->isMFMAorWMMA(*OtherElt->getInstr())) { + for (auto &S : OtherElt->Succs) { + if (S.getSUnit() == SU) + return true; + } + } + return false; + } + if (TII->isMFMAorWMMA(*MI)) { + if (Collection.size() > 1) + return false; + if (Collection.size() == 0) + return true; + + auto OtherElt = Collection[0]; + if (TII->isDS(*OtherElt->getInstr()) && MI->mayStore()) { + for (auto &S : OtherElt->Preds) { + if (S.getSUnit() == SU) + return true; + } + } + return false; + } + return false; + }; + + SmallVector DemoRules; + DemoRules.push_back(Rule1); + + auto Mask = SchedGroupMask::MFMA | SchedGroupMask::DS_WRITE; + + for (unsigned I = 0; I < 5; ++I) { + SG = &SyncedSchedGroups[PipelineSyncID].emplace_back( + Mask, 2, DemoRules, PipelineSyncID, DAG, TII); SG->initSchedGroup(SyncedInstrs[SG->getSyncID()]); } } @@ -787,6 +886,8 @@ switch (ID) { case MFMASmallGemmOptID: return std::make_unique(DAG, TII); + case DemoID: + return std::make_unique(DAG, TII); } llvm_unreachable("Unknown IGLPStrategyID"); @@ -844,6 +945,7 @@ bool SchedGroup::canAddMI(const MachineInstr &MI) const { bool Result = false; + if (MI.isMetaInstruction()) Result = false; @@ -950,6 +1052,17 @@ bool SchedGroup::canAddSU(SUnit &SU) const { MachineInstr &MI = *SU.getInstr(); + + // At SchedGroup init time, collections will be empty. Thus, any rule + // inspecting the stored contents of collections will not be relevant during + // SchedGroup initialization + if (Rules.has_value()) { + for (auto &Rule : *Rules) { + if (!Rule(&SU, Collection, TII)) + return false; + } + } + if (MI.getOpcode() != TargetOpcode::BUNDLE) return canAddMI(MI); @@ -1049,7 +1162,7 @@ resetEdges(SchedBarrier, DAG); auto InvertedMask = invertSchedBarrierMask((SchedGroupMask)MI.getOperand(0).getImm()); - SchedGroup SG(InvertedMask, std::nullopt, DAG, TII); + SchedGroup SG(InvertedMask, std::nullopt, std::nullopt, DAG, TII); SG.initSchedGroup(); // Preserve original instruction ordering relative to the SCHED_BARRIER. SG.link( @@ -1104,8 +1217,8 @@ int32_t Size = SGB.getOperand(1).getImm(); int32_t SyncID = SGB.getOperand(2).getImm(); - auto &SG = SyncedSchedGroups[SyncID].emplace_back((SchedGroupMask)SGMask, - Size, SyncID, DAG, TII); + auto &SG = SyncedSchedGroups[SyncID].emplace_back( + (SchedGroupMask)SGMask, Size, std::nullopt, SyncID, DAG, TII); SG.initSchedGroup(RIter, SyncedInstrs[SG.getSyncID()]); } diff --git a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll --- a/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll +++ b/llvm/test/CodeGen/AMDGPU/llvm.amdgcn.iglp.opt.ll @@ -147,9 +147,99 @@ ret void } + +define amdgpu_kernel void @test_iglp_opt_demo(ptr addrspace(3) noalias %in, ptr addrspace(3) %in2, ptr addrspace(3) noalias %out) #0 { +; GCN-LABEL: test_iglp_opt_demo: +; GCN: ; %bb.0: ; %entry +; GCN-NEXT: s_load_dwordx2 s[2:3], s[0:1], 0x24 +; GCN-NEXT: v_lshlrev_b32_e32 v4, 6, v0 +; GCN-NEXT: s_load_dword s0, s[0:1], 0x2c +; GCN-NEXT: ; iglp_opt mask(0x00000001) +; GCN-NEXT: s_waitcnt lgkmcnt(0) +; GCN-NEXT: v_mov_b32_e32 v0, s3 +; GCN-NEXT: ds_read2_b64 v[0:3], v0 offset1:1 +; GCN-NEXT: v_add_u32_e32 v5, s2, v4 +; GCN-NEXT: ds_read_b128 a[16:19], v5 +; GCN-NEXT: ds_read_b128 a[20:23], v5 offset:16 +; GCN-NEXT: ds_read_b128 a[24:27], v5 offset:32 +; GCN-NEXT: ds_read_b128 a[28:31], v5 offset:48 +; GCN-NEXT: v_add_u32_e32 v4, s0, v4 +; GCN-NEXT: ds_read_b128 a[12:15], v5 offset:4144 +; GCN-NEXT: ds_read_b128 a[8:11], v5 offset:4128 +; GCN-NEXT: s_waitcnt lgkmcnt(2) +; GCN-NEXT: v_mfma_f32_32x32x8f16 a[16:31], v[0:1], v[2:3], a[16:31] cbsz:1 abid:2 blgp:3 +; GCN-NEXT: ds_read_b128 a[4:7], v5 offset:4112 +; GCN-NEXT: ds_read_b128 a[0:3], v5 offset:4096 +; GCN-NEXT: ds_read_b128 a[40:43], v5 offset:12336 +; GCN-NEXT: ds_read_b128 a[36:39], v5 offset:12320 +; GCN-NEXT: ds_read_b128 a[32:35], v5 offset:12304 +; GCN-NEXT: v_mov_b32_e32 v6, s0 +; GCN-NEXT: s_nop 7 +; GCN-NEXT: s_nop 4 +; GCN-NEXT: ds_write_b128 v4, a[28:31] offset:48 +; GCN-NEXT: ds_read_b128 a[28:31], v5 offset:12288 +; GCN-NEXT: s_waitcnt lgkmcnt(5) +; GCN-NEXT: v_mfma_f32_32x32x8f16 a[0:15], v[0:1], v[2:3], a[0:15] cbsz:1 abid:2 blgp:3 +; GCN-NEXT: ds_write_b128 v4, a[24:27] offset:32 +; GCN-NEXT: ds_write_b128 v4, a[20:23] offset:16 +; GCN-NEXT: ds_write_b128 v4, a[16:19] +; GCN-NEXT: s_waitcnt lgkmcnt(3) +; GCN-NEXT: v_mfma_f32_32x32x8f16 a[28:43], v[0:1], v[2:3], a[28:43] cbsz:1 abid:2 blgp:3 +; GCN-NEXT: s_nop 7 +; GCN-NEXT: s_nop 5 +; GCN-NEXT: ds_write_b128 v6, a[8:11] offset:4128 +; GCN-NEXT: ds_write_b128 v6, a[12:15] offset:4144 +; GCN-NEXT: ds_write_b128 v6, a[0:3] offset:4096 +; GCN-NEXT: ds_write_b128 v6, a[4:7] offset:4112 +; GCN-NEXT: s_nop 0 +; GCN-NEXT: ds_write_b128 v6, a[36:39] offset:8224 +; GCN-NEXT: ds_write_b128 v6, a[40:43] offset:8240 +; GCN-NEXT: ds_write_b128 v6, a[28:31] offset:8192 +; GCN-NEXT: ds_write_b128 v6, a[32:35] offset:8208 +; GCN-NEXT: s_endpgm +entry: + call void @llvm.amdgcn.iglp.opt(i32 1) + %idx = call i32 @llvm.amdgcn.workitem.id.x() + %load.0.addr = getelementptr <16 x float>, ptr addrspace(3) %in, i32 %idx + %load.0 = load <16 x float>, ptr addrspace(3) %load.0.addr + %load.1.addr = getelementptr <16 x float>, ptr addrspace(3) %load.0.addr, i32 64 + %load.1 = load <16 x float>, ptr addrspace(3) %load.1.addr + %load.2.addr = getelementptr <16 x float>, ptr addrspace(3) %load.1.addr, i32 128 + %load.2 = load <16 x float>, ptr addrspace(3) %load.2.addr + %load.3.addr = getelementptr <16 x float>, ptr addrspace(3) %load.1.addr, i32 128 + %load.3 = load <16 x float>, ptr addrspace(3) %load.2.addr + %load.4.addr = getelementptr <16 x float>, ptr addrspace(3) %load.1.addr, i32 128 + %load.4 = load <16 x float>, ptr addrspace(3) %load.2.addr + %load.5.addr = getelementptr <16 x float>, ptr addrspace(3) %load.1.addr, i32 128 + %load.5 = load <16 x float>, ptr addrspace(3) %load.2.addr + %in2.0 = load <4 x half>, ptr addrspace(3) %in2 + %inp.1 = getelementptr <4 x half>, ptr addrspace(3) %in2, i64 1 + %in2.1 = load <4 x half>, ptr addrspace(3) %inp.1 + %mai.0 = tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %in2.0, <4 x half> %in2.1, <16 x float> %load.0, i32 1, i32 2, i32 3) + %mai.1 = tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %in2.0, <4 x half> %in2.1, <16 x float> %load.1, i32 1, i32 2, i32 3) + %mai.2 = tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %in2.0, <4 x half> %in2.1, <16 x float> %load.2, i32 1, i32 2, i32 3) + %mai.3 = tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %in2.0, <4 x half> %in2.1, <16 x float> %load.3, i32 1, i32 2, i32 3) + %mai.4 = tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %in2.0, <4 x half> %in2.1, <16 x float> %load.4, i32 1, i32 2, i32 3) + %mai.5 = tail call <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half> %in2.0, <4 x half> %in2.1, <16 x float> %load.5, i32 1, i32 2, i32 3) + %store.0.addr = getelementptr <16 x float>, ptr addrspace(3) %out, i32 %idx + store <16 x float> %mai.0, ptr addrspace(3) %store.0.addr + %store.1.addr = getelementptr <16 x float>, ptr addrspace(3) %out, i32 64 + store <16 x float> %mai.1, ptr addrspace(3) %store.1.addr + %store.2.addr = getelementptr <16 x float>, ptr addrspace(3) %out, i32 128 + store <16 x float> %mai.2, ptr addrspace(3) %store.2.addr + %store.3.addr = getelementptr <16 x float>, ptr addrspace(3) %out, i32 128 + store <16 x float> %mai.2, ptr addrspace(3) %store.3.addr + %store.4.addr = getelementptr <16 x float>, ptr addrspace(3) %out, i32 128 + store <16 x float> %mai.2, ptr addrspace(3) %store.4.addr + %store.5.addr = getelementptr <16 x float>, ptr addrspace(3) %out, i32 128 + store <16 x float> %mai.2, ptr addrspace(3) %store.5.addr + ret void +} + declare void @llvm.amdgcn.iglp.opt(i32) #1 declare i32 @llvm.amdgcn.workitem.id.x() #1 declare <32 x float> @llvm.amdgcn.mfma.f32.32x32x1f32(float, float, <32 x float>, i32, i32, i32) #1 +declare <16 x float> @llvm.amdgcn.mfma.f32.32x32x8f16(<4 x half>, <4 x half>, <16 x float>, i32, i32, i32) #1 attributes #0 = { nounwind "amdgpu-flat-work-group-size"="1,256" } attributes #1 = { convergent nounwind }