diff --git a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp --- a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp @@ -181,11 +181,16 @@ } } -static Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End, - IRBuilderBase &B, DomTreeUpdater &DTU, - LoopInfo &LI, Value *Row, Value *Col, - Value *K, Value *Acc, Value *LHS, - Value *RHS) { +template ::type> +static Value *createTileDPLoops(BasicBlock *Start, BasicBlock *End, + IRBuilderBase &B, DomTreeUpdater &DTU, + LoopInfo &LI, Value *Row, Value *Col, Value *K, + Value *Acc, Value *LHS, Value *RHS) { + std::string IntrinName = + IntrID == Intrinsic::x86_tdpbssd_internal ? "tiledpbssd" : "tdpbf16ps"; Loop *RowLoop = LI.AllocateLoop(); Loop *ColLoop = LI.AllocateLoop(); Loop *InnerLoop = LI.AllocateLoop(); @@ -197,19 +202,19 @@ LI.addTopLevelLoop(RowLoop); BasicBlock *RowBody = - createLoop(Start, End, Row, B.getInt16(1), "tiledpbssd.scalarize.rows", B, - DTU, RowLoop, LI); + createLoop(Start, End, Row, B.getInt16(1), IntrinName + ".scalarize.rows", + B, DTU, RowLoop, LI); BasicBlock *RowLatch = RowBody->getSingleSuccessor(); BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), - "tiledpbssd.scalarize.cols", B, DTU, ColLoop, LI); + IntrinName + ".scalarize.cols", B, DTU, ColLoop, LI); BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); B.SetInsertPoint(ColBody->getTerminator()); BasicBlock *InnerBody = createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), - "tiledpbssd.scalarize.inner", B, DTU, InnerLoop, LI); + IntrinName + ".scalarize.inner", B, DTU, InnerLoop, LI); BasicBlock *ColLoopHeader = ColBody->getSinglePredecessor(); BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); @@ -273,39 +278,82 @@ PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi"); VecCPhi->addIncoming(VecCPhiColLoop, ColBody); - // tiledpbssd.scalarize.inner.body: - // calculate idxa, idxb - // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc - // %elta = extractelement <256 x i32> %veca, i16 %idxa - // %eltav4i8 = bitcast i32 %elta to <4 x i8> - // %eltb = extractelement <256 x i32> %vecb, i16 %idxb - // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> - // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> - // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> - // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 - // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) - // %neweltc = add i32 %elt, %acc - // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, - // i16 %idxc - B.SetInsertPoint(InnerBody->getTerminator()); Value *IdxA = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); Value *IdxB = B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); - - FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); - FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); - Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); - Value *EltA = B.CreateExtractElement(VecA, IdxA); - Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); - Value *EltB = B.CreateExtractElement(VecB, IdxB); - Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); - Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); - Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); - Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); - Value *ResElt = B.CreateAdd(EltC, SubVecR); - Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); + Value *NewVecC = nullptr; + + if (IntrID == Intrinsic::x86_tdpbssd_internal) { + // tiledpbssd.scalarize.inner.body: + // calculate idxa, idxb + // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc + // %elta = extractelement <256 x i32> %veca, i16 %idxa + // %eltav4i8 = bitcast i32 %elta to <4 x i8> + // %eltb = extractelement <256 x i32> %vecb, i16 %idxb + // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> + // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> + // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> + // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 + // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) + // %neweltc = add i32 %elt, %acc + // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, + // i16 %idxc + FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); + FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); + Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); + Value *EltB = B.CreateExtractElement(VecB, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); + Value *SEXTSubVecB = B.CreateSExt(SubVecB, V4I32Ty); + Value *SEXTSubVecA = B.CreateSExt(SubVecA, V4I32Ty); + Value *SubVecR = B.CreateAddReduce(B.CreateMul(SEXTSubVecA, SEXTSubVecB)); + Value *ResElt = B.CreateAdd(EltC, SubVecR); + NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); + } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) { + // tiledpbf16ps.scalarize.inner.body: + // calculate idxa, idxb, idxc + // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc + // %eltcf32 = bitcast i32 %eltc to float + // %elta = extractelement <256 x i32> %veca, i16 %idxa + // %eltav2i16 = bitcast i32 %elta to <2 x i16> + // %eltb = extractelement <256 x i32> %vecb, i16 %idxb + // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> + // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 + // x i32> + // %eltav2f32 = bitcast <4 x i16> %shufflea to <2 x float> + // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x + // i32> + // %eltbv2f32 = bitcast <4 x i16> %shuffleb to <2 x float> + // %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 + // %acc = call float + // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) + // %neweltc = bitcast float %acc to i32 + // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, + // i16 %idxc + // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, + // i16 %idxc + FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); + FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); + Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); + Value *EltCF32 = B.CreateBitCast(EltC, B.getFloatTy()); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); + Value *EltB = B.CreateExtractElement(VecB, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); + Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); + int ShuffleMask[4] = {2, 0, 3, 1}; + auto ShuffleArray = makeArrayRef(ShuffleMask); + Value *AV2F32 = B.CreateBitCast( + B.CreateShuffleVector(SubVecA, ZeroV2I16, ShuffleArray), V2F32Ty); + Value *BV2F32 = B.CreateBitCast( + B.CreateShuffleVector(SubVecB, ZeroV2I16, ShuffleArray), V2F32Ty); + Value *SubVecR = B.CreateFAddReduce(EltCF32, B.CreateFMul(AV2F32, BV2F32)); + Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); + NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); + } // tiledpbssd.scalarize.cols.latch: // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc @@ -338,19 +386,26 @@ LoopInfo *LI; template bool lowerTileLoadStore(Instruction *TileLoadStore); - bool lowerTileDPBSSD(Instruction *TileDPBSSD); + template + typename std::enable_if::type + lowerTileDP(Instruction *TileDP); bool lowerTileZero(Instruction *TileZero); }; -bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) { +template +typename std::enable_if::type +X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { Value *M, *N, *K, *C, *A, *B; - match(TileDPBSSD, m_Intrinsic( - m_Value(M), m_Value(N), m_Value(K), m_Value(C), - m_Value(A), m_Value(B))); + match(TileDP, m_Intrinsic(m_Value(M), m_Value(N), m_Value(K), + m_Value(C), m_Value(A), m_Value(B))); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - Instruction *InsertI = TileDPBSSD; - IRBuilder<> PreBuilder(TileDPBSSD); - PreBuilder.SetInsertPoint(TileDPBSSD); + Instruction *InsertI = TileDP; + IRBuilder<> PreBuilder(TileDP); + PreBuilder.SetInsertPoint(TileDP); // We visit the loop with (m, n/4, k/4): // %n_dword = lshr i16 %n, 2 // %k_dword = lshr i16 %k, 2 @@ -359,17 +414,16 @@ BasicBlock *Start = InsertI->getParent(); BasicBlock *End = SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); - IRBuilder<> Builder(TileDPBSSD); - Value *ResVec = createTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M, - NDWord, KDWord, C, A, B); - // we cannot assume there always be bitcast after tiledpbssd. So we need to + IRBuilder<> Builder(TileDP); + Value *ResVec = createTileDPLoops(Start, End, Builder, DTU, *LI, M, + NDWord, KDWord, C, A, B); + // we cannot assume there always be bitcast after TileDP. So we need to // insert one bitcast as required Builder.SetInsertPoint(End->getFirstNonPHI()); Value *ResAMX = Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); - // Delete tiledpbssd intrinsic and do some clean-up. - for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end(); - UI != UE;) { + // Delete TileDP intrinsic and do some clean-up. + for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) { Instruction *I = cast((UI++)->getUser()); Value *Vec; if (match(I, m_BitCast(m_Value(Vec)))) { @@ -377,8 +431,8 @@ I->eraseFromParent(); } } - TileDPBSSD->replaceAllUsesWith(ResAMX); - TileDPBSSD->eraseFromParent(); + TileDP->replaceAllUsesWith(ResAMX); + TileDP->eraseFromParent(); return true; } @@ -456,6 +510,7 @@ case Intrinsic::x86_tileloadd64_internal: case Intrinsic::x86_tilestored64_internal: case Intrinsic::x86_tilezero_internal: + case Intrinsic::x86_tdpbf16ps_internal: WorkList.push_back(Inst); break; default: @@ -468,7 +523,10 @@ for (auto *Inst : WorkList) { switch (Inst->getIntrinsicID()) { case Intrinsic::x86_tdpbssd_internal: - C = lowerTileDPBSSD(Inst) || C; + C = lowerTileDP(Inst) || C; + break; + case Intrinsic::x86_tdpbf16ps_internal: + C = lowerTileDP(Inst) || C; break; case Intrinsic::x86_tileloadd64_internal: C = lowerTileLoadStore(Inst) || C; diff --git a/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll --- a/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll @@ -97,8 +97,8 @@ ret void } -define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 { -; CHECK-LABEL: @test_amx_dp( +define dso_local void @test_amx_dpbssd(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 { +; CHECK-LABEL: @test_amx_dpbssd( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx ; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx @@ -172,6 +172,84 @@ ret void } +define dso_local void @test_amx_dpbf16ps(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 { +; CHECK-LABEL: @test_amx_dpbf16ps( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx +; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx +; CHECK-NEXT: [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx +; CHECK-NEXT: [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i16 [[K:%.*]], 2 +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_HEADER:%.*]] +; CHECK: tdpbf16ps.scalarize.rows.header: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TDPBF16PS_SCALARIZE_ROWS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[C]], [[ENTRY]] ], [ [[TMP21:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ] +; CHECK-NEXT: [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP23:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_BODY:%.*]] +; CHECK: tdpbf16ps.scalarize.rows.body: +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_HEADER:%.*]] +; CHECK: tdpbf16ps.scalarize.cols.header: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_COLS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_COLS_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ] +; CHECK-NEXT: [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP23]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ] +; CHECK-NEXT: [[TMP2:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = add i16 [[TMP2]], [[TDPBF16PS_SCALARIZE_COLS_IV]] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_BODY:%.*]] +; CHECK: tdpbf16ps.scalarize.cols.body: +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_HEADER:%.*]] +; CHECK: tdpbf16ps.scalarize.inner.header: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_INNER_STEP:%.*]], [[TDPBF16PS_SCALARIZE_INNER_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_INNER_LATCH]] ] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_BODY:%.*]] +; CHECK: tdpbf16ps.scalarize.inner.body: +; CHECK-NEXT: [[TMP4:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = add i16 [[TMP4]], [[TDPBF16PS_SCALARIZE_INNER_IV]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 16 +; CHECK-NEXT: [[TMP7:%.*]] = add i16 [[TMP6]], [[TDPBF16PS_SCALARIZE_COLS_IV]] +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP3]] +; CHECK-NEXT: [[TMP9:%.*]] = bitcast i32 [[TMP8]] to float +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP5]] +; CHECK-NEXT: [[TMP11:%.*]] = bitcast i32 [[TMP10]] to <2 x i16> +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <256 x i32> [[B]], i16 [[TMP7]] +; CHECK-NEXT: [[TMP13:%.*]] = bitcast i32 [[TMP12]] to <2 x i16> +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i16> [[TMP11]], <2 x i16> zeroinitializer, <4 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <4 x i16> [[TMP14]] to <2 x float> +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x i16> [[TMP13]], <2 x i16> zeroinitializer, <4 x i32> +; CHECK-NEXT: [[TMP17:%.*]] = bitcast <4 x i16> [[TMP16]] to <2 x float> +; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x float> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = call float @llvm.vector.reduce.fadd.v2f32(float [[TMP9]], <2 x float> [[TMP18]]) +; CHECK-NEXT: [[TMP20:%.*]] = bitcast float [[TMP19]] to i32 +; CHECK-NEXT: [[TMP21]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP20]], i16 [[TMP3]] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_LATCH]] +; CHECK: tdpbf16ps.scalarize.inner.latch: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 1 +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_INNER_STEP]], [[TMP1]] +; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_INNER_COND]], label [[TDPBF16PS_SCALARIZE_INNER_HEADER]], label [[TDPBF16PS_SCALARIZE_COLS_LATCH]] +; CHECK: tdpbf16ps.scalarize.cols.latch: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_COLS_IV]], 1 +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_COLS_STEP]], [[TMP0]] +; CHECK-NEXT: [[TMP22:%.*]] = extractelement <256 x i32> [[TMP21]], i16 [[TMP3]] +; CHECK-NEXT: [[TMP23]] = insertelement <256 x i32> [[VEC_D_PHI_COL]], i32 [[TMP22]], i16 [[TMP3]] +; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_COLS_COND]], label [[TDPBF16PS_SCALARIZE_COLS_HEADER]], label [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] +; CHECK: tdpbf16ps.scalarize.rows.latch: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 1 +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_ROWS_STEP]], [[ROW:%.*]] +; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_ROWS_COND]], label [[TDPBF16PS_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]] +; CHECK: continue: +; CHECK-NEXT: [[TMP24:%.*]] = bitcast <256 x i32> [[TMP23]] to x86_amx +; CHECK-NEXT: store <256 x i32> [[TMP23]], <256 x i32>* [[VPTR:%.*]], align 64 +; CHECK-NEXT: ret void +; +entry: + %a.amx = bitcast <256 x i32> %a to x86_amx + %b.amx = bitcast <256 x i32> %b to x86_amx + %c.amx = bitcast <256 x i32> %c to x86_amx + %acc = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx) + %vec = bitcast x86_amx %acc to <256 x i32> + store <256 x i32> %vec, <256 x i32>* %vptr, align 64 + ret void +} + define dso_local void @test_amx_store(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr, <256 x i32> %vec) #0 { ; CHECK-LABEL: @test_amx_store( ; CHECK-NEXT: entry: @@ -232,6 +310,7 @@ declare x86_amx @llvm.x86.tilezero.internal(i16, i16) declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) attributes #0 = { noinline nounwind optnone }