diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -709,7 +709,7 @@ public: X86LowerAMXCast(Function &F) : Func(F), DT(nullptr) {} - void combineCastStore(IntrinsicInst *Cast, StoreInst *ST); + bool combineCastStore(IntrinsicInst *Cast, StoreInst *ST); bool combineLoadCast(IntrinsicInst *Cast, LoadInst *LD); bool combineLdSt(SmallVectorImpl &Casts); bool combineAMXcast(TargetLibraryInfo *TLI); @@ -922,12 +922,12 @@ // --> // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, // i64 64, x86_amx %42) -void X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { +bool X86LowerAMXCast::combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { Value *Tile = Cast->getOperand(0); // TODO: If it is cast intrinsic or phi node, we can propagate the // shape information through def-use chain. if (!isAMXIntrinsic(Tile)) - return; + return false; auto *II = cast(Tile); // Tile is output from AMX intrinsic. The first operand of the // intrinsic is row, the second operand of the intrinsic is column. @@ -942,6 +942,7 @@ std::array Args = {Row, Col, I8Ptr, Stride, Tile}; Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, std::nullopt, Args); + return true; } // %65 = load <256 x i32>, <256 x i32>* %p, align 64 @@ -1006,9 +1007,10 @@ StoreInst *Store = dyn_cast(U); if (!Store) continue; - combineCastStore(cast(Cast), Store); - DeadStores.push_back(Store); - Change = true; + if (combineCastStore(cast(Cast), Store)) { + DeadStores.push_back(Store); + Change = true; + } } for (auto *Store : DeadStores) Store->eraseFromParent();