diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -18,6 +18,7 @@ // #include "X86.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" @@ -81,18 +82,47 @@ return std::make_pair(Row, Col); } +// Start is load instruction. Check if the load can sink through +// the following instruction. The sink don't across basic block. +// TODO: improve the sink across basic block. +static Instruction *findSafePointToSink(Instruction &LD, Instruction &User, + AliasAnalysis *AA) { + Instruction *Start = &*(++LD.getIterator()); + Instruction *End = nullptr; + BasicBlock *BB = LD.getParent(); + if (LD.getParent() == User.getParent()) + End = &User; + else + End = BB->getTerminator(); + + for (Instruction &Inst : + make_range(Start->getIterator(), End->getIterator())) { + if (Inst.mayThrow()) + return &Inst; + if (!isa(&Inst)) + continue; + MemoryLocation LoadLoc = MemoryLocation::get(cast(&LD)); + MemoryLocation StoreLoc = MemoryLocation::get(cast(&Inst)); + if (!AA->isNoAlias(LoadLoc, StoreLoc)) + return &Inst; + } + + return End; +} + // %1 = load x86_amx, x86_amx* %0, align 64 // %2 = call x86_amx @llvm.x86.tdpbssd.internal(%1, %1, %1, ...) // --> // %1 = call x86_amx @llvm.x86.tileloadd64.internal() // %2 = call x86_amx @llvm.x86.tdpbssd.internal(%1, %1, %1, ...) -static void transformTileLoad(LoadInst *LD) { +static void transformTileLoad(LoadInst *LD, AliasAnalysis *AA) { Value *Row = nullptr, *Col = nullptr; Use &U = *(LD->use_begin()); unsigned OpNo = U.getOperandNo(); auto *II = cast(U.getUser()); std::tie(Row, Col) = getShape(II, OpNo); - IRBuilder<> Builder(LD); + auto *InsertPt = findSafePointToSink(*LD, *II, AA); + IRBuilder<> Builder(InsertPt); // Use the maximun column as stride. Value *Stride = Builder.getInt64(64); Value *I8Ptr = @@ -225,9 +255,10 @@ namespace { class X86LowerAMXType { Function &Func; + AliasAnalysis *AA; public: - X86LowerAMXType(Function &F) : Func(F) {} + X86LowerAMXType(Function &F, AliasAnalysis *AA) : Func(F), AA(AA) {} bool visit(); }; @@ -249,7 +280,7 @@ // %0 = bitcast <256 x i32>* %tile to x86_amx* // %1 = load x86_amx, x86_amx* %0, align 64 if (LD) { - transformTileLoad(LD); + transformTileLoad(LD, AA); DeadInsts.push_back(LD); } auto *ST = dyn_cast(I); @@ -327,12 +358,14 @@ } bool runOnFunction(Function &F) override { - X86LowerAMXType LAT(F); + auto *AA = &getAnalysis().getAAResults(); + X86LowerAMXType LAT(F, AA); bool C = LAT.visit(); return C; } void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); AU.setPreservesCFG(); } }; @@ -343,6 +376,7 @@ char X86LowerAMXTypeLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false, false) diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll --- a/llvm/test/CodeGen/X86/AMX/amx-type.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -8,6 +8,194 @@ @buf = dso_local global [1024 x i8] zeroinitializer, align 16 @buf2 = dso_local global [1024 x i8] zeroinitializer, align 16 +define dso_local void @test_tile_sink(%struct.__tile_str* %a, %struct.__tile_str* %b, %struct.__tile_str* %c) #0 { +; CHECK-LABEL: @test_tile_sink( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[B_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[B:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[B_COL:%.*]] = load i16, i16* [[B_COLPTR]], align 2 +; CHECK-NEXT: [[B_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[B]], i64 0, i32 2 +; CHECK-NEXT: [[B_AMXPTR:%.*]] = bitcast <256 x i32>* [[B_VPTR]] to x86_amx* +; CHECK-NEXT: [[A_ROWPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[A_ROW:%.*]] = load i16, i16* [[A_ROWPTR]], align 64 +; CHECK-NEXT: [[A_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 1 +; CHECK-NEXT: [[A_COL:%.*]] = load i16, i16* [[A_COLPTR]], align 2 +; CHECK-NEXT: [[A_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 2 +; CHECK-NEXT: [[A_AMXPTR:%.*]] = bitcast <256 x i32>* [[A_VPTR]] to x86_amx* +; CHECK-NEXT: [[C_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[C:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[C_AMXPTR:%.*]] = bitcast <256 x i32>* [[C_VPTR]] to x86_amx* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[B_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_COL]], i16 [[B_COL]], i8* [[TMP0]], i64 64) +; CHECK-NEXT: [[TMP2:%.*]] = bitcast x86_amx* [[A_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast x86_amx* [[C_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_COL]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_COL]], i16 [[A_COL]], x86_amx [[TMP5]], x86_amx [[TMP3]], x86_amx [[TMP1]]) [[ATTR1:#.*]] +; CHECK-NEXT: ret void +; +entry: + %b.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 1 + %b.col = load i16, i16* %b.colptr, align 2 + %b.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 2 + %b.amxptr = bitcast <256 x i32>* %b.vptr to x86_amx* + %b.tile = load x86_amx, x86_amx* %b.amxptr, align 64 + %a.rowptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 0 + %a.row = load i16, i16* %a.rowptr, align 64 + %a.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 1 + %a.col = load i16, i16* %a.colptr, align 2 + %a.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 2 + %a.amxptr = bitcast <256 x i32>* %a.vptr to x86_amx* + %a.tile = load x86_amx, x86_amx* %a.amxptr, align 64 + %c.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %c, i64 0, i32 2 + %c.amxptr = bitcast <256 x i32>* %c.vptr to x86_amx* + %c.tile = load x86_amx, x86_amx* %c.amxptr, align 64, !tbaa !2 + %res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.col, i16 %a.col, x86_amx %c.tile, x86_amx %a.tile, x86_amx %b.tile) #2 + ret void +} + +define dso_local void @test_tile_sink_noalias(%struct.__tile_str* %a, %struct.__tile_str* %b, %struct.__tile_str* %c) #0 { +; CHECK-LABEL: @test_tile_sink_noalias( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[B_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[B:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[B_COL:%.*]] = load i16, i16* [[B_COLPTR]], align 2 +; CHECK-NEXT: [[B_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[B]], i64 0, i32 2 +; CHECK-NEXT: [[B_AMXPTR:%.*]] = bitcast <256 x i32>* [[B_VPTR]] to x86_amx* +; CHECK-NEXT: store i16 8, i16* [[B_COLPTR]], align 2 +; CHECK-NEXT: [[A_ROWPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[A_ROW:%.*]] = load i16, i16* [[A_ROWPTR]], align 64 +; CHECK-NEXT: [[A_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 1 +; CHECK-NEXT: [[A_COL:%.*]] = load i16, i16* [[A_COLPTR]], align 2 +; CHECK-NEXT: [[A_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 2 +; CHECK-NEXT: [[A_AMXPTR:%.*]] = bitcast <256 x i32>* [[A_VPTR]] to x86_amx* +; CHECK-NEXT: [[C_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[C:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[C_AMXPTR:%.*]] = bitcast <256 x i32>* [[C_VPTR]] to x86_amx* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[B_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_COL]], i16 [[B_COL]], i8* [[TMP0]], i64 64) +; CHECK-NEXT: [[TMP2:%.*]] = bitcast x86_amx* [[A_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast x86_amx* [[C_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_COL]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_COL]], i16 [[A_COL]], x86_amx [[TMP5]], x86_amx [[TMP3]], x86_amx [[TMP1]]) [[ATTR1]] +; CHECK-NEXT: ret void +; +entry: + %b.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 1 + %b.col = load i16, i16* %b.colptr, align 2 + %b.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 2 + %b.amxptr = bitcast <256 x i32>* %b.vptr to x86_amx* + %b.tile = load x86_amx, x86_amx* %b.amxptr, align 64 + ; test noalias + store i16 8, i16* %b.colptr + %a.rowptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 0 + %a.row = load i16, i16* %a.rowptr, align 64 + %a.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 1 + %a.col = load i16, i16* %a.colptr, align 2 + %a.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 2 + %a.amxptr = bitcast <256 x i32>* %a.vptr to x86_amx* + %a.tile = load x86_amx, x86_amx* %a.amxptr, align 64 + %c.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %c, i64 0, i32 2 + %c.amxptr = bitcast <256 x i32>* %c.vptr to x86_amx* + %c.tile = load x86_amx, x86_amx* %c.amxptr, align 64, !tbaa !2 + %res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.col, i16 %a.col, x86_amx %c.tile, x86_amx %a.tile, x86_amx %b.tile) #2 + ret void +} + +define dso_local void @test_tile_sink_alias(%struct.__tile_str* %a, %struct.__tile_str* %b, %struct.__tile_str* %c) #0 { +; CHECK-LABEL: @test_tile_sink_alias( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[B_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[B:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[B_COL:%.*]] = load i16, i16* [[B_COLPTR]], align 2 +; CHECK-NEXT: [[B_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[B]], i64 0, i32 2 +; CHECK-NEXT: [[B_AMXPTR:%.*]] = bitcast <256 x i32>* [[B_VPTR]] to x86_amx* +; CHECK-NEXT: [[A_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[A_COL:%.*]] = load i16, i16* [[A_COLPTR]], align 2 +; CHECK-NEXT: [[B_SCALARPTR:%.*]] = bitcast <256 x i32>* [[B_VPTR]] to i32* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[B_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_COL]], i16 [[B_COL]], i8* [[TMP0]], i64 64) +; CHECK-NEXT: store i32 8, i32* [[B_SCALARPTR]], align 4 +; CHECK-NEXT: [[A_ROWPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 0 +; CHECK-NEXT: [[A_ROW:%.*]] = load i16, i16* [[A_ROWPTR]], align 64 +; CHECK-NEXT: [[A_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 2 +; CHECK-NEXT: [[A_AMXPTR:%.*]] = bitcast <256 x i32>* [[A_VPTR]] to x86_amx* +; CHECK-NEXT: [[C_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[C:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[C_AMXPTR:%.*]] = bitcast <256 x i32>* [[C_VPTR]] to x86_amx* +; CHECK-NEXT: [[TMP2:%.*]] = bitcast x86_amx* [[A_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast x86_amx* [[C_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_COL]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_COL]], i16 [[A_COL]], x86_amx [[TMP5]], x86_amx [[TMP3]], x86_amx [[TMP1]]) [[ATTR1]] +; CHECK-NEXT: ret void +; +entry: + %b.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 1 + %b.col = load i16, i16* %b.colptr, align 2 + %b.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 2 + %b.amxptr = bitcast <256 x i32>* %b.vptr to x86_amx* + %b.tile = load x86_amx, x86_amx* %b.amxptr, align 64 + %a.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 1 + %a.col = load i16, i16* %a.colptr, align 2 + ; test alias + %b.scalarptr = bitcast <256 x i32>* %b.vptr to i32* + store i32 8, i32* %b.scalarptr + %a.rowptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 0 + %a.row = load i16, i16* %a.rowptr, align 64 + %a.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 2 + %a.amxptr = bitcast <256 x i32>* %a.vptr to x86_amx* + %a.tile = load x86_amx, x86_amx* %a.amxptr, align 64 + %c.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %c, i64 0, i32 2 + %c.amxptr = bitcast <256 x i32>* %c.vptr to x86_amx* + %c.tile = load x86_amx, x86_amx* %c.amxptr, align 64, !tbaa !2 + %res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.col, i16 %a.col, x86_amx %c.tile, x86_amx %a.tile, x86_amx %b.tile) #2 + ret void +} + +define dso_local void @test_tile_sink_across_bb(%struct.__tile_str* %a, %struct.__tile_str* %b, %struct.__tile_str* %c) #0 { +; CHECK-LABEL: @test_tile_sink_across_bb( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[B_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[B:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[B_COL:%.*]] = load i16, i16* [[B_COLPTR]], align 2 +; CHECK-NEXT: [[B_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[B]], i64 0, i32 2 +; CHECK-NEXT: [[B_AMXPTR:%.*]] = bitcast <256 x i32>* [[B_VPTR]] to x86_amx* +; CHECK-NEXT: [[A_COLPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[A_COL:%.*]] = load i16, i16* [[A_COLPTR]], align 2 +; CHECK-NEXT: [[A_ROWPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 0 +; CHECK-NEXT: [[A_ROW:%.*]] = load i16, i16* [[A_ROWPTR]], align 64 +; CHECK-NEXT: [[A_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[A]], i64 0, i32 2 +; CHECK-NEXT: [[A_AMXPTR:%.*]] = bitcast <256 x i32>* [[A_VPTR]] to x86_amx* +; CHECK-NEXT: [[C_VPTR:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[C:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[C_AMXPTR:%.*]] = bitcast <256 x i32>* [[C_VPTR]] to x86_amx* +; CHECK-NEXT: [[TMP0:%.*]] = bitcast x86_amx* [[B_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_COL]], i16 [[B_COL]], i8* [[TMP0]], i64 64) +; CHECK-NEXT: [[TMP2:%.*]] = bitcast x86_amx* [[A_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP3:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[A_COL]], i8* [[TMP2]], i64 64) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast x86_amx* [[C_AMXPTR]] to i8* +; CHECK-NEXT: [[TMP5:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[A_ROW]], i16 [[B_COL]], i8* [[TMP4]], i64 64) +; CHECK-NEXT: br label [[DOTPROD:%.*]] +; CHECK: dotprod: +; CHECK-NEXT: [[RES:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[A_ROW]], i16 [[B_COL]], i16 [[A_COL]], x86_amx [[TMP5]], x86_amx [[TMP3]], x86_amx [[TMP1]]) [[ATTR1]] +; CHECK-NEXT: ret void +; +entry: + %b.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 1 + %b.col = load i16, i16* %b.colptr, align 2 + %b.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %b, i64 0, i32 2 + %b.amxptr = bitcast <256 x i32>* %b.vptr to x86_amx* + %b.tile = load x86_amx, x86_amx* %b.amxptr, align 64 + %a.colptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 1 + %a.col = load i16, i16* %a.colptr, align 2 + %a.rowptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 0 + %a.row = load i16, i16* %a.rowptr, align 64 + %a.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %a, i64 0, i32 2 + %a.amxptr = bitcast <256 x i32>* %a.vptr to x86_amx* + %a.tile = load x86_amx, x86_amx* %a.amxptr, align 64 + %c.vptr = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %c, i64 0, i32 2 + %c.amxptr = bitcast <256 x i32>* %c.vptr to x86_amx* + %c.tile = load x86_amx, x86_amx* %c.amxptr, align 64, !tbaa !2 + br label %dotprod +dotprod: ; preds = %if.else, %if.then + %res = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %a.row, i16 %b.col, i16 %a.col, x86_amx %c.tile, x86_amx %a.tile, x86_amx %b.tile) #2 + ret void +} + define dso_local void @test_amx_store(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { ; CHECK-LABEL: @test_amx_store( ; CHECK-NEXT: entry: