diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -74,7 +74,7 @@ match(II, m_Intrinsic(m_Value())); } -static bool isAMXInstrinsic(User *I) { +static bool isAMXInstrinsic(Value *I) { auto *II = dyn_cast(I); if (!II) return false; @@ -908,6 +908,56 @@ return true; } +// %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector.v256i32(x86_amx %42) +// store <256 x i32> %43, <256 x i32>* %p, align 64 +// --> +// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, +// i64 64, x86_amx %42) +static void combineCastStore(IntrinsicInst *Cast, StoreInst *ST) { + Value *Tile = Cast->getOperand(0); + // TODO: If Tile is cast intrinsic or phi node, just return + if (!isAMXInstrinsic(Tile)) + return; + auto *II = cast(Tile); + // Tile is output from AMX intrinsic. The first operand of the + // intrinsic is row, the second operand of the intrinsic is column. + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + IRBuilder<> Builder(ST); + // Use the maximum column as stride. It must be the same with load + // stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride, Tile}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); +} + +// %65 = load <256 x i32>, <256 x i32>* %p, align 64 +// %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) +// --> +// %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, +// i8* %p, i64 64) +static void combineLoadCast(IntrinsicInst *Cast, LoadInst *LD) { + Value *Row = nullptr, *Col = nullptr; + Use &U = *(Cast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = cast(U.getUser()); + if (!isAMXInstrinsic(II)) + return; + std::tie(Row, Col) = getShape(II, OpNo); + IRBuilder<> Builder(Cast); + // Use the maximun column as stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride}; + + Value *NewInst = + Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); + Cast->replaceAllUsesWith(NewInst); +} + bool X86LowerAMXCast::combineAMXcast(TargetLibraryInfo *TLI) { bool Change = false; // Collect tile cast instruction. @@ -943,6 +993,39 @@ II->replaceAllUsesWith(Inst->getOperand(0)); Change = true; } + IntrinsicInst *II = dyn_cast(Inst); + // %43 = call <256 x i32> @llvm.x86.cast.tile.to.vector(x86_amx %42) + // store <256 x i32> %43, <256 x i32>* %p, align 64 + // --> + // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %p, + // i64 64, x86_amx %42) + if (II->getIntrinsicID() == Intrinsic::x86_cast_tile_to_vector) { + SmallVector DeadStores; + for (User *U : Inst->users()) { + StoreInst *Store = dyn_cast(U); + if (!Store) + continue; + combineCastStore(cast(Inst), Store); + DeadStores.push_back(Store); + Change = true; + } + for (auto *Store : DeadStores) + Store->eraseFromParent(); + } else { // x86_cast_vector_to_tile + SmallVector DeadLoads; + LoadInst *Load = dyn_cast(Inst->getOperand(0)); + if (!Load || !Load->hasOneUse()) + continue; + // %65 = load <256 x i32>, <256 x i32>* %p, align 64 + // %66 = call x86_amx @llvm.x86.cast.vector.to.tile(<256 x i32> %65) + // --> + // %66 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, + // i8* %p, i64 64) + combineLoadCast(cast(Inst), Load); + // Set the operand is null so that load instruction can be erased. + Inst->setOperand(0, nullptr); + Load->eraseFromParent(); + } } };