diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -103,6 +103,7 @@ // AMX internal builtin TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8") +TARGET_BUILTIN(__builtin_ia32_tdpbf16ps_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-bf16") TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUs", "n", "amx-tile") // AMX diff --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h --- a/clang/lib/Headers/amxintrin.h +++ b/clang/lib/Headers/amxintrin.h @@ -224,6 +224,9 @@ #define __DEFAULT_FN_ATTRS_INT8 \ __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) +#define __DEFAULT_FN_ATTRS_BF16 \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-bf16"))) + typedef int _tile1024i __attribute__((__vector_size__(1024), __aligned__(64))); static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8 _tile_loadd_internal(unsigned short m, unsigned short n, const void *base, @@ -238,6 +241,12 @@ return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2); } +static __inline__ _tile1024i __DEFAULT_FN_ATTRS_INT8 +_tile_dpbf16ps_internal(unsigned short m, unsigned short n, unsigned short k, + _tile1024i dst, _tile1024i src1, _tile1024i src2) { + return __builtin_ia32_tdpbf16ps_internal(m, n, k, dst, src1, src2); +} + static __inline__ void __DEFAULT_FN_ATTRS_INT8 _tile_stored_internal(unsigned short m, unsigned short n, void *base, __SIZE_TYPE__ stride, _tile1024i tile) { @@ -264,6 +273,13 @@ src1.tile, src2.tile); } +__DEFAULT_FN_ATTRS_INT8 +static void __tile_dpbf16ps(__tile1024i *dst, __tile1024i src1, + __tile1024i src2) { + dst->tile = _tile_dpbf16ps_internal(src1.row, src2.col, src1.col, dst->tile, + src1.tile, src2.tile); +} + __DEFAULT_FN_ATTRS_TILE static void __tile_stored(void *base, __SIZE_TYPE__ stride, __tile1024i src) { _tile_stored_internal(src.row, src.col, base, stride, src.tile); diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5053,6 +5053,12 @@ [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, llvm_x86amx_ty, llvm_x86amx_ty, llvm_x86amx_ty], []>; + def int_x86_tdpbf16ps_internal : + GCCBuiltin<"__builtin_ia32_tdpbf16ps_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; def int_x86_tilestored64_internal : GCCBuiltin<"__builtin_ia32_tilestored64_internal">, Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -475,6 +475,14 @@ MI.tieOperands(0, 1); return true; } + case X86::PTDPBF16PSV: { + MI.untieRegOperand(4); + for (unsigned i = 3; i > 0; --i) + MI.RemoveOperand(i); + MI.setDesc(TII->get(X86::TDPBF16PS)); + MI.tieOperands(0, 1); + return true; + } case X86::PTILESTOREDV: { for (int i = 1; i >= 0; --i) MI.RemoveOperand(i); diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4638,6 +4638,23 @@ ReplaceNode(Node, CNode); return; } + case Intrinsic::x86_tdpbf16ps_internal: { + if (!Subtarget->hasAMXTILE()) + break; + SDValue Chain = Node->getOperand(0); + unsigned Opc = X86::PTDPBF16PSV; + SDValue Ops[] = {Node->getOperand(2), + Node->getOperand(3), + Node->getOperand(4), + Node->getOperand(5), + Node->getOperand(6), + Node->getOperand(7), + Chain}; + MachineSDNode *CNode = + CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops); + ReplaceNode(Node, CNode); + return; + } case Intrinsic::x86_tilezero_internal: { if (!Subtarget->hasAMXTILE()) break; diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -136,5 +136,10 @@ [(int_x86_tdpbf16ps timm:$src1, timm:$src2, timm:$src3)]>; } + // Pseduo instruction for RA. + let Constraints = "$src4 = $dst" in + def PTDPBF16PSV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), []>; } } // HasAMXTILE, HasAMXBF16 diff --git a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp --- a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp @@ -306,6 +306,111 @@ return NewVecC; } +static Value *createTileDPBF16PSLoops(BasicBlock *Start, BasicBlock *End, + IRBuilderBase &B, DomTreeUpdater &DTU, + LoopInfo &LI, Value *Row, Value *Col, + Value *K, Value *Acc, Value *LHS, + Value *RHS) { + Loop *RowLoop = LI.AllocateLoop(); + Loop *ColLoop = LI.AllocateLoop(); + Loop *InnerLoop = LI.AllocateLoop(); + ColLoop->addChildLoop(InnerLoop); + RowLoop->addChildLoop(ColLoop); + if (Loop *ParentL = LI.getLoopFor(Start)) + ParentL->addChildLoop(RowLoop); + else + LI.addTopLevelLoop(RowLoop); + + BasicBlock *RowBody = + createLoop(Start, End, Row, B.getInt16(1), "tiledpbf16ps.unroll.rows", B, + DTU, RowLoop, LI); + BasicBlock *RowLatch = RowBody->getSingleSuccessor(); + + BasicBlock *ColBody = + createLoop(RowBody, RowLatch, Col, B.getInt16(1), + "tiledpbf16ps.unroll.cols", B, DTU, ColLoop, LI); + BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); + + B.SetInsertPoint(ColBody->getTerminator()); + BasicBlock *InnerBody = + createLoop(ColBody, ColLoopLatch, K, B.getInt16(1), + "tiledpbf16ps.unroll.inner", B, DTU, InnerLoop, LI); + + BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor(); + BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); + BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); + BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); + Value *CurrentRow = &*RowLoopHeader->begin(); + Value *CurrentCol = &*ColumnLoopHeader->begin(); + Value *CurrentInner = &*InnerLoopHeader->begin(); + + FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); + // Type *EltTy = V256I32Ty->getElementType(); + Value *VecC, *VecA, *VecB; + if (auto BitCast = dyn_cast(Acc)) + VecC = BitCast->getOperand(0); + assert(VecC->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx"); + // TODO else create BitCast from x86amx to v256i32. + // Store x86amx to memory, and reload from memory + // to vector. However with -O0, it doesn't happen. + if (auto BitCast = dyn_cast(LHS)) + VecA = BitCast->getOperand(0); + assert(VecA->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx"); + if (auto BitCast = dyn_cast(RHS)) + VecB = BitCast->getOperand(0); + assert(VecB->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx"); + + // tiledpbf16ps.unroll.rows.header: + // %vec.phi.rows = phi <256 x i32> [ %vec_c, %continue ], [ %NewVecC, + // %tiledpbf16ps.unroll.rows.latch ] + B.SetInsertPoint(RowLoopHeader->getTerminator()); + PHINode *VecPhi_Row_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row"); + VecPhi_Row_Loop->addIncoming(VecC, Start); + + // tiledpbf16ps.unroll.cols.header: + // %vec.phi.cols = phi <256 x i32> [ %vec.phi.rows, + // %tiledpbf16ps.unroll.rows.body ], [ %NewVecC, %tiledpbf16ps.unroll.cols.latch ] + B.SetInsertPoint(ColumnLoopHeader->getTerminator()); + PHINode *VecPhi_Col_Loop = B.CreatePHI(V256I32Ty, 2, "vec.phi.col"); + VecPhi_Col_Loop->addIncoming(VecPhi_Row_Loop, RowBody); + + // Generate PHI vector for C. + B.SetInsertPoint(InnerLoopHeader->getTerminator()); + PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); + VecCPhi->addIncoming(VecPhi_Col_Loop, ColBody); + + // Generate accmulate multiply in innerbody. + B.SetInsertPoint(InnerBody->getTerminator()); + Value *IdxC = + B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); + Value *IdxA = + B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); + Value *IdxB = + B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); + + //FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); + FixedVectorType *V2I16Ty = FixedVectorType::get(B.getInt16Ty(), 2); + FixedVectorType *V2I32Ty = FixedVectorType::get(B.getInt32Ty(), 2); + FixedVectorType *V2F32Ty = FixedVectorType::get(B.getFloatTy(), 2); + FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4); + Value *EltC = B.CreateExtractElement(VecCPhi, IdxC); + Value *C_F32= B.CreateBitCast(EltC, B.getFloatTy()); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V2I16Ty); + Value *EltB = B.CreateExtractElement(VecB, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V2I16Ty); + Value *A_V2F32 = B.CreateBitCast(B.CreateShl(B.CreateZExt(SubVecA, V2I32Ty), B.CreateVectorSplat(2,B.getInt32(16))), V2F32Ty); + Value *B_V2F32 = B.CreateBitCast(B.CreateShl(B.CreateZExt(SubVecB, V2I32Ty), B.CreateVectorSplat(2,B.getInt32(16))), V2F32Ty); + Value *SubVecR = B.CreateFAddReduce(C_F32, B.CreateFMul(A_V2F32, B_V2F32)); + Value *ResElt = B.CreateBitCast(SubVecR, B.getInt32Ty()); + Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC); + VecCPhi->addIncoming(NewVecC, InnerLoopLatch); + VecPhi_Row_Loop->addIncoming(NewVecC, RowLatch); + VecPhi_Col_Loop->addIncoming(NewVecC, ColLoopLatch); + + return NewVecC; +} + namespace { class X86LowerAMXIntrinsics { Function &Func; @@ -320,6 +425,7 @@ LoopInfo *LI; bool lowerTileLoad(Instruction *TileLoad); bool lowerTileDPBSSD(Instruction *TileDPBSSD); + bool lowerTileDPBF16PS(Instruction *TileDPBSSD); bool lowerTileStore(Instruction *TileStore); bool lowerTileZero(Instruction *TileZero); }; @@ -359,6 +465,41 @@ return true; } +bool X86LowerAMXIntrinsics::lowerTileDPBF16PS(Instruction *TileDPBF16PS) { + Value *M, *N, *K, *C, *A, *B; + match(TileDPBF16PS, m_Intrinsic( + m_Value(M), m_Value(N), m_Value(K), m_Value(C), + m_Value(A), m_Value(B))); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Instruction *InsertI = TileDPBF16PS; + IRBuilder<> Builder_Prepare(TileDPBF16PS); + Builder_Prepare.SetInsertPoint(TileDPBF16PS); + // We visit the loop with (m, n/4, k/4): + // %n_dword = udiv i16 %n, 4 + // %k_dword = udiv i16 %k, 4 + Value *N_DWord = Builder_Prepare.CreateUDiv(N, Builder_Prepare.getInt16(4)); + Value *K_DWord = Builder_Prepare.CreateUDiv(K, Builder_Prepare.getInt16(4)); + BasicBlock *Start = InsertI->getParent(); + BasicBlock *End = + SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); + IRBuilder<> Builder(TileDPBF16PS); + Value *ResVec = createTileDPBF16PSLoops(Start, End, Builder, DTU, *LI, M, + N_DWord, K_DWord, C, A, B); + + // Delete tileloadd6 intrinsic and bitcast instruction. + for (auto UI = TileDPBF16PS->use_begin(), UE = TileDPBF16PS->use_end(); + UI != UE;) { + Instruction *I = cast((UI++)->getUser()); + Value *Vec; + if (match(I, m_BitCast(m_Value(Vec)))) { + I->replaceAllUsesWith(ResVec); + I->eraseFromParent(); + } + } + TileDPBF16PS->eraseFromParent(); + return true; +} + bool X86LowerAMXIntrinsics::lowerTileLoad(Instruction *TileLoad) { Value *M, *N, *Ptr, *Stride; match(TileLoad, m_Intrinsic( @@ -432,6 +573,7 @@ bool X86LowerAMXIntrinsics::visit() { bool C; SmallVector TileDPBSSDs; + SmallVector TileDPBF16PSs; SmallVector TileLoads; SmallVector TileStores; SmallVector TileZeros; @@ -446,6 +588,12 @@ // x86_amx, %amx1, ...) // %vec2 = bitcast x86_amx %res to <256 x i32> TileDPBSSDs.push_back(&Inst); + } else if (match(&Inst, m_Intrinsic())) { + // %amx1 = bitcast <256 x i32> %vec to x86_amx + // %res = call x86_amx @llvm.x86.tdpbf16ps.internal(i16 m, i16 n, i16 k, + // x86_amx, %amx1, ...) + // %vec2 = bitcast x86_amx %res to <256 x i32> + TileDPBF16PSs.push_back(&Inst); } else if (match(&Inst, m_Intrinsic())) { // %17 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %13, i16 %14, @@ -473,6 +621,9 @@ for (auto *Inst : TileDPBSSDs) { C |= lowerTileDPBSSD(Inst); } + for (auto *Inst : TileDPBF16PSs) { + C |= lowerTileDPBF16PS(Inst); + } for (auto *Inst : TileStores) { C |= lowerTileStore(Inst); } diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -67,7 +67,8 @@ } // a * b + c // The shape depends on which operand. - case Intrinsic::x86_tdpbssd_internal: { + case Intrinsic::x86_tdpbssd_internal: + case Intrinsic::x86_tdpbf16ps_internal:{ switch (OpNo) { case 3: Row = II->getArgOperand(0); diff --git a/llvm/lib/Target/X86/X86PreTileConfig.cpp b/llvm/lib/Target/X86/X86PreTileConfig.cpp --- a/llvm/lib/Target/X86/X86PreTileConfig.cpp +++ b/llvm/lib/Target/X86/X86PreTileConfig.cpp @@ -127,6 +127,7 @@ llvm_unreachable("Unexpected machine instruction on tile"); case X86::PTILELOADDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: MachineOperand &MO1 = const_cast(MI.getOperand(1)); MachineOperand &MO2 = const_cast(MI.getOperand(2)); @@ -221,6 +222,7 @@ case X86::PTILELOADDV: case X86::PTILESTOREDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: return true; } diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -873,6 +873,7 @@ // We only collect the tile shape that is defined. case X86::PTILELOADDV: case X86::PTDPBSSDV: + case X86::PTDPBF16PSV: case X86::PTILEZEROV: MachineOperand &MO1 = MI->getOperand(1); MachineOperand &MO2 = MI->getOperand(2);