diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -263,9 +263,10 @@ } struct BBInfo { - bool HasAMX = false; - bool HasCallBeforeAMX = false; - bool HasAMXBeforeCallInSuccs = false; + // Note: This is true only when there's no call instructions between the end + // of current BB and the AMX instructions. + bool HasAMXInSuccs = false; + bool PropagateToPred = false; MachineInstr *LastCall = nullptr; BBInfo() = default; @@ -274,13 +275,12 @@ MachineBasicBlock::iterator MII = MI ? MI->getIterator() : MBB->begin(); for (auto E = MBB->end(); MII != E; ++MII) { if (isAMXInstruction(MII)) { - HasAMX = true; if (LastCall) CfgNeedInsert.insert(LastCall); + else + PropagateToPred = true; } else if (MII->isCall()) { LastCall = &*MII; - if (!HasAMX) - HasCallBeforeAMX = true; } } } @@ -296,6 +296,13 @@ MachineBasicBlock *MBB = MI->getParent(); BBVisitedInfo[MBB] = BBInfo(CfgNeedInsert, MBB, MI); + // The entry BB is special, since it always has a ldtilecfg before AMX + // instruction. We don't need to insert another ldtilecfg even if its + // predecessor BBs have calls. + // FIXME: This case happens only when the entry BB is in a loop. We need to + // hoist the first ldtilecfg out of the loop in future. + BBVisitedInfo[MBB].PropagateToPred = false; + WorkList.push_back(MBB); while (!WorkList.empty()) { MBB = WorkList.pop_back_val(); @@ -312,18 +319,19 @@ WorkList.push_back(I.first); while (!WorkList.empty()) { MBB = WorkList.pop_back_val(); - if (BBVisitedInfo[MBB].HasCallBeforeAMX || - (!BBVisitedInfo[MBB].HasAMX && - !BBVisitedInfo[MBB].HasAMXBeforeCallInSuccs)) - continue; - for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) { - if (!BBVisitedInfo.count(*I) || - BBVisitedInfo[*I].HasAMXBeforeCallInSuccs) - continue; - if (BBVisitedInfo[*I].LastCall) - CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall); - BBVisitedInfo[*I].HasAMXBeforeCallInSuccs = true; - WorkList.push_back(*I); + if (BBVisitedInfo[MBB].PropagateToPred) { + for (auto I = MBB->pred_begin(), E = MBB->pred_end(); I != E; ++I) { + if (!BBVisitedInfo.count(*I) || BBVisitedInfo[*I].HasAMXInSuccs) + continue; + BBVisitedInfo[*I].HasAMXInSuccs = true; + if (BBVisitedInfo[*I].LastCall) { + CfgNeedInsert.insert(BBVisitedInfo[*I].LastCall); + if (!BBVisitedInfo[*I].PropagateToPred) + continue; + } + BBVisitedInfo[*I].PropagateToPred = true; + WorkList.push_back(*I); + } } } } diff --git a/llvm/test/CodeGen/X86/AMX/amx-across-func.ll b/llvm/test/CodeGen/X86/AMX/amx-across-func.ll --- a/llvm/test/CodeGen/X86/AMX/amx-across-func.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-across-func.ll @@ -280,15 +280,14 @@ ; CHECK-NEXT: .p2align 4, 0x90 ; CHECK-NEXT: .LBB3_1: # =>This Inner Loop Header: Depth=1 ; CHECK-NEXT: callq foo -; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) -; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) -; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) -; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; CHECK-NEXT: testl %ebx, %ebx ; CHECK-NEXT: jle .LBB3_3 ; CHECK-NEXT: # %bb.2: # in Loop: Header=BB3_1 Depth=1 ; CHECK-NEXT: vpxord %zmm0, %zmm0, %zmm0 ; CHECK-NEXT: vmovdqu64 %zmm0, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $1, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb $8, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw $8, {{[0-9]+}}(%rsp) ; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) ; CHECK-NEXT: tileloadd (%r14,%r15), %tmm0 ; CHECK-NEXT: movabsq $64, %rax