diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -74,6 +74,24 @@ match(II, m_Intrinsic(m_Value())); } +static bool isAMXInstrinsic(User *I) { + auto *II = dyn_cast(I); + if (!II) + return false; + if (isAMXCast(II)) + return false; + // Check if return type or parameter is x86_amx. If it is x86_amx + // the intrinsic must be x86 amx intrinsics. + if (II->getType()->isX86_AMXTy()) + return true; + for (Value *V : II->args()) { + if (V->getType()->isX86_AMXTy()) + return true; + } + + return false; +} + static AllocaInst *createAllocaInstAtEntry(IRBuilder<> &Builder, BasicBlock *BB, Type *Ty) { Function &F = *BB->getParent(); @@ -162,6 +180,36 @@ return std::make_pair(Row, Col); } +static std::pair getShape(PHINode *Phi) { + Use &U = *(Phi->use_begin()); + unsigned OpNo = U.getOperandNo(); + User *V = U.getUser(); + // TODO We don't traverse all users. To make the algorithm simple, here we + // just traverse the first user. If we can find shape, then return the shape, + // otherwise just return nullptr and the optimization for undef/zero will be + // abandoned. + while (V) { + if (isAMXCast(dyn_cast(V))) { + if (V->use_empty()) + break; + Use &U = *(V->use_begin()); + OpNo = U.getOperandNo(); + V = U.getUser(); + } else if (isAMXInstrinsic(V)) { + return getShape(cast(V), OpNo); + } else if (isa(V)) { + if (V->use_empty()) + break; + Use &U = *(Phi->use_begin()); + V = U.getUser(); + } else { + break; + } + } + + return std::make_pair(nullptr, nullptr); +} + namespace { class X86LowerAMXType { Function &Func; @@ -720,11 +768,33 @@ OldPhiNodes.insert(PN); while (!PhiWorklist.empty()) { auto *OldPN = PhiWorklist.pop_back_val(); - for (Value *IncValue : OldPN->incoming_values()) { + for (unsigned I = 0; I < OldPN->getNumOperands(); ++I) { + Value *IncValue = OldPN->getIncomingValue(I); // TODO: currently, We ignore cases where it is a const. In the future, we // might support const. - if (isa(IncValue)) - return false; + if (isa(IncValue)) { + auto *IncConst = dyn_cast(IncValue); + if (!isa(IncValue) && !IncConst->isZeroValue()) + return false; + Value *Row = nullptr, *Col = nullptr; + std::tie(Row, Col) = getShape(OldPN); + // TODO: If it is not constant the Row and Col must domoniate tilezero + // that we are going to create. + if (!Row || !Col || !isa(Row) || !isa(Col)) + return false; + // Create tilezero at the end of incoming block. + auto *Block = OldPN->getIncomingBlock(I); + BasicBlock::iterator Iter = Block->getTerminator()->getIterator(); + Instruction *NewInst = Builder.CreateIntrinsic( + Intrinsic::x86_tilezero_internal, None, {Row, Col}); + NewInst->moveBefore(&*Iter); + NewInst = Builder.CreateIntrinsic(Intrinsic::x86_cast_tile_to_vector, + {IncValue->getType()}, {NewInst}); + NewInst->moveBefore(&*Iter); + // Replace InValue with new Value. + OldPN->setIncomingValue(I, NewInst); + IncValue = NewInst; + } if (auto *PNode = dyn_cast(IncValue)) { if (OldPhiNodes.insert(PNode)) diff --git a/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll b/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll --- a/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-combine-undef.ll @@ -4,21 +4,14 @@ define void @foo_undef(i8 *%buf) { ; CHECK-LABEL: @foo_undef( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 -; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP0:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) ; CHECK-NEXT: br i1 undef, label [[L1:%.*]], label [[L2:%.*]] ; CHECK: l1: ; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) -; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP2]], i64 32, x86_amx [[T1]]) -; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 ; CHECK-NEXT: br i1 undef, label [[L2]], label [[EXIT:%.*]] ; CHECK: l2: -; CHECK-NEXT: [[T3:%.*]] = phi <256 x i32> [ undef, [[ENTRY:%.*]] ], [ [[TMP3]], [[L1]] ] -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* -; CHECK-NEXT: store <256 x i32> [[T3]], <256 x i32>* [[TMP0]], align 1024 -; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* [[TMP4]], i64 32) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP1:%.*]] = phi x86_amx [ [[TMP0]], [[ENTRY:%.*]] ], [ [[T1]], [[L1]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP1]]) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void @@ -44,21 +37,14 @@ define void @foo_zero(i8 *%buf) { ; CHECK-LABEL: @foo_zero( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 -; CHECK-NEXT: [[TMP1:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[TMP0:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) ; CHECK-NEXT: br i1 undef, label [[L1:%.*]], label [[L2:%.*]] ; CHECK: l1: ; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) -; CHECK-NEXT: [[TMP2:%.*]] = bitcast <256 x i32>* [[TMP1]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP2]], i64 32, x86_amx [[T1]]) -; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 1024 ; CHECK-NEXT: br i1 undef, label [[L2]], label [[EXIT:%.*]] ; CHECK: l2: -; CHECK-NEXT: [[T3:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY:%.*]] ], [ [[TMP3]], [[L1]] ] -; CHECK-NEXT: [[TMP4:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* -; CHECK-NEXT: store <256 x i32> [[T3]], <256 x i32>* [[TMP0]], align 1024 -; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 32, i8* [[TMP4]], i64 32) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP5]]) +; CHECK-NEXT: [[TMP1:%.*]] = phi x86_amx [ [[TMP0]], [[ENTRY:%.*]] ], [ [[T1]], [[L1]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[BUF:%.*]], i64 1024, x86_amx [[TMP1]]) ; CHECK-NEXT: br label [[EXIT]] ; CHECK: exit: ; CHECK-NEXT: ret void @@ -163,8 +149,8 @@ ret void } -define void @foo_noshape(i8 *%buf) { -; CHECK-LABEL: @foo_noshape( +define void @noshape(i8 *%buf) { +; CHECK-LABEL: @noshape( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 ; CHECK-NEXT: br i1 undef, label [[L1:%.*]], label [[L2:%.*]] @@ -202,6 +188,48 @@ ret void } +define void @noshape2(i8 *%buf) { +; CHECK-LABEL: @noshape2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: br i1 undef, label [[L1:%.*]], label [[L2:%.*]] +; CHECK: l1: +; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 8, i16 32, i8* [[TMP1]], i64 32, x86_amx [[T1]]) +; CHECK-NEXT: [[TMP2:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: br i1 undef, label [[L2]], label [[EXIT:%.*]] +; CHECK: l2: +; CHECK-NEXT: [[T3:%.*]] = phi <256 x i32> [ undef, [[ENTRY:%.*]] ], [ [[TMP2]], [[L1]] ] +; CHECK-NEXT: [[T6:%.*]] = call <256 x i32> @llvm.abs.v256i32(<256 x i32> [[T3]], i1 true) +; CHECK-NEXT: [[P:%.*]] = bitcast i8* [[BUF:%.*]] to <256 x i32>* +; CHECK-NEXT: store <256 x i32> [[T6]], <256 x i32>* [[P]], align 1024 +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: ret void +; +entry: + br i1 undef, label %l1, label %l2 + +l1: + %t1 = call x86_amx @llvm.x86.tilezero.internal(i16 8, i16 32) + %t2 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t1) + br i1 undef, label %l2, label %exit + +l2: + %t3 = phi <256 x i32> [ undef, %entry ], [ %t2, %l1 ] + %t4 = call x86_amx @llvm.x86.cast.vector.to.tile.v256i32(<256 x i32> %t3) + %t5 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %t4) + %t6 = call <256 x i32> @llvm.abs.v256i32(<256 x i32> %t5, i1 1) + %p = bitcast i8* %buf to <256 x i32>* + store <256 x i32> %t6, <256 x i32>* %p + br label %exit + +exit: + ret void +} + +declare <256 x i32> @llvm.abs.v256i32(<256 x i32>, i1) declare x86_amx @llvm.x86.tilezero.internal(i16, i16) declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) declare <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx) diff --git a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll --- a/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll +++ b/llvm/test/CodeGen/X86/AMX/lat-combine-amx-bitcast.ll @@ -187,33 +187,26 @@ define void @fail_to_combine_amx_cast_and_phi_due_to_const_value() { ; CHECK-LABEL: @fail_to_combine_amx_cast_and_phi_due_to_const_value( ; CHECK-NEXT: wrapper_entry: -; CHECK-NEXT: [[TMP0:%.*]] = alloca <110 x i32>, align 64 -; CHECK-NEXT: [[TMP1:%.*]] = alloca <110 x i32>, align 64 -; CHECK-NEXT: [[TMP2:%.*]] = alloca <560 x i8>, align 64 -; CHECK-NEXT: [[TMP3:%.*]] = alloca <616 x i8>, align 64 -; CHECK-NEXT: [[TMP4:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP0:%.*]] = alloca <560 x i8>, align 64 +; CHECK-NEXT: [[TMP1:%.*]] = alloca <616 x i8>, align 64 +; CHECK-NEXT: [[TMP2:%.*]] = alloca <110 x i32>, align 64 +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 11, i16 40) ; CHECK-NEXT: br i1 undef, label [[FOR_COND_CLEANUP_I_I:%.*]], label [[FOR_BODY_I_LR_PH_I:%.*]] ; CHECK: for.body.i.lr.ph.i: -; CHECK-NEXT: [[TMP5:%.*]] = bitcast <110 x i32>* [[TMP4]] to i8* -; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP4]], align 512 -; CHECK-NEXT: [[TMP6:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP5]], i64 40) -; CHECK-NEXT: [[TMP7:%.*]] = bitcast <616 x i8>* [[TMP3]] to i8* -; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP3]], align 1024 -; CHECK-NEXT: [[TMP8:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP7]], i64 56) -; CHECK-NEXT: [[TMP9:%.*]] = bitcast <560 x i8>* [[TMP2]] to i8* -; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP2]], align 1024 -; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP9]], i64 40) -; CHECK-NEXT: [[TMP11:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP6]], x86_amx [[TMP8]], x86_amx [[TMP10]]) -; CHECK-NEXT: [[TMP12:%.*]] = bitcast <110 x i32>* [[TMP1]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* [[TMP12]], i64 40, x86_amx [[TMP11]]) -; CHECK-NEXT: [[TMP13:%.*]] = load <110 x i32>, <110 x i32>* [[TMP1]], align 512 +; CHECK-NEXT: [[TMP4:%.*]] = bitcast <110 x i32>* [[TMP2]] to i8* +; CHECK-NEXT: store <110 x i32> undef, <110 x i32>* [[TMP2]], align 512 +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP4]], i64 40) +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <616 x i8>* [[TMP1]] to i8* +; CHECK-NEXT: store <616 x i8> undef, <616 x i8>* [[TMP1]], align 1024 +; CHECK-NEXT: [[TMP7:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 56, i8* [[TMP6]], i64 56) +; CHECK-NEXT: [[TMP8:%.*]] = bitcast <560 x i8>* [[TMP0]] to i8* +; CHECK-NEXT: store <560 x i8> undef, <560 x i8>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP9:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 14, i16 40, i8* [[TMP8]], i64 40) +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tdpbssd.internal(i16 11, i16 40, i16 56, x86_amx [[TMP5]], x86_amx [[TMP7]], x86_amx [[TMP9]]) ; CHECK-NEXT: br label [[FOR_COND_CLEANUP_I_I]] ; CHECK: for.cond.cleanup.i.i: -; CHECK-NEXT: [[EVILPHI:%.*]] = phi <110 x i32> [ undef, [[WRAPPER_ENTRY:%.*]] ], [ [[TMP13]], [[FOR_BODY_I_LR_PH_I]] ] -; CHECK-NEXT: [[TMP14:%.*]] = bitcast <110 x i32>* [[TMP0]] to i8* -; CHECK-NEXT: store <110 x i32> [[EVILPHI]], <110 x i32>* [[TMP0]], align 512 -; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 11, i16 40, i8* [[TMP14]], i64 40) -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP15]]) +; CHECK-NEXT: [[TMP11:%.*]] = phi x86_amx [ [[TMP3]], [[WRAPPER_ENTRY:%.*]] ], [ [[TMP10]], [[FOR_BODY_I_LR_PH_I]] ] +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 11, i16 40, i8* undef, i64 undef, x86_amx [[TMP11]]) ; CHECK-NEXT: ret void ; wrapper_entry: