diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -103,6 +103,7 @@ // AMX internal builtin TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8") +TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16") TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile") // AMX diff --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h --- a/clang/lib/Headers/amxintrin.h +++ b/clang/lib/Headers/amxintrin.h @@ -224,6 +224,9 @@ #define __DEFAULT_FN_ATTRS_INT8 \ __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) +#define __DEFAULT_FN_ATTRS_BF16 \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16"))) + typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64))); static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8 _tile_loadd_internal(unsigned short m, unsigned short n, const void *base, @@ -238,6 +241,12 @@ return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2); } +static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8 +_tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); +} + static __inline__ void __DEFAULT_FN_ATTRS_INT8 _tile_stored_internal(unsigned short m, unsigned short n, void *base, __SIZE_TYPE__ stride, _tile1024i tile) { @@ -264,6 +273,13 @@ src1.tile, src2.tile); } +__DEFAULT_FN_ATTRS_INT8 +static void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src1, + __tile1024i src2) { + dst->tile = _tile_dpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile, + src1.tile, src2.tile); +} + __DEFAULT_FN_ATTRS_TILE static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) { _tile_stored_internal(src.row, src.col, base, stride, src.tile); diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5053,6 +5053,12 @@ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty], []>; + def int_x86_tdpbf16ps_internal : + GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; def int_x86_tilestored64_internal : GCCBuiltin<"__builtin_ia32_tilestored64_internal">, Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -475,6 +475,14 @@ MI.tieOperands(0, 1); return true; } + case X86::PTDPBF16PSV: { + MI.untieRegOperand(4); + for (unsigned i = 3; i > 0; --i) + MI.RemoveOperand(i); + MI.setDesc(TII->get(X86::TDPBF16PS)); + MI.tieOperands(0, 1); + return true; + } case X86::PTILESTOREDV: { for (int i = 1; i >= 0; --i) MI.RemoveOperand(i); diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4638,6 +4638,23 @@ ReplaceNode(Node, CNode); return; } + case Intrinsic::x86_tdpbf16ps_internal: { + if (!Subtarget->hasAMXTILE()) + break; + SDValue Chain = Node->getOperand(0); + unsigned Opc = X86::PTDPBF16PSV; + SDValue Ops[] = {Node->getOperand(2), + Node->getOperand(3), + Node->getOperand(4), + Node->getOperand(5), + Node->getOperand(6), + Node->getOperand(7), + Chain}; + MachineSDNode *CNode = + CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops); + ReplaceNode(Node, CNode); + return; + } case Intrinsic::x86_tilezero_internal: { if (!Subtarget->hasAMXTILE()) break; diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -136,5 +136,10 @@ [(int_x86_tdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; } + // Pseduo instruction for RA. + let Constraints = "$src4 = $dst" in + def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), []>; } } // HasAMXTILE, HasAMXBF16 diff --git a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp --- a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp @@ -22,7 +22,6 @@ #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" - #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DataLayout.h" @@ -209,11 +208,11 @@ B.CreateStore(Elt, EltPtr); } -static Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End, - IRBuilderBase &B, DomTreeUpdater &DTU, - LoopInfo &LI, Value *Row, Value *Col, - Value *K, Value *Acc, Value *LHS, - Value *RHS) { +template +static Value *createTileDPLoops(BasicBlock *Start, BasicBlock *End, + IRBuilderBase &B, DomTreeUpdater &DTU, + LoopInfo &LI, Value *Row, Value *Col, Value *K, + Value *Acc, Value *LHS, Value *RHS) { Loop *RowLoop = LI.AllocateLoop(); Loop *ColLoop = LI.AllocateLoop(); Loop *InnerLoop = LI.AllocateLoop(); @@ -321,17 +320,40 @@ B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); Value *IdxB = B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); - - FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); - FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); - Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); - Value *EltA = B.CreateExtractElement(VecA, IdxA); - Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); - Value *EltB = B.CreateExtractElement(VecB, IdxB); - Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); - Value *SubVecR = B.CreateAddReduce(B.CreateMul( - B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty))); - Value *ResElt = B.CreateAdd(EltC, SubVecR); + Value *ResElt = nullptr; + if (IntrID == Intrinsic::x86_tdpbssd_internal) { + FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); + FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); + Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); + Value *EltB = B.CreateExtractElement(VecB, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); + Value *SubVecR = B.CreateAddReduce(B.CreateMul( + B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty))); + ResElt = B.CreateAdd(EltC, SubVecR); + } else if (IntrID == Intrinsic::x86_tdpbf16ps_internal) { + FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); + FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); + Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); + Value *C_F32 = B.CreateBitCast(EltC, B.getFloatTy()); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); + Value *EltB = B.CreateExtractElement(VecB, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); + Value *ZeroV2I16 = Constant::getNullValue(V2I16Ty); + int ShuffleMask[4] = {2, 0, 3, 1}; + Value *A_V2F32 = B.CreateBitCast( + B.CreateShuffleVector(SubVecA, ZeroV2I16, makeArrayRef(ShuffleMask)), + V2F32Ty); + Value *B_V2F32 = B.CreateBitCast( + B.CreateShuffleVector(SubVecB, ZeroV2I16, makeArrayRef(ShuffleMask)), + V2F32Ty); + Value *SubVecR = B.CreateFAddReduce(C_F32, B.CreateFMul(A_V2F32, B_V2F32)); + ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); + } else { + llvm_unreachable("it is not a tdpb intrinsic"); + } Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); Value *NewVecD = B.CreateInsertElement(VecDPhi, ResElt, IdxC); @@ -358,20 +380,20 @@ DominatorTree *DT; LoopInfo *LI; bool lowerTileLoad(Instruction *TileLoad); - bool lowerTileDPBSSD(Instruction *TileDPBSSD); + template bool lowerTileDP(Instruction *TileDP); bool lowerTileStore(Instruction *TileStore); bool lowerTileZero(Instruction *TileZero); }; -bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) { +template +bool X86LowerAMXIntrinsics::lowerTileDP(Instruction *TileDP) { Value *M, *N, *K, *C, *A, *B; - match(TileDPBSSD, m_Intrinsic( - m_Value(M), m_Value(N), m_Value(K), m_Value(C), - m_Value(A), m_Value(B))); + match(TileDP, m_Intrinsic(m_Value(M), m_Value(N), m_Value(K), + m_Value(C), m_Value(A), m_Value(B))); DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); - Instruction *InsertI = TileDPBSSD; - IRBuilder<> BuilderPrepare(TileDPBSSD); - BuilderPrepare.SetInsertPoint(TileDPBSSD); + Instruction *InsertI = TileDP; + IRBuilder<> BuilderPrepare(TileDP); + BuilderPrepare.SetInsertPoint(TileDP); // We visit the loop with (m, n/4, k/4): // %n_dword = udiv i16 %n, 4 // %k_dword = udiv i16 %k, 4 @@ -380,17 +402,16 @@ BasicBlock *Start = InsertI->getParent(); BasicBlock *End = SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); - IRBuilder<> Builder(TileDPBSSD); - Value *ResVec = createTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M, - NDWord, KDWord, C, A, B); - // we cannot assume there always be bitcast after tiledpbssd. So we need to + IRBuilder<> Builder(TileDP); + Value *ResVec = createTileDPLoops(Start, End, Builder, DTU, *LI, M, + NDWord, KDWord, C, A, B); + // we cannot assume there always be bitcast after TileDP. So we need to // insert one bitcast as required Builder.SetInsertPoint(End->getFirstNonPHI()); Value *ResAMX = Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext())); - // Delete tiledpbssd intrinsic and do some clean-up. - for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end(); - UI != UE;) { + // Delete TileDP intrinsic and do some clean-up. + for (auto UI = TileDP->use_begin(), UE = TileDP->use_end(); UI != UE;) { Instruction *I = cast((UI++)->getUser()); Value *Vec; if (match(I, m_BitCast(m_Value(Vec)))) { @@ -398,8 +419,8 @@ I->eraseFromParent(); } } - TileDPBSSD->replaceAllUsesWith(ResAMX); - TileDPBSSD->eraseFromParent(); + TileDP->replaceAllUsesWith(ResAMX); + TileDP->eraseFromParent(); return true; } @@ -481,6 +502,7 @@ bool X86LowerAMXIntrinsics::visit() { bool C = false; SmallVector TileDPBSSDs; + SmallVector TileDPBF16PSs; SmallVector TileLoads; SmallVector TileStores; SmallVector TileZeros; @@ -489,6 +511,7 @@ for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) { Instruction &Inst = *II++; if (match(&Inst, m_Intrinsic()) || + match(&Inst, m_Intrinsic()) || match(&Inst, m_Intrinsic()) || match(&Inst, m_Intrinsic()) || match(&Inst, m_Intrinsic())) @@ -504,7 +527,13 @@ // %res = call x86_amx @llvm.x86.tdpbssd.internal(i16 m, i16 n, i16 k, // x86_amx, %amx1, ...) // %vec2 = bitcast x86_amx %res to <256 x i32> - C = lowerTileDPBSSD(Inst) || C; + C = lowerTileDP(Inst) || C; + else if (match(Inst, m_Intrinsic())) + // %amx1 = bitcast <256 x i32> %vec to x86_amx + // %res = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 m, i16 n, i16 k, + // x86_amx, %amx1, ...) + // %vec2 = bitcast x86_amx %res to <256 x i32> + C = lowerTileDP(Inst) || C; else if (match(Inst, m_Intrinsic())) // %17 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %13, i16 %14, // i8* %15, i64 %16) diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -69,7 +69,8 @@ } // a * b + c // The shape depends on which operand. - case Intrinsic::x86_tdpbssd_internal: { + case Intrinsic::x86_tdpbssd_internal: + case Intrinsic::x86_tdpbf16ps_internal: { switch (OpNo) { case 3: Row = II->getArgOperand(0); diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -127,6 +127,7 @@ llvm_unreachable("Unexpected machine instruction on tile"); case X86::PTILELOADDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: MachineOperand &MO1 = const_cast(MI.getOperand(1)); MachineOperand &MO2 = const_cast(MI.getOperand(2)); @@ -221,6 +222,7 @@ case X86::PTILELOADDV: case X86::PTILESTOREDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: return true; } diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -878,6 +878,7 @@ // We only collect the tile shape that is defined. case X86::PTILELOADDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2);