diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -489,6 +489,8 @@ /// The pass transform load/store <256 x i32> to AMX load/store intrinsics /// or split the data to two <128 x i32>. FunctionPass *createX86LowerAMXTypePass(); + + FunctionPass *createX86LowerAMXIntrinsicsPass(); } // End llvm namespace #endif diff --git a/llvm/lib/Target/X86/CMakeLists.txt b/llvm/lib/Target/X86/CMakeLists.txt --- a/llvm/lib/Target/X86/CMakeLists.txt +++ b/llvm/lib/Target/X86/CMakeLists.txt @@ -33,6 +33,7 @@ X86DomainReassignment.cpp X86DiscriminateMemOps.cpp X86LowerAMXType.cpp + X86LowerAMXIntrinsics.cpp X86TileConfig.cpp X86PreTileConfig.cpp X86ExpandPseudo.cpp diff --git a/llvm/lib/Target/X86/X86.h b/llvm/lib/Target/X86/X86.h --- a/llvm/lib/Target/X86/X86.h +++ b/llvm/lib/Target/X86/X86.h @@ -169,6 +169,7 @@ void initializeX86PreTileConfigPass(PassRegistry &); void initializeX86TileConfigPass(PassRegistry &); void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &); +void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &); namespace X86AS { enum : unsigned { diff --git a/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp @@ -0,0 +1,380 @@ +//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file Pass to transform amx intrinsics to scalar operation. +/// This pass is only enabled with -O0. With -O0, the def of shape to amx +/// intrinsics is near the amx intrinsics code. We are not bale to find a +/// point which post-dominate all the shape and dominate all amx intrinsics. +/// To decouple the dependency of the shape, we transform amx intrinsics +/// to scalar operation, so that compiling doesn't fail. In long term, we +/// should improve fast register allocation to allocate amx register. +//===----------------------------------------------------------------------===// +// +#include "X86.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsX86.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "lower-amx-intrinsics" + +static BasicBlock *CreateLoop(BasicBlock *Preheader, BasicBlock *Exit, + Value *Bound, Value *Step, StringRef Name, + IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L, + LoopInfo &LI) { + LLVMContext &Ctx = Preheader->getContext(); + BasicBlock *Header = BasicBlock::Create( + Preheader->getContext(), Name + ".header", Preheader->getParent(), Exit); + BasicBlock *Body = BasicBlock::Create(Header->getContext(), Name + ".body", + Header->getParent(), Exit); + BasicBlock *Latch = BasicBlock::Create(Header->getContext(), Name + ".latch", + Header->getParent(), Exit); + + Type *I16Ty = Type::getInt16Ty(Ctx); + BranchInst::Create(Body, Header); + BranchInst::Create(Latch, Body); + PHINode *IV = + PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator()); + IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader); + + B.SetInsertPoint(Latch); + Value *Inc = B.CreateAdd(IV, Step, Name + ".step"); + Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond"); + BranchInst::Create(Header, Exit, Cond, Latch); + IV->addIncoming(Inc, Latch); + + BranchInst *PreheaderBr = cast(Preheader->getTerminator()); + BasicBlock *Tmp = PreheaderBr->getSuccessor(0); + PreheaderBr->setSuccessor(0, Header); + DTU.applyUpdatesPermissive({ + {DominatorTree::Delete, Preheader, Tmp}, + {DominatorTree::Insert, Header, Body}, + {DominatorTree::Insert, Body, Latch}, + {DominatorTree::Insert, Latch, Header}, + {DominatorTree::Insert, Latch, Exit}, + {DominatorTree::Insert, Preheader, Header}, + }); + + L->addBasicBlockToLoop(Header, LI); + L->addBasicBlockToLoop(Body, LI); + L->addBasicBlockToLoop(Latch, LI); + return Body; +} + +static Value *CreateTileLoadLoops(BasicBlock *Start, BasicBlock *End, + IRBuilderBase &B, DomTreeUpdater &DTU, + LoopInfo &LI, Value *Row, Value *Col, + Value *Ptr, Value *Stride) { + Loop *RowLoop = LI.AllocateLoop(); + Loop *ColLoop = LI.AllocateLoop(); + RowLoop->addChildLoop(ColLoop); + if (Loop *ParentL = LI.getLoopFor(Start)) + ParentL->addChildLoop(RowLoop); + else + LI.addTopLevelLoop(RowLoop); + + BasicBlock *RowBody = + CreateLoop(Start, End, Row, B.getInt16(1), "rows", B, DTU, RowLoop, LI); + BasicBlock *RowLatch = RowBody->getSingleSuccessor(); + + uint16_t ColStep = B.getInt32Ty()->getPrimitiveSizeInBits() / 8; + BasicBlock *ColBody = CreateLoop(RowBody, RowLatch, Col, B.getInt16(ColStep), + "cols", B, DTU, ColLoop, LI); + + BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); + BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor(); + BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); + Value *CurrentRow = &*RowLoopHeader->begin(); + Value *CurrentCol = &*ColumnLoopHeader->begin(); + + // cols.header: + // %vecphi = phi [%undef, %rows.body] [%vec2, %cols.latch] + B.SetInsertPoint(ColumnLoopHeader->getTerminator()); + FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); + Value *UndefVec = UndefValue::get(V256I32Ty); + PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); + VecPhi->addIncoming(UndefVec, RowBody); + + // cols.body: + // %elt = load i32 i32 *ptr + // %mul = mul i16 %row.iv, i16 16 + // %add = add i16 %mul, i16 %col.iv + // %vec2 = insertelement <16 x i32> %vecphi, i32 %elt, i16 %idx + B.SetInsertPoint(ColBody->getTerminator()); + Type *EltTy = V256I32Ty->getElementType(); + Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType()); + Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType()); + Value *Offset = + B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt); + unsigned AS = cast(Ptr->getType())->getAddressSpace(); + Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS)); + Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset); + Value *Elt = B.CreateLoad(EltTy, EltPtr); + Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); + Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx); + VecPhi->addIncoming(ResVec, ColLoopLatch); + + return ResVec; +} + +static Value *CreateTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End, + IRBuilderBase &B, DomTreeUpdater &DTU, + LoopInfo &LI, Value *Row, Value *Col, + Value *K, Value *Acc, Value *LHS, + Value *RHS) { + Loop *RowLoop = LI.AllocateLoop(); + Loop *ColLoop = LI.AllocateLoop(); + Loop *InnerLoop = LI.AllocateLoop(); + ColLoop->addChildLoop(InnerLoop); + RowLoop->addChildLoop(ColLoop); + if (Loop *ParentL = LI.getLoopFor(Start)) + ParentL->addChildLoop(RowLoop); + else + LI.addTopLevelLoop(RowLoop); + + BasicBlock *RowBody = + CreateLoop(Start, End, Row, B.getInt16(1), "rows", B, DTU, RowLoop, LI); + BasicBlock *RowLatch = RowBody->getSingleSuccessor(); + + BasicBlock *ColBody = CreateLoop(RowBody, RowLatch, Col, B.getInt16(1), + "cols", B, DTU, ColLoop, LI); + BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor(); + + uint16_t KStep = B.getInt32Ty()->getPrimitiveSizeInBits() / 8; + B.SetInsertPoint(ColBody->getTerminator()); + Value *BoundK = B.CreateUDiv(K, B.getInt16(KStep)); + BasicBlock *InnerBody = + CreateLoop(ColBody, ColLoopLatch, BoundK, B.getInt16(1), "inner", B, DTU, + InnerLoop, LI); + + BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor(); + BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor(); + BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor(); + BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor(); + Value *CurrentRow = &*RowLoopHeader->begin(); + Value *CurrentCol = &*ColumnLoopHeader->begin(); + Value *CurrentInner = &*InnerLoopHeader->begin(); + + FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256); + Type *EltTy = V256I32Ty->getElementType(); + Value *VecC, *VecA, *VecB; + if (auto BitCast = dyn_cast(Acc)) + VecC = BitCast->getOperand(0); + assert(VecC->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx"); + // TODO else create BitCast from x86amx to v256i32. + // Store x86amx to memory, and reload from memory + // to vector. However with -O0, it doesn't happen. + if (auto BitCast = dyn_cast(LHS)) + VecA = BitCast->getOperand(0); + assert(VecA->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx"); + if (auto BitCast = dyn_cast(RHS)) + VecB = BitCast->getOperand(0); + assert(VecB->getType()->isVectorTy() && "bitcast from non-v256i32 to x86amx"); + + // Generate PHI vector for C. + B.SetInsertPoint(InnerLoopHeader->getTerminator()); + PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi"); + VecCPhi->addIncoming(VecC, ColBody); + + // Generate accmulate multiply in innerbody. + B.SetInsertPoint(InnerBody->getTerminator()); + Value *IdxC = + B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol); + Value *IdxA = + B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner); + Value *IdxB = + B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol); + + FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4); + Value *EltC = B.CreateExtractElement(VecA, IdxC); + Value *SubVecC = B.CreateBitCast(EltC, V4I8Ty); + Value *EltA = B.CreateExtractElement(VecA, IdxA); + Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty); + Value *EltB = B.CreateExtractElement(VecA, IdxB); + Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty); + Value *SubVecR = B.CreateAdd(B.CreateMul(SubVecA, SubVecB), SubVecC); + Value *ResElt = B.CreateBitCast(SubVecR, EltTy); + Value *NewVecC = B.CreateInsertElement(VecC, ResElt, IdxC); + VecCPhi->addIncoming(NewVecC, InnerLoopLatch); + + return NewVecC; +} + +namespace { +class X86LowerAMXIntrinsics { + Function &Func; + +public: + X86LowerAMXIntrinsics(Function &F, DominatorTree *DT, LoopInfo *LI) + : Func(F), DT(DT), LI(LI) {} + bool visit(); + +private: + DominatorTree *DT; + LoopInfo *LI; + bool lowerTileLoad(Instruction *TileLoad); + bool lowerTileDPBSSD(Instruction *TileDPBSSD); +}; + +bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) { + Value *M, *N, *K, *C, *A, *B; + match(TileDPBSSD, m_Intrinsic( + m_Value(M), m_Value(N), m_Value(K), m_Value(C), + m_Value(A), m_Value(B))); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Instruction *InsertI = TileDPBSSD; + BasicBlock *Start = InsertI->getParent(); + BasicBlock *End = + SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); + IRBuilder<> Builder(TileDPBSSD); + Value *ResVec = + CreateTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M, N, K, C, A, B); + + // Delete tileloadd6 intrinsic and bitcast instruction. + for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end(); + UI != UE;) { + Instruction *I = cast((UI++)->getUser()); + Value *Vec; + if (match(I, m_BitCast(m_Value(Vec)))) { + I->replaceAllUsesWith(ResVec); + I->eraseFromParent(); + } + } + TileDPBSSD->eraseFromParent(); + return true; +} + +bool X86LowerAMXIntrinsics::lowerTileLoad(Instruction *TileLoad) { + Value *M, *N, *Ptr, *Stride; + match(TileLoad, m_Intrinsic( + m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride))); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy); + Instruction *InsertI = TileLoad; + BasicBlock *Start = InsertI->getParent(); + BasicBlock *End = + SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue"); + IRBuilder<> Builder(TileLoad); + Value *ResVec = + CreateTileLoadLoops(Start, End, Builder, DTU, *LI, M, N, Ptr, Stride); + + // Delete tileloadd6 intrinsic and bitcast instruction. + for (auto UI = TileLoad->use_begin(), UE = TileLoad->use_end(); UI != UE;) { + Instruction *I = cast((UI++)->getUser()); + Value *Vec; + if (match(I, m_BitCast(m_Value(Vec)))) { + I->replaceAllUsesWith(ResVec); + I->eraseFromParent(); + } + } + TileLoad->eraseFromParent(); + return true; +} + +bool X86LowerAMXIntrinsics::visit() { + bool C; + SmallVector TileDPBSSDs; + SmallVector TileLoads; + SmallVector TileStores; + + for (BasicBlock *BB : post_order(&Func)) { + for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); + II != IE;) { + Instruction &Inst = *II++; + if (match(&Inst, m_Intrinsic())) { + // %amx1 = bitcast <256 x i32> %vec to x86_amx + // %res = call x86_amx @llvm.x86.tdpbssd.internal(i16 m, i16 n, i16 k, + // x86_amx, %amx1, ...) + // %vec2 = bitcast x86_amx %res to <256 x i32> + TileDPBSSDs.push_back(&Inst); + } else if (match(&Inst, + m_Intrinsic())) { + // %17 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %13, i16 %14, + // i8* %15, i64 %16) + // %18 = bitcast x86_amx %17 to <256 x i32> + TileLoads.push_back(&Inst); + } else if (match(&Inst, + m_Intrinsic())) { + // %89 = bitcast <256 x i32> %88 to x86_amx + // call void @llvm.x86.tilestored64.internal(i16 %84, i16 %85, i8* %86, + // i64 %87, x86_amx %89) + // lowerTileStore(); + TileStores.push_back(&Inst); + } + } + } + + for (auto *Inst : TileLoads) { + C |= lowerTileLoad(Inst); + } + for (auto *Inst : TileDPBSSDs) { + C |= lowerTileDPBSSD(Inst); + } + + return C; +} +} // anonymous namespace + +namespace { + +class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass { +public: + static char ID; + + X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) { + initializeX86LowerAMXIntrinsicsLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis().getDomTree(); + auto &LI = getAnalysis().getLoopInfo(); + + X86LowerAMXIntrinsics LAT(F, &DT, &LI); + bool C = LAT.visit(); + return C; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.addPreserved(); + } +}; + +} // anonymous namespace + +static const char PassName[] = "Lower AMX intrinsics"; +char X86LowerAMXIntrinsicsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, + false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName, + false, false) + +FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() { + return new X86LowerAMXIntrinsicsLegacyPass(); +} diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp --- a/llvm/lib/Target/X86/X86TargetMachine.cpp +++ b/llvm/lib/Target/X86/X86TargetMachine.cpp @@ -62,6 +62,7 @@ RegisterTargetMachine Y(getTheX86_64Target()); PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeX86LowerAMXIntrinsicsLegacyPassPass(PR); initializeX86LowerAMXTypeLegacyPassPass(PR); initializeGlobalISel(PR); initializeWinEHStatePassPass(PR); @@ -410,7 +411,12 @@ void X86PassConfig::addIRPasses() { addPass(createAtomicExpandPass()); - addPass(createX86LowerAMXTypePass()); + + if (TM->getOptLevel() == CodeGenOpt::None) + addPass(createX86LowerAMXIntrinsicsPass()); + else { + addPass(createX86LowerAMXTypePass()); + } TargetPassConfig::addIRPasses(); diff --git a/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll @@ -0,0 +1,116 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-amx-intrinsics %s -S | FileCheck %s + +define dso_local void @test_amx_load(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr) { +; CHECK-LABEL: @test_amx_load( +; CHECK-NEXT: entry: +; CHECK-NEXT: br label [[ROWS_HEADER:%.*]] +; CHECK: rows.header: +; CHECK-NEXT: [[ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[ROWS_STEP:%.*]], [[ROWS_LATCH:%.*]] ] +; CHECK-NEXT: br label [[ROWS_BODY:%.*]] +; CHECK: rows.body: +; CHECK-NEXT: br label [[COLS_HEADER:%.*]] +; CHECK: cols.header: +; CHECK-NEXT: [[COLS_IV:%.*]] = phi i16 [ 0, [[ROWS_BODY]] ], [ [[COLS_STEP:%.*]], [[COLS_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <256 x i32> [ undef, [[ROWS_BODY]] ], [ [[TMP9:%.*]], [[COLS_LATCH]] ] +; CHECK-NEXT: br label [[COLS_BODY:%.*]] +; CHECK: cols.body: +; CHECK-NEXT: [[TMP0:%.*]] = zext i16 [[ROWS_IV]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = zext i16 [[COLS_IV]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = mul i64 [[TMP0]], [[STRIDE:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = add i64 [[TMP2]], [[TMP1]] +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i8* [[PTR:%.*]] to i32* +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i32, i32* [[TMP4]], i64 [[TMP3]] +; CHECK-NEXT: [[TMP6:%.*]] = load i32, i32* [[TMP5]], align 4 +; CHECK-NEXT: [[TMP7:%.*]] = mul i16 [[ROWS_IV]], 16 +; CHECK-NEXT: [[TMP8:%.*]] = add i16 [[TMP7]], [[COLS_IV]] +; CHECK-NEXT: [[TMP9]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP6]], i16 [[TMP8]] +; CHECK-NEXT: br label [[COLS_LATCH]] +; CHECK: cols.latch: +; CHECK-NEXT: [[COLS_STEP]] = add i16 [[COLS_IV]], 4 +; CHECK-NEXT: [[COLS_COND:%.*]] = icmp ne i16 [[COLS_STEP]], [[COL:%.*]] +; CHECK-NEXT: br i1 [[COLS_COND]], label [[COLS_HEADER]], label [[ROWS_LATCH]] +; CHECK: rows.latch: +; CHECK-NEXT: [[ROWS_STEP]] = add i16 [[ROWS_IV]], 1 +; CHECK-NEXT: [[ROWS_COND:%.*]] = icmp ne i16 [[ROWS_STEP]], [[ROW:%.*]] +; CHECK-NEXT: br i1 [[ROWS_COND]], label [[ROWS_HEADER]], label [[CONTINUE:%.*]] +; CHECK: continue: +; CHECK-NEXT: store <256 x i32> [[TMP9]], <256 x i32>* [[VPTR:%.*]], align 64 +; CHECK-NEXT: ret void +; +entry: + %amx = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride) + %vec = bitcast x86_amx %amx to <256 x i32> + store <256 x i32> %vec, <256 x i32>* %vptr, align 64 + ret void +} + +define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) { +; CHECK-LABEL: @test_amx_dp( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx +; CHECK-NEXT: [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx +; CHECK-NEXT: [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx +; CHECK-NEXT: br label [[ROWS_HEADER:%.*]] +; CHECK: rows.header: +; CHECK-NEXT: [[ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[ROWS_STEP:%.*]], [[ROWS_LATCH:%.*]] ] +; CHECK-NEXT: br label [[ROWS_BODY:%.*]] +; CHECK: rows.body: +; CHECK-NEXT: br label [[COLS_HEADER:%.*]] +; CHECK: cols.header: +; CHECK-NEXT: [[COLS_IV:%.*]] = phi i16 [ 0, [[ROWS_BODY]] ], [ [[COLS_STEP:%.*]], [[COLS_LATCH:%.*]] ] +; CHECK-NEXT: br label [[COLS_BODY:%.*]] +; CHECK: cols.body: +; CHECK-NEXT: [[TMP0:%.*]] = udiv i16 [[K:%.*]], 4 +; CHECK-NEXT: br label [[INNER_HEADER:%.*]] +; CHECK: inner.header: +; CHECK-NEXT: [[INNER_IV:%.*]] = phi i16 [ 0, [[COLS_BODY]] ], [ [[INNER_STEP:%.*]], [[INNER_LATCH:%.*]] ] +; CHECK-NEXT: [[VEC_PHI:%.*]] = phi <256 x i32> [ [[C]], [[COLS_BODY]] ], [ [[TMP16:%.*]], [[INNER_LATCH]] ] +; CHECK-NEXT: br label [[INNER_BODY:%.*]] +; CHECK: inner.body: +; CHECK-NEXT: [[TMP1:%.*]] = mul i16 [[ROWS_IV]], 16 +; CHECK-NEXT: [[TMP2:%.*]] = add i16 [[TMP1]], [[COLS_IV]] +; CHECK-NEXT: [[TMP3:%.*]] = mul i16 [[ROWS_IV]], 16 +; CHECK-NEXT: [[TMP4:%.*]] = add i16 [[TMP3]], [[INNER_IV]] +; CHECK-NEXT: [[TMP5:%.*]] = mul i16 [[INNER_IV]], 16 +; CHECK-NEXT: [[TMP6:%.*]] = add i16 [[TMP5]], [[COLS_IV]] +; CHECK-NEXT: [[TMP7:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP2]] +; CHECK-NEXT: [[TMP8:%.*]] = bitcast i32 [[TMP7]] to <4 x i8> +; CHECK-NEXT: [[TMP9:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP4]] +; CHECK-NEXT: [[TMP10:%.*]] = bitcast i32 [[TMP9]] to <4 x i8> +; CHECK-NEXT: [[TMP11:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP6]] +; CHECK-NEXT: [[TMP12:%.*]] = bitcast i32 [[TMP11]] to <4 x i8> +; CHECK-NEXT: [[TMP13:%.*]] = mul <4 x i8> [[TMP10]], [[TMP12]] +; CHECK-NEXT: [[TMP14:%.*]] = add <4 x i8> [[TMP13]], [[TMP8]] +; CHECK-NEXT: [[TMP15:%.*]] = bitcast <4 x i8> [[TMP14]] to i32 +; CHECK-NEXT: [[TMP16]] = insertelement <256 x i32> [[C]], i32 [[TMP15]], i16 [[TMP2]] +; CHECK-NEXT: br label [[INNER_LATCH]] +; CHECK: inner.latch: +; CHECK-NEXT: [[INNER_STEP]] = add i16 [[INNER_IV]], 1 +; CHECK-NEXT: [[INNER_COND:%.*]] = icmp ne i16 [[INNER_STEP]], [[TMP0]] +; CHECK-NEXT: br i1 [[INNER_COND]], label [[INNER_HEADER]], label [[COLS_LATCH]] +; CHECK: cols.latch: +; CHECK-NEXT: [[COLS_STEP]] = add i16 [[COLS_IV]], 1 +; CHECK-NEXT: [[COLS_COND:%.*]] = icmp ne i16 [[COLS_STEP]], [[COL:%.*]] +; CHECK-NEXT: br i1 [[COLS_COND]], label [[COLS_HEADER]], label [[ROWS_LATCH]] +; CHECK: rows.latch: +; CHECK-NEXT: [[ROWS_STEP]] = add i16 [[ROWS_IV]], 1 +; CHECK-NEXT: [[ROWS_COND:%.*]] = icmp ne i16 [[ROWS_STEP]], [[ROW:%.*]] +; CHECK-NEXT: br i1 [[ROWS_COND]], label [[ROWS_HEADER]], label [[CONTINUE:%.*]] +; CHECK: continue: +; CHECK-NEXT: store <256 x i32> [[TMP16]], <256 x i32>* [[VPTR:%.*]], align 64 +; CHECK-NEXT: ret void +; +entry: + %a.amx = bitcast <256 x i32> %a to x86_amx + %b.amx = bitcast <256 x i32> %b to x86_amx + %c.amx = bitcast <256 x i32> %c to x86_amx + %acc = call x86_amx @llvm.x86.tdpbssd.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx) + %vec = bitcast x86_amx %acc to <256 x i32> + store <256 x i32> %vec, <256 x i32>* %vptr, align 64 + ret void +} + +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) diff --git a/llvm/test/CodeGen/X86/O0-pipeline.ll b/llvm/test/CodeGen/X86/O0-pipeline.ll --- a/llvm/test/CodeGen/X86/O0-pipeline.ll +++ b/llvm/test/CodeGen/X86/O0-pipeline.ll @@ -18,7 +18,9 @@ ; CHECK-NEXT: Pre-ISel Intrinsic Lowering ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Expand Atomic instructions -; CHECK-NEXT: Lower AMX type for load/store +; CHECK-NEXT: Dominator Tree Construction +; CHECK-NEXT: Natural Loop Information +; CHECK-NEXT: Lower AMX intrinsics ; CHECK-NEXT: Module Verifier ; CHECK-NEXT: Lower Garbage Collection Instructions ; CHECK-NEXT: Shadow Stack GC Lowering