diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -74,6 +74,26 @@ match(II, m_Intrinsic(m_Value())); } +static bool isAMXInstrinsic(User *I) { + auto *II = dyn_cast(I); + if (!II) + return false; + switch (II->getIntrinsicID()) { + default: + return false; + case Intrinsic::x86_tilezero_internal: + case Intrinsic::x86_tileloadd64_internal: + case Intrinsic::x86_tileloaddt164_internal: + case Intrinsic::x86_tilestored64_internal: + case Intrinsic::x86_tdpbssd_internal: + case Intrinsic::x86_tdpbsud_internal: + case Intrinsic::x86_tdpbusd_internal: + case Intrinsic::x86_tdpbuud_internal: + case Intrinsic::x86_tdpbf16ps_internal: + return true; + } +} + static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty) { Function &F = *BB->getParent(); @@ -162,6 +182,34 @@ return std::make_pair(Row, Col); } +static std::pair getShape(PHINode *Phi) { + Use &U = *(Phi->use_begin()); + unsigned OpNo = U.getOperandNo(); + User *V = U.getUser(); + while (V) { + if (isAMXCast(dyn_cast(V))) { + if (V->use_empty()) + break; + Use &U = *(V->use_begin()); + OpNo = U.getOperandNo(); + V = U.getUser(); + continue; + } else if (isAMXInstrinsic(V)) { + return getShape(cast(V), OpNo); + } else if (isa(V)) { + if (V->use_empty()) + break; + Use &U = *(Phi->use_begin()); + V = U.getUser(); + continue; + } else { + break; + } + } + + return std::make_pair(nullptr, nullptr); +} + namespace { class X86LowerAMXType { Function &Func; @@ -720,11 +768,33 @@ OldPhiNodes.insert(PN); while (!PhiWorklist.empty()) { auto *OldPN = PhiWorklist.pop_back_val(); - for (Value *IncValue : OldPN->incoming_values()) { + for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { + Value *IncValue = OldPN->getIncomingValue(I); // TODO: currently, We ignore cases where it is a const. In the future, we // might support const. - if (isa(IncValue)) - return false; + if (isa(IncValue)) { + auto *IncConst = dyn_cast(IncValue); + if (!isa(IncValue) && !IncConst->isZeroValue()) + return false; + Value *Row = nullptr, *Col = nullptr; + std::tie(Row, Col) = getShape(OldPN); + // TODO: If it is not constant the Row and Col must domoniate tilezero + // that we are going to create. + if (!Row || !Col || !isa(Row) || !isa(Col)) + return false; + // Create tilezero at the end of incoming block. + auto *Block = OldPN->getIncomingBlock(I); + BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); + Instruction *NewInst = Builder.CreateIntrinsic( + Intrinsic::x86_tilezero_internal, None, {Row, Col}); + NewInst->moveBefore(&*Iter); + NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, + {IncValue->getType()}, {NewInst}); + NewInst->moveBefore(&*Iter); + // Replace InValue with new Value. + OldPN->setIncomingValue(I, NewInst); + IncValue = NewInst; + } if (auto *PNode = dyn_cast(IncValue)) { if (OldPhiNodes.insert(PNode)) diff --git a/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll b/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll --- a/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll @@ -4,21 +4,14 @@ define void @foo_undef(i8 *%buf) { ; CHECK-LABEL: @foo_undef( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 -; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP0:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) ; CHECK-NEXT: br i1 undef, label [[L1:%.*]], label [[L2:%.*]] ; CHECK: l1: ; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) -; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP2]], i64 32, x86_amx [[T1]]) -; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 ; CHECK-NEXT: br i1 undef, label [[L2]], label [[EXIT:%.*]] ; CHECK: l2: -; CHECK-NEXT: [[T3:%.*]] = phi <256 x i32> [ undef, [[ENTRY:%.*]] ], [ [[TMP3]], [[L1]] ] -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* -; CHECK-NEXT: store <256 x i32> [[T3]], <256 x i32>* [[TMP0]], align 1024 -; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* [[TMP4]], i64 32) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP1:%.*]] = phi x86_amx [ [[TMP0]], [[ENTRY:%.*]] ], [ [[T1]], [[L1]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP1]]) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void @@ -44,21 +37,14 @@ define void @foo_zero(i8 *%buf) { ; CHECK-LABEL: @foo_zero( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 -; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP0:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) ; CHECK-NEXT: br i1 undef, label [[L1:%.*]], label [[L2:%.*]] ; CHECK: l1: ; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) -; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP2]], i64 32, x86_amx [[T1]]) -; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 ; CHECK-NEXT: br i1 undef, label [[L2]], label [[EXIT:%.*]] ; CHECK: l2: -; CHECK-NEXT: [[T3:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[TMP3]], [[L1]] ] -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* -; CHECK-NEXT: store <256 x i32> [[T3]], <256 x i32>* [[TMP0]], align 1024 -; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* [[TMP4]], i64 32) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP1:%.*]] = phi x86_amx [ [[TMP0]], [[ENTRY:%.*]] ], [ [[T1]], [[L1]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP1]]) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void diff --git a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll --- a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll +++ b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll @@ -187,33 +187,26 @@ define void @fail_to_combine_amx_cast_and_phi_due_to_const_value() { ; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi_due_to_const_value( ; CHECK-NEXT: wrapper_entry: -; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64 -; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64 -; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64 -; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64 -; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 11, i16 40) ; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] ; CHECK: for.body.i.lr.ph.i: -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8* -; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512 -; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40) -; CHECK-NEXT: [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8* -; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024 -; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56) -; CHECK-NEXT: [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8* -; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024 -; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40) -; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]]) -; CHECK-NEXT: [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]]) -; CHECK-NEXT: [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512 +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512 +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40) +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]]) ; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]] ; CHECK: for.cond.cleanup.i.i: -; CHECK-NEXT: [[EVILPHI:%.*]] = phi <110 x i32> [ undef, [[WRAPPER_ENTRY:%.*]] ], [ [[TMP13]], [[FOR_BODY_I_LR_PH_I]] ] -; CHECK-NEXT: [[TMP14:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8* -; CHECK-NEXT: store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512 -; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP14]], i64 40) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]]) +; CHECK-NEXT: [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]]) ; CHECK-NEXT: ret void ; wrapper_entry: