diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -57,6 +57,9 @@ ++I, ++Pos) MI = &*I; } + MIRef(MachineInstr *MI) + : MI(MI), MBB(MI->getParent()), + Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} MIRef(MachineInstr *MI, MachineBasicBlock *MBB) : MI(MI), MBB(MBB), Pos(std::distance(MBB->instr_begin(), ++MI->getIterator())) {} @@ -66,6 +69,7 @@ bool operator==(const MIRef &RHS) const { return MI == RHS.MI && MBB == RHS.MBB; } + bool operator!=(const MIRef &RHS) const { return !(*this == RHS); } bool operator<(const MIRef &RHS) const { return MBB < RHS.MBB || (MBB == RHS.MBB && Pos < RHS.Pos); } @@ -77,7 +81,7 @@ struct BBInfo { MIRef FirstAMX; MIRef LastCall; - MIRef LastShape; + bool HasAMXRegLiveIn = false; bool TileCfgForbidden = false; bool NeedTileCfgLiveIn = false; }; @@ -86,8 +90,8 @@ MachineRegisterInfo *MRI; const MachineLoopInfo *MLI; SmallSet DefVisited; - SmallSet ShapeBBs; DenseMap BBVisitedInfo; + DenseMap> ShapeBBs; /// Check if the callee will clobber AMX registers. bool isDestructiveCall(MachineInstr &MI, BitVector UsableRegs) { @@ -124,6 +128,32 @@ /// Collect the shape def information for later use. void collectShapeInfo(MachineInstr &MI); + /// Try to hoist shapes definded below AMX instructions. + bool hoistShapesInBB(MachineBasicBlock *MBB, SmallVectorImpl &Shapes) { + MIRef &FirstAMX = BBVisitedInfo[MBB].FirstAMX; + auto FirstShapeBelowAMX = llvm::lower_bound(Shapes, FirstAMX); + auto InsertPoint = FirstAMX.MI->getIterator(); + for (auto I = FirstShapeBelowAMX, E = Shapes.end(); I != E; ++I) { + // Do not hoist instructions that access memory. + if (I->MI->mayLoadOrStore()) + return false; + for (auto &MO : I->MI->operands()) { + if (MO.isDef()) + continue; + // Do not hoist instructions if the sources' def under AMX instruction. + // TODO: We can handle isMoveImmediate MI here. + if (MO.isReg() && MIRef(MRI->getVRegDef(MO.getReg())) > FirstAMX) + return false; + // TODO: Maybe need more checks here. + } + MBB->insert(InsertPoint, I->MI->removeFromParent()); + } + // We only need to mark the last shape in the BB now. + Shapes.clear(); + Shapes.push_back(MIRef(&*--InsertPoint, MBB)); + return true; + } + public: X86PreTileConfig() : MachineFunctionPass(ID) {} @@ -165,9 +195,9 @@ void X86PreTileConfig::collectShapeInfo(MachineInstr &MI) { auto RecordShape = [&](MachineInstr *MI, MachineBasicBlock *MBB) { MIRef MIR(MI, MBB); - if (BBVisitedInfo[MBB].LastShape < MIR) - BBVisitedInfo[MBB].LastShape = MIR; - ShapeBBs.insert(MBB); + auto I = llvm::lower_bound(ShapeBBs[MBB], MIR); + if (I == ShapeBBs[MBB].end() || *I != MIR) + ShapeBBs[MBB].insert(I, MIR); }; SmallVector WorkList( @@ -229,6 +259,10 @@ else CfgLiveInBBs.push_back(&MBB); } + if (BBVisitedInfo[&MBB].FirstAMX || BBVisitedInfo[&MBB].HasAMXRegLiveIn) + for (auto *Succ : MBB.successors()) + if (!isLoopBackEdge(Succ, &MBB)) + BBVisitedInfo[Succ].HasAMXRegLiveIn = true; } // Update NeedTileCfgLiveIn for predecessors. @@ -252,8 +286,17 @@ return false; // Avoid to insert ldtilecfg before any shape defs. - SmallVector WorkList( - make_range(ShapeBBs.begin(), ShapeBBs.end())); + SmallVector WorkList; + for (auto &I : ShapeBBs) { + // TODO: We can hoist shapes across BBs here. + if (BBVisitedInfo[I.first].HasAMXRegLiveIn) + REPORT_CONFIG_FAIL + if (BBVisitedInfo[I.first].FirstAMX && + BBVisitedInfo[I.first].FirstAMX < I.second.back() && + !hoistShapesInBB(I.first, I.second)) + REPORT_CONFIG_FAIL + WorkList.push_back(I.first); + } while (!WorkList.empty()) { MachineBasicBlock *MBB = WorkList.pop_back_val(); for (auto *Pred : MBB->predecessors()) { @@ -282,9 +325,6 @@ } else { // Avoid the BB to be multi visited. VisitedOrInserted.insert(I); - // We cannot sink it across any AMX instruction. - if (BBVisitedInfo[I.MBB].FirstAMX) - REPORT_CONFIG_FAIL; // Sink the inserting point along the chain with NeedTileCfgLiveIn = // true when MBB isn't all shapes reachable. for (auto *Succ : I.MBB->successors()) @@ -296,14 +336,9 @@ // A given point might be forked due to shape conditions are not met. for (MIRef I : InsertPoints) { - // Even MBB is all shapes reachable, we still need to check if there's - // AMX that intersects with shapes in the same MBB. - if (BBVisitedInfo[I.MBB].FirstAMX && - BBVisitedInfo[I.MBB].FirstAMX < BBVisitedInfo[I.MBB].LastShape) - REPORT_CONFIG_FAIL; // Make sure we insert ldtilecfg after the last shape def in MBB. - if (I < BBVisitedInfo[I.MBB].LastShape) - I = BBVisitedInfo[I.MBB].LastShape; + if (ShapeBBs.count(I.MBB) && I < ShapeBBs[I.MBB].back()) + I = ShapeBBs[I.MBB].back(); // There're chances the MBB is sunk more than once. Record it to avoid // multi insert. if (VisitedOrInserted.insert(I).second) { diff --git a/llvm/test/CodeGen/X86/AMX/amx-sched.ll b/llvm/test/CodeGen/X86/AMX/amx-sched.ll --- a/llvm/test/CodeGen/X86/AMX/amx-sched.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-sched.ll @@ -2,6 +2,7 @@ define <256 x i32> @test_shape_sched(i16 %m, i16 %n, i16 %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b) nounwind { ; Just to make sure shape def is not scheduled across ldtilecfg. +; CHECK-LABEL: test_shape_sched: ; CHECK: ldtilecfg ; CHECK-NOT: movw %c1 = bitcast <256 x i32> %c to x86_amx @@ -12,5 +13,19 @@ ret <256 x i32> %res } +define <256 x i32> @test_shape_sched2(i16 %m, i16 %n, i16 %k, i8* %c, i8* %a, i8* %b) nounwind { +; Just to make sure shape def is not scheduled across ldtilecfg. +; CHECK-LABEL: test_shape_sched2: +; CHECK: ldtilecfg +; CHECK-NOT: movw + %aa = lshr i16 %k, 2 + %c1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %c, i64 64) + %a1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %k, i8* %a, i64 64) + %b1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %aa, i16 %n, i8* %b, i64 64) + %t = call x86_amx @llvm.x86.tdpbssd.internal(i16 %m, i16 %n, i16 %k, x86_amx %c1, x86_amx %a1, x86_amx %b1) + %res = bitcast x86_amx %t to <256 x i32> + ret <256 x i32> %res +} +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)