diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -103,6 +103,7 @@ // AMX internal builtin TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8") +TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16") TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile") // AMX diff --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h --- a/clang/lib/Headers/amxintrin.h +++ b/clang/lib/Headers/amxintrin.h @@ -224,6 +224,9 @@ #define __DEFAULT_FN_ATTRS_INT8 \ __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) +#define __DEFAULT_FN_ATTRS_BF16 \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16"))) + typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64))); static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8 _tile_loadd_internal(unsigned short m, unsigned short n, const void *base, @@ -238,6 +241,12 @@ return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2); } +static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8 +_tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); +} + static __inline__ void __DEFAULT_FN_ATTRS_INT8 _tile_stored_internal(unsigned short m, unsigned short n, void *base, __SIZE_TYPE__ stride, _tile1024i tile) { @@ -264,6 +273,13 @@ src1.tile, src2.tile); } +__DEFAULT_FN_ATTRS_INT8 +static void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src1, + __tile1024i src2) { + dst->tile = _tile_dpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile, + src1.tile, src2.tile); +} + __DEFAULT_FN_ATTRS_TILE static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) { _tile_stored_internal(src.row, src.col, base, stride, src.tile); diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5053,6 +5053,12 @@ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty], []>; + def int_x86_tdpbf16ps_internal : + GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; def int_x86_tilestored64_internal : GCCBuiltin<"__builtin_ia32_tilestored64_internal">, Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -475,6 +475,14 @@ MI.tieOperands(0, 1); return true; } + case X86::PTDPBF16PSV: { + MI.untieRegOperand(4); + for (unsigned i = 3; i > 0; --i) + MI.RemoveOperand(i); + MI.setDesc(TII->get(X86::TDPBF16PS)); + MI.tieOperands(0, 1); + return true; + } case X86::PTILESTOREDV: { for (int i = 1; i >= 0; --i) MI.RemoveOperand(i); diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4638,6 +4638,23 @@ ReplaceNode(Node, CNode); return; } + case Intrinsic::x86_tdpbf16ps_internal: { + if (!Subtarget->hasAMXTILE()) + break; + SDValue Chain = Node->getOperand(0); + unsigned Opc = X86::PTDPBF16PSV; + SDValue Ops[] = {Node->getOperand(2), + Node->getOperand(3), + Node->getOperand(4), + Node->getOperand(5), + Node->getOperand(6), + Node->getOperand(7), + Chain}; + MachineSDNode *CNode = + CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops); + ReplaceNode(Node, CNode); + return; + } case Intrinsic::x86_tilezero_internal: { if (!Subtarget->hasAMXTILE()) break; diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -136,5 +136,10 @@ [(int_x86_tdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; } + // Pseduo instruction for RA. + let Constraints = "$src4 = $dst" in + def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), []>; } } // HasAMXTILE, HasAMXBF16 diff --git a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp --- a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp @@ -22,7 +22,6 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" - #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DataLayout.h" @@ -181,11 +180,16 @@ } } -static Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End, - IRBuilderBase &B, DomTreeUpdater &DTU, - LoopInfo &LI, Value *Row, Value *Col, - Value *K, Value *Acc, Value *LHS, - Value *RHS) { +template ::type> +static Value *createTileDPLoops(BasicBlock *Start, BasicBlock *End, + IRBuilderBase &B, DomTreeUpdater &DTU, + LoopInfo &LI, Value *Row, Value *Col, Value *K, + Value *Acc, Value *LHS, Value *RHS) { + std::string IntrinName = + IntrID == Intrinsic::x86_tdpbssd_internal ? "tiledpbssd" : "tdpbf16ps"; Loop *RowLoop = LI.AllocateLoop(); Loop *ColLoop = LI.AllocateLoop(); Loop *InnerLoop = LI.AllocateLoop(); @@ -197,19 +201,19 @@ LI.addTopLevelLoop(RowLoop); BasicBlock *RowBody = - createLoop(Start, End, Row, B.getInt16(1), "tiledpbssd.scalarize.rows", B, - DTU, RowLoop, LI); + createLoop(Start, End, Row, B.getInt16(1), IntrinName + ".scalarize.rows", + B, DTU, RowLoop, LI); BasicBlock *RowLatch = RowBody->getSingleSuccessor(); BasicBlock *ColBody = createLoop(RowBody, RowLatch, Col, B.getInt16(1), - "tiledpbssd.scalarize.cols", B, DTU, ColLoop, LI); + IntrinName + ".scalarize.cols", B, DTU, ColLoop, LI); BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); B.SetInsertPoint(ColBody->getTerminator()); BasicBlock *InnerBody = createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), - "tiledpbssd.scalarize.inner", B, DTU, InnerLoop, LI); + IntrinName + ".scalarize.inner", B, DTU, InnerLoop, LI); BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor(); BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); @@ -275,22 +279,6 @@ PHINode *VecDPhi = B.CreatePHI(V256I32Ty, 2, "vec.d.inner.phi"); VecDPhi->addIncoming(VecDPhiColLoop, ColBody); - // tiledpbssd.scalarize.inner.body: - // calculate idxa, idxb, idxc - // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc - // %elta = extractelement <256 x i32> %veca, i16 %idxa - // %eltav4i8 = bitcast i32 %elta to <4 x i8> - // %eltb = extractelement <256 x i32> %vecb, i16 %idxb - // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> - // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> - // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> - // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 - // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) - // %neweltc = add i32 %elt, %acc - // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, - // i16 %idxc - // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, - // i16 %idxc B.SetInsertPoint(InnerBody->getTerminator()); Value *IdxC = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); @@ -298,17 +286,77 @@ B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); Value *IdxB = B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); - - FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); - FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); - Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); - Value *EltA = B.CreateExtractElement(VecA, IdxA); - Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); - Value *EltB = B.CreateExtractElement(VecB, IdxB); - Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); - Value *SubVecR = B.CreateAddReduce(B.CreateMul( - B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty))); - Value *ResElt = B.CreateAdd(EltC, SubVecR); + Value *ResElt = nullptr; + if (IntrID == Intrinsic::x86_tdpbssd_internal) { + // tiledpbssd.scalarize.inner.body: + // calculate idxa, idxb, idxc + // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc + // %elta = extractelement <256 x i32> %veca, i16 %idxa + // %eltav4i8 = bitcast i32 %elta to <4 x i8> + // %eltb = extractelement <256 x i32> %vecb, i16 %idxb + // %eltbv4i8 = bitcast i32 %eltb to <4 x i8> + // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32> + // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32> + // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32 + // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131) + // %neweltc = add i32 %elt, %acc + // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, + // i16 %idxc + // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, + // i16 %idxc + FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); + FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); + Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); + Value *EltB = B.CreateExtractElement(VecB, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); + Value *SubVecR = B.CreateAddReduce(B.CreateMul( + B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty))); + ResElt = B.CreateAdd(EltC, SubVecR); + } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) { + // tiledpbf16ps.scalarize.inner.body: + // calculate idxa, idxb, idxc + // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc + // %eltcf32 = bitcast i32 %eltc to float + // %elta = extractelement <256 x i32> %veca, i16 %idxa + // %eltav2i16 = bitcast i32 %elta to <2 x i16> + // %eltb = extractelement <256 x i32> %vecb, i16 %idxb + // %eltbv2i16 = bitcast i32 %eltb to <2 x i16> + // %shufflea = shufflevector <2 x i16> %elta, <2 x i16> zeroinitializer, <4 + // x i32> %eltav2f32 = bitcast <4 x i16> + // %shufflea to <2 x float> + // %shuffleb = shufflevector <2 x i16> %eltb, <2 xi16> zeroinitializer, <4 x + // i32> %eltbv2f32 = bitcast <4 x i16> + // %shuffleb to <2 x float> %mulab = fmul <2 x float> %eltav2f32, %eltbv2f32 + // %acc = call float + // @llvm.vector.reduce.fadd.v2f32(float %eltcf32, <2 x float> %mulab) + // %neweltc = bitcast float %acc to i32 + // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc, + // i16 %idxc + // %NewVecD = insertelement <256 x i32> %vec.d.inner.phi, i32 %neweltc, + // i16 %idxc + FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); + FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); + Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); + Value *C_F32 = B.CreateBitCast(EltC, B.getFloatTy()); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); + Value *EltB = B.CreateExtractElement(VecB, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); + Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); + int ShuffleMask[4] = {2, 0, 3, 1}; + Value *A_V2F32 = B.CreateBitCast( + B.CreateShuffleVector(SubVecA, ZeroV2I16, makeArrayRef(ShuffleMask)), + V2F32Ty); + Value *B_V2F32 = B.CreateBitCast( + B.CreateShuffleVector(SubVecB, ZeroV2I16, makeArrayRef(ShuffleMask)), + V2F32Ty); + Value *SubVecR = B.CreateFAddReduce(C_F32, B.CreateFMul(A_V2F32, B_V2F32)); + ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); + } else { + llvm_unreachable("it is not a tdpb intrinsic"); + } Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); Value *NewVecD = B.CreateInsertElement(VecDPhi, ResElt, IdxC); @@ -340,20 +388,27 @@ IntrID == Intrinsic::x86_tilestored64_internal>::type> bool lowerTileLoadStore(Instruction *TileLoad); bool lowerTileLoad(Instruction *TileLoad); - bool lowerTileDPBSSD(Instruction *TileDPBSSD); + template ::type> + bool lowerTileDP(Instruction *TileDP); bool lowerTileStore(Instruction *TileStore); bool lowerTileZero(Instruction *TileZero); }; -bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) { +template ::type> +bool X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { Value *M, *N, *K, *C, *A, *B; - match(TileDPBSSD, m_Intrinsic( - m_Value(M), m_Value(N), m_Value(K), m_Value(C), - m_Value(A), m_Value(B))); + match(TileDP, m_Intrinsic(m_Value(M), m_Value(N), m_Value(K), + m_Value(C), m_Value(A), m_Value(B))); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - Instruction *InsertI = TileDPBSSD; - IRBuilder<> PreBuilder(TileDPBSSD); - PreBuilder.SetInsertPoint(TileDPBSSD); + Instruction *InsertI = TileDP; + IRBuilder<> PreBuilder(TileDP); + PreBuilder.SetInsertPoint(TileDP); // We visit the loop with (m, n/4, k/4): // %n_dword = udiv i16 %n, 4 // %k_dword = udiv i16 %k, 4 @@ -362,17 +417,16 @@ BasicBlock *Start = InsertI->getParent(); BasicBlock *End = SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); - IRBuilder<> Builder(TileDPBSSD); - Value *ResVec = createTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M, - NDWord, KDWord, C, A, B); - // we cannot assume there always be bitcast after tiledpbssd. So we need to + IRBuilder<> Builder(TileDP); + Value *ResVec = createTileDPLoops(Start, End, Builder, DTU, *LI, M, + NDWord, KDWord, C, A, B); + // we cannot assume there always be bitcast after TileDP. So we need to // insert one bitcast as required Builder.SetInsertPoint(End->getFirstNonPHI()); Value *ResAMX = Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); - // Delete tiledpbssd intrinsic and do some clean-up. - for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end(); - UI != UE;) { + // Delete TileDP intrinsic and do some clean-up. + for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) { Instruction *I = cast((UI++)->getUser()); Value *Vec; if (match(I, m_BitCast(m_Value(Vec)))) { @@ -380,8 +434,8 @@ I->eraseFromParent(); } } - TileDPBSSD->replaceAllUsesWith(ResAMX); - TileDPBSSD->eraseFromParent(); + TileDP->replaceAllUsesWith(ResAMX); + TileDP->eraseFromParent(); return true; } @@ -460,6 +514,7 @@ case Intrinsic::x86_tileloadd64_internal: case Intrinsic::x86_tilestored64_internal: case Intrinsic::x86_tilezero_internal: + case Intrinsic::x86_tdpbf16ps_internal: WorkList.push_back(Inst); break; default: @@ -472,7 +527,10 @@ for (auto *Inst : WorkList) { switch (Inst->getIntrinsicID()) { case Intrinsic::x86_tdpbssd_internal: - C = lowerTileDPBSSD(Inst) || C; + C = lowerTileDP(Inst) || C; + break; + case Intrinsic::x86_tdpbf16ps_internal: + C = lowerTileDP(Inst) || C; break; case Intrinsic::x86_tileloadd64_internal: C = lowerTileLoadStore(Inst) || C; diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -69,7 +69,8 @@ } // a * b + c // The shape depends on which operand. - case Intrinsic::x86_tdpbssd_internal: { + case Intrinsic::x86_tdpbssd_internal: + case Intrinsic::x86_tdpbf16ps_internal: { switch (OpNo) { case 3: Row = II->getArgOperand(0); diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -127,6 +127,7 @@ llvm_unreachable("Unexpected machine instruction on tile"); case X86::PTILELOADDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: MachineOperand &MO1 = const_cast(MI.getOperand(1)); MachineOperand &MO2 = const_cast(MI.getOperand(2)); @@ -221,6 +222,7 @@ case X86::PTILELOADDV: case X86::PTILESTOREDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: return true; } diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -878,6 +878,7 @@ // We only collect the tile shape that is defined. case X86::PTILELOADDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2); diff --git a/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll --- a/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll @@ -97,8 +97,8 @@ ret void } -define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 { -; CHECK-LABEL: @test_amx_dp( +define dso_local void @test_amx_dpbssd(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 { +; CHECK-LABEL: @test_amx_dpbssd( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx ; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx @@ -172,6 +172,84 @@ ret void } +define dso_local void @test_amx_dpbf16ps(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 { +; CHECK-LABEL: @test_amx_dpbf16ps( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx +; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx +; CHECK-NEXT: [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx +; CHECK-NEXT: [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i16 [[K:%.*]], 2 +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_HEADER:%.*]] +; CHECK: tdpbf16ps.scalarize.rows.header: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TDPBF16PS_SCALARIZE_ROWS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[C]], [[ENTRY]] ], [ [[TMP21:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ] +; CHECK-NEXT: [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP22:%.*]], [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] ] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_ROWS_BODY:%.*]] +; CHECK: tdpbf16ps.scalarize.rows.body: +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_HEADER:%.*]] +; CHECK: tdpbf16ps.scalarize.cols.header: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_COLS_STEP:%.*]], [[TDPBF16PS_SCALARIZE_COLS_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ] +; CHECK-NEXT: [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TDPBF16PS_SCALARIZE_ROWS_BODY]] ], [ [[TMP22]], [[TDPBF16PS_SCALARIZE_COLS_LATCH]] ] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_COLS_BODY:%.*]] +; CHECK: tdpbf16ps.scalarize.cols.body: +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_HEADER:%.*]] +; CHECK: tdpbf16ps.scalarize.inner.header: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TDPBF16PS_SCALARIZE_INNER_STEP:%.*]], [[TDPBF16PS_SCALARIZE_INNER_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TMP21]], [[TDPBF16PS_SCALARIZE_INNER_LATCH]] ] +; CHECK-NEXT: [[VEC_D_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_COL]], [[TDPBF16PS_SCALARIZE_COLS_BODY]] ], [ [[TMP22]], [[TDPBF16PS_SCALARIZE_INNER_LATCH]] ] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_BODY:%.*]] +; CHECK: tdpbf16ps.scalarize.inner.body: +; CHECK-NEXT: [[TMP2:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16 +; CHECK-NEXT: [[TMP3:%.*]] = add i16 [[TMP2]], [[TDPBF16PS_SCALARIZE_COLS_IV]] +; CHECK-NEXT: [[TMP4:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 16 +; CHECK-NEXT: [[TMP5:%.*]] = add i16 [[TMP4]], [[TDPBF16PS_SCALARIZE_INNER_IV]] +; CHECK-NEXT: [[TMP6:%.*]] = mul i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 16 +; CHECK-NEXT: [[TMP7:%.*]] = add i16 [[TMP6]], [[TDPBF16PS_SCALARIZE_COLS_IV]] +; CHECK-NEXT: [[TMP8:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP3]] +; CHECK-NEXT: [[TMP9:%.*]] = bitcast i32 [[TMP8]] to float +; CHECK-NEXT: [[TMP10:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP5]] +; CHECK-NEXT: [[TMP11:%.*]] = bitcast i32 [[TMP10]] to <2 x i16> +; CHECK-NEXT: [[TMP12:%.*]] = extractelement <256 x i32> [[B]], i16 [[TMP7]] +; CHECK-NEXT: [[TMP13:%.*]] = bitcast i32 [[TMP12]] to <2 x i16> +; CHECK-NEXT: [[TMP14:%.*]] = shufflevector <2 x i16> [[TMP11]], <2 x i16> zeroinitializer, <4 x i32> +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <4 x i16> [[TMP14]] to <2 x float> +; CHECK-NEXT: [[TMP16:%.*]] = shufflevector <2 x i16> [[TMP13]], <2 x i16> zeroinitializer, <4 x i32> +; CHECK-NEXT: [[TMP17:%.*]] = bitcast <4 x i16> [[TMP16]] to <2 x float> +; CHECK-NEXT: [[TMP18:%.*]] = fmul <2 x float> [[TMP15]], [[TMP17]] +; CHECK-NEXT: [[TMP19:%.*]] = call float @llvm.vector.reduce.fadd.v2f32(float [[TMP9]], <2 x float> [[TMP18]]) +; CHECK-NEXT: [[TMP20:%.*]] = bitcast float [[TMP19]] to i32 +; CHECK-NEXT: [[TMP21]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP20]], i16 [[TMP3]] +; CHECK-NEXT: [[TMP22]] = insertelement <256 x i32> [[VEC_D_INNER_PHI]], i32 [[TMP20]], i16 [[TMP3]] +; CHECK-NEXT: br label [[TDPBF16PS_SCALARIZE_INNER_LATCH]] +; CHECK: tdpbf16ps.scalarize.inner.latch: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_INNER_IV]], 1 +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_INNER_STEP]], [[TMP1]] +; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_INNER_COND]], label [[TDPBF16PS_SCALARIZE_INNER_HEADER]], label [[TDPBF16PS_SCALARIZE_COLS_LATCH]] +; CHECK: tdpbf16ps.scalarize.cols.latch: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_COLS_IV]], 1 +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_COLS_STEP]], [[TMP0]] +; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_COLS_COND]], label [[TDPBF16PS_SCALARIZE_COLS_HEADER]], label [[TDPBF16PS_SCALARIZE_ROWS_LATCH]] +; CHECK: tdpbf16ps.scalarize.rows.latch: +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_STEP]] = add i16 [[TDPBF16PS_SCALARIZE_ROWS_IV]], 1 +; CHECK-NEXT: [[TDPBF16PS_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TDPBF16PS_SCALARIZE_ROWS_STEP]], [[ROW:%.*]] +; CHECK-NEXT: br i1 [[TDPBF16PS_SCALARIZE_ROWS_COND]], label [[TDPBF16PS_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]] +; CHECK: continue: +; CHECK-NEXT: [[TMP23:%.*]] = bitcast <256 x i32> [[TMP22]] to x86_amx +; CHECK-NEXT: store <256 x i32> [[TMP22]], <256 x i32>* [[VPTR:%.*]], align 64 +; CHECK-NEXT: ret void +; +entry: + %a.amx = bitcast <256 x i32> %a to x86_amx + %b.amx = bitcast <256 x i32> %b to x86_amx + %c.amx = bitcast <256 x i32> %c to x86_amx + %acc = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx) + %vec = bitcast x86_amx %acc to <256 x i32> + store <256 x i32> %vec, <256 x i32>* %vptr, align 64 + ret void +} + define dso_local void @test_amx_store(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr, <256 x i32> %vec) #0 { ; CHECK-LABEL: @test_amx_store( ; CHECK-NEXT: entry: @@ -232,6 +310,7 @@ declare x86_amx @llvm.x86.tilezero.internal(i16, i16) declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare x86_amx @llvm.x86.tdpbf16ps.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) attributes #0 = { noinline nounwind optnone }