diff --git a/clang/include/clang/Basic/BuiltinsX86_64.def b/clang/include/clang/Basic/BuiltinsX86_64.def --- a/clang/include/clang/Basic/BuiltinsX86_64.def +++ b/clang/include/clang/Basic/BuiltinsX86_64.def @@ -94,6 +94,11 @@ TARGET_BUILTIN(__builtin_ia32_cvtusi2ss64, "V4fV4fUOiIi", "ncV:128:", "avx512f") TARGET_BUILTIN(__builtin_ia32_directstore_u64, "vULi*ULi", "n", "movdiri") +// AMX internal builtin +TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile") +TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUsV256i", "n", "amx-tile") +TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8") +TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile") // AMX TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile") TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile") diff --git a/clang/lib/Headers/amxintrin.h b/clang/lib/Headers/amxintrin.h --- a/clang/lib/Headers/amxintrin.h +++ b/clang/lib/Headers/amxintrin.h @@ -66,6 +66,8 @@ __builtin_ia32_tilerelease(); } +#undef __DEFAULT_FN_ATTRS + /// Load tile rows from memory specifieid by "base" address and "stride" into /// destination tile "dst" using the tile configuration previously configured /// via "_tile_loadconfig". @@ -219,6 +221,56 @@ #define _tile_dpbf16ps(dst, src0, src1) \ __builtin_ia32_tdpbf16ps((dst), (src0), (src1)) +#define __DEFAULT_FN_ATTRS \ + __attribute__((__always_inline__, __nodebug__, __target__("amx-int8"))) + +/// This is new intrinsic interface +typedef int _tile_data __attribute__((__vector_size__(1024), __aligned__(64))); +static __inline__ _tile_data __DEFAULT_FN_ATTRS +_tile_loadd_internal(short m, short n, const void *base, int stride) { + return __builtin_ia32_tileloadd64_internal(m, n, base, + (__SIZE_TYPE__)(stride)); +} + +static __inline__ _tile_data __DEFAULT_FN_ATTRS +_tile_zero_internal(short m, short n, _tile_data tile) { + return __builtin_ia32_tilezero_internal(m, n, tile); +} + +static __inline__ _tile_data __DEFAULT_FN_ATTRS +_tile_dpbssd_internal(short m, short n, short k, _tile_data dst, + _tile_data src1, _tile_data src2) { + return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2); +} + +static __inline__ void __DEFAULT_FN_ATTRS _tile_stored_internal( + short m, short n, void *base, int stride, _tile_data tile) { + return __builtin_ia32_tilestored64_internal(m, n, base, + (__SIZE_TYPE__)(stride), tile); +} + +typedef struct __tile_str { + const short row; + const short col; + _tile_data tile; +} __tile; + +__DEFAULT_FN_ATTRS +void __tile_loadd(__tile *dst, const void *base, long stride) { + dst->tile = _tile_loadd_internal(dst->row, dst->col, base, stride); +} + +__DEFAULT_FN_ATTRS +void __tile_dpbsud(__tile *dst, __tile src1, __tile src2) { + dst->tile = _tile_dpbssd_internal(src1.row, src2.col, src1.col, dst->tile, + src1.tile, src2.tile); +} + +__DEFAULT_FN_ATTRS +void __tile_stored(void *base, long stride, __tile src) { + _tile_stored_internal(src.row, src.col, base, stride, src.tile); +} + #undef __DEFAULT_FN_ATTRS #endif /* __x86_64__ */ diff --git a/clang/test/CodeGen/AMX/amx_api.c b/clang/test/CodeGen/AMX/amx_api.c new file mode 100644 --- /dev/null +++ b/clang/test/CodeGen/AMX/amx_api.c @@ -0,0 +1,31 @@ +// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown -target-feature +avx512f -target-feature +amx-int8 \ +// RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK + +#include + +char buf[1024]; +#define STRIDE 32 + +char buf2[1024]; + +void test_api(int cond, short row, short col) { +//CHECK-LABEL: @test_api +//CHECK: call <256 x i32> @llvm.x86.tileloadd64.internal +//CHECK: call <256 x i32> @llvm.x86.tdpbssd.internal +//CHECK: call void @llvm.x86.tilestored64.internal + __tile a = {row, 8}; + __tile b = {8, col}; + __tile c = {row, col}; + + if(cond) { + __tile_loadd(&a, buf, STRIDE); + __tile_loadd(&b, buf, STRIDE); + __tile_loadd(&c, buf, STRIDE); + } else { + __tile_loadd(&a, buf2, STRIDE); + __tile_loadd(&b, buf2, STRIDE); + __tile_loadd(&c, buf2, STRIDE); + } + __tile_dpbsud(&c, a, b); + __tile_stored(buf, STRIDE, c); +} diff --git a/llvm/include/llvm/CodeGen/LiveIntervalUnion.h b/llvm/include/llvm/CodeGen/LiveIntervalUnion.h --- a/llvm/include/llvm/CodeGen/LiveIntervalUnion.h +++ b/llvm/include/llvm/CodeGen/LiveIntervalUnion.h @@ -104,6 +104,9 @@ void verify(LiveVirtRegBitSet& VisitedVRegs); #endif + // Get any virtual register that is assign to this physical unit + LiveInterval *getOneVReg() const; + /// Query interferences between a single live virtual register and a live /// interval union. class Query { diff --git a/llvm/include/llvm/CodeGen/LiveRegMatrix.h b/llvm/include/llvm/CodeGen/LiveRegMatrix.h --- a/llvm/include/llvm/CodeGen/LiveRegMatrix.h +++ b/llvm/include/llvm/CodeGen/LiveRegMatrix.h @@ -41,6 +41,7 @@ const TargetRegisterInfo *TRI; LiveIntervals *LIS; VirtRegMap *VRM; + MachineRegisterInfo *MRI; // UserTag changes whenever virtual registers have been modified. unsigned UserTag = 0; @@ -152,6 +153,8 @@ /// Directly access the live interval unions per regunit. /// This returns an array indexed by the regunit number. LiveIntervalUnion *getLiveUnions() { return &Matrix[0]; } + + Register getOneVReg(unsigned PhysReg) const; }; } // end namespace llvm diff --git a/llvm/include/llvm/CodeGen/Passes.h b/llvm/include/llvm/CodeGen/Passes.h --- a/llvm/include/llvm/CodeGen/Passes.h +++ b/llvm/include/llvm/CodeGen/Passes.h @@ -490,6 +490,8 @@ /// The pass fixups statepoint machine instruction to replace usage of /// caller saved registers with stack slots. extern char &FixupStatepointCallerSavedID; + + FunctionPass *createX86LowerAMXTypePass(); } // End llvm namespace #endif diff --git a/llvm/include/llvm/CodeGen/TileShapeInfo.h b/llvm/include/llvm/CodeGen/TileShapeInfo.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/TileShapeInfo.h @@ -0,0 +1,103 @@ +//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines TileShapeInfo for AMX. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_TILESHAPEINFO_H +#define LLVM_CODEGEN_TILESHAPEINFO_H + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineOperand.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/Register.h" +#include + +using namespace llvm; + +namespace llvm { + +class ShapeT { +public: + ShapeT(MachineOperand *Row, MachineOperand *Col, + const MachineRegisterInfo *MRI = nullptr) + : Row(Row), Col(Col) { + if (MRI) + deduceImm(MRI); + } + ShapeT() + : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape), + ColImm(InvalidImmShape) {} + bool operator==(const ShapeT &Shape) { + MachineOperand *R = Shape.Row; + MachineOperand *C = Shape.Col; + if (!R || !C) + return false; + if (!Row || !Col) + return false; + if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg()) + return true; + if ((RowImm != InvalidImmShape) && (Shape.getRowImm() != InvalidImmShape) && + (ColImm != InvalidImmShape) && (Shape.getColImm() != InvalidImmShape)) { + return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm(); + } + return false; + } + + bool operator!=(const ShapeT &Shape) { return !(*this == Shape); } + + ShapeT &operator=(const ShapeT &RHS) { + Row = RHS.Row; + Col = RHS.Col; + RowImm = RHS.RowImm; + ColImm = RHS.ColImm; + return *this; + } + + MachineOperand *getRow() const { return Row; } + + MachineOperand *getCol() const { return Col; } + + int64_t getRowImm() const { return RowImm; } + + int64_t getColImm() const { return ColImm; } + + bool isValid() { return (Row != nullptr) && (Col != nullptr); } + + void deduceImm(const MachineRegisterInfo *MRI) { + // All def must be the same value, otherwise it is invalid MIs. + // Find the immediate. + // TODO copy propagation. + auto GetImm = [&](Register Reg) { + int64_t Imm = InvalidImmShape; + for (const MachineOperand &DefMO : MRI->def_operands(Reg)) { + const auto *MI = DefMO.getParent(); + if (MI->isMoveImmediate()) { + Imm = MI->getOperand(1).getImm(); + break; + } + } + return Imm; + }; + RowImm = GetImm(Row->getReg()); + ColImm = GetImm(Col->getReg()); + } + +private: + static constexpr int64_t InvalidImmShape = -1; + MachineOperand *Row; + MachineOperand *Col; + int64_t RowImm; + int64_t ColImm; +}; + +} // namespace llvm + +#endif diff --git a/llvm/include/llvm/CodeGen/VirtRegMap.h b/llvm/include/llvm/CodeGen/VirtRegMap.h --- a/llvm/include/llvm/CodeGen/VirtRegMap.h +++ b/llvm/include/llvm/CodeGen/VirtRegMap.h @@ -19,6 +19,7 @@ #include "llvm/ADT/IndexedMap.h" #include "llvm/CodeGen/MachineFunctionPass.h" #include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TileShapeInfo.h" #include "llvm/Pass.h" #include @@ -60,6 +61,10 @@ /// mapping. IndexedMap Virt2SplitMap; + /// Virt2ShapeMap - For X86 AMX register whose register is bound shape + /// information. + DenseMap Virt2ShapeMap; + /// createSpillSlot - Allocate a spill slot for RC from MFI. unsigned createSpillSlot(const TargetRegisterClass *RC); @@ -107,6 +112,21 @@ /// the specified physical register void assignVirt2Phys(Register virtReg, MCPhysReg physReg); + bool isShapeMapEmpty() const { return Virt2ShapeMap.empty(); } + + bool hasShape(Register virtReg) const { + return getShape(virtReg).isValid(); + } + + ShapeT getShape(Register virtReg) const { + assert(virtReg.isVirtual()); + return Virt2ShapeMap.lookup(virtReg); + } + + void assignVirt2Shape(Register virtReg, ShapeT shape) { + Virt2ShapeMap[virtReg.id()] = shape; + } + /// clears the specified virtual register's, physical /// register mapping void clearVirt(Register virtReg) { @@ -133,6 +153,9 @@ /// records virtReg is a split live interval from SReg. void setIsSplitFromReg(Register virtReg, unsigned SReg) { Virt2SplitMap[virtReg.id()] = SReg; + if (hasShape(SReg)) { + Virt2ShapeMap[virtReg.id()] = getShape(SReg); + } } /// returns the live interval virtReg is split from. diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -289,6 +289,7 @@ def llvm_v16i32_ty : LLVMType; // 16 x i32 def llvm_v32i32_ty : LLVMType; // 32 x i32 def llvm_v64i32_ty : LLVMType; // 64 x i32 +def llvm_v256i32_ty : LLVMType; //256 x i32 def llvm_v1i64_ty : LLVMType; // 1 x i64 def llvm_v2i64_ty : LLVMType; // 2 x i64 diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -4977,3 +4977,26 @@ def int_x86_tdpbf16ps : GCCBuiltin<"__builtin_ia32_tdpbf16ps">, Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], []>; } + +// AMX - internal intrinsics +let TargetPrefix = "x86" in { + def int_x86_tileloadd64_internal : + GCCBuiltin<"__builtin_ia32_tileloadd64_internal">, + Intrinsic<[llvm_v256i32_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty], + []>; + def int_x86_tilezero_internal : + GCCBuiltin<"__builtin_ia32_tilezero_internal">, + Intrinsic<[llvm_v256i32_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_v256i32_ty], []>; + def int_x86_tdpbssd_internal : + GCCBuiltin<"__builtin_ia32_tdpbssd_internal">, + Intrinsic<[llvm_v256i32_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_v256i32_ty, llvm_v256i32_ty, + llvm_v256i32_ty], []>; + def int_x86_tilestored64_internal : + GCCBuiltin<"__builtin_ia32_tilestored64_internal">, + Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_v256i32_ty], []>; +} diff --git a/llvm/lib/CodeGen/InlineSpiller.cpp b/llvm/lib/CodeGen/InlineSpiller.cpp --- a/llvm/lib/CodeGen/InlineSpiller.cpp +++ b/llvm/lib/CodeGen/InlineSpiller.cpp @@ -1556,6 +1556,8 @@ VRM.assignVirt2Phys(New, VRM.getPhys(Old)); else if (VRM.getStackSlot(Old) != VirtRegMap::NO_STACK_SLOT) VRM.assignVirt2StackSlot(New, VRM.getStackSlot(Old)); + else if (VRM.hasShape(Old)) + VRM.assignVirt2Shape(New, VRM.getShape(Old)); else llvm_unreachable("VReg should be assigned either physreg or stackslot"); } diff --git a/llvm/lib/CodeGen/LiveIntervalUnion.cpp b/llvm/lib/CodeGen/LiveIntervalUnion.cpp --- a/llvm/lib/CodeGen/LiveIntervalUnion.cpp +++ b/llvm/lib/CodeGen/LiveIntervalUnion.cpp @@ -99,6 +99,16 @@ } #endif //!NDEBUG +LiveInterval *LiveIntervalUnion::getOneVReg() const { + if (empty()) + return nullptr; + for (LiveSegments::const_iterator SI = Segments.begin(); SI.valid(); ++SI) { + // return the first valid live interval + return SI.value(); + } + return nullptr; +} + // Scan the vector of interfering virtual registers in this union. Assume it's // quite small. bool LiveIntervalUnion::Query::isSeenInterference(LiveInterval *VirtReg) const { diff --git a/llvm/lib/CodeGen/LiveRegMatrix.cpp b/llvm/lib/CodeGen/LiveRegMatrix.cpp --- a/llvm/lib/CodeGen/LiveRegMatrix.cpp +++ b/llvm/lib/CodeGen/LiveRegMatrix.cpp @@ -54,6 +54,7 @@ bool LiveRegMatrix::runOnMachineFunction(MachineFunction &MF) { TRI = MF.getSubtarget().getRegisterInfo(); + MRI = &MF.getRegInfo(); LIS = &getAnalysis(); VRM = &getAnalysis(); @@ -221,3 +222,13 @@ } return false; } + +Register LiveRegMatrix::getOneVReg(unsigned PhysReg) const { + LiveInterval *VRegInterval = nullptr; + for (MCRegUnitIterator Unit(PhysReg, TRI); Unit.isValid(); ++Unit) { + if ((VRegInterval = Matrix[*Unit].getOneVReg())) + return VRegInterval->reg(); + } + + return MCRegister::NoRegister; +} diff --git a/llvm/lib/CodeGen/VirtRegMap.cpp b/llvm/lib/CodeGen/VirtRegMap.cpp --- a/llvm/lib/CodeGen/VirtRegMap.cpp +++ b/llvm/lib/CodeGen/VirtRegMap.cpp @@ -68,6 +68,7 @@ Virt2PhysMap.clear(); Virt2StackSlotMap.clear(); Virt2SplitMap.clear(); + Virt2ShapeMap.clear(); grow(); return false; diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -826,7 +826,8 @@ IIT_SUBDIVIDE4_ARG = 45, IIT_VEC_OF_BITCASTS_TO_INT = 46, IIT_V128 = 47, - IIT_BF16 = 48 + IIT_BF16 = 48, + IIT_V256 = 49 }; static void DecodeIITType(unsigned &NextElt, ArrayRef Infos, @@ -920,6 +921,10 @@ OutputTable.push_back(IITDescriptor::getVector(128, IsScalableVector)); DecodeIITType(NextElt, Infos, Info, OutputTable); return; + case IIT_V256: + OutputTable.push_back(IITDescriptor::getVector(256, IsScalableVector)); + DecodeIITType(NextElt, Infos, Info, OutputTable); + return; case IIT_V512: OutputTable.push_back(IITDescriptor::getVector(512, IsScalableVector)); DecodeIITType(NextElt, Infos, Info, OutputTable); diff --git a/llvm/lib/Target/X86/CMakeLists.txt b/llvm/lib/Target/X86/CMakeLists.txt --- a/llvm/lib/Target/X86/CMakeLists.txt +++ b/llvm/lib/Target/X86/CMakeLists.txt @@ -30,6 +30,8 @@ X86CmovConversion.cpp X86DomainReassignment.cpp X86DiscriminateMemOps.cpp + X86LowerAMXType.cpp + X86TileConfig.cpp X86ExpandPseudo.cpp X86FastISel.cpp X86FixupBWInsts.cpp diff --git a/llvm/lib/Target/X86/X86.h b/llvm/lib/Target/X86/X86.h --- a/llvm/lib/Target/X86/X86.h +++ b/llvm/lib/Target/X86/X86.h @@ -76,6 +76,8 @@ /// Return a pass that expands WinAlloca pseudo-instructions. FunctionPass *createX86WinAllocaExpander(); +FunctionPass *createX86TileConfigPass(); + /// Return a pass that inserts int3 at the end of the function if it ends with a /// CALL instruction. The pass does the same for each funclet as well. This /// ensures that the open interval of function start and end PCs contains all @@ -162,6 +164,8 @@ void initializeX86PartialReductionPass(PassRegistry &); void initializeX86SpeculativeLoadHardeningPassPass(PassRegistry &); void initializeX86SpeculativeExecutionSideEffectSuppressionPass(PassRegistry &); +void initializeX86TileConfigPass(PassRegistry &); +void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &); namespace X86AS { enum : unsigned { diff --git a/llvm/lib/Target/X86/X86ExpandPseudo.cpp b/llvm/lib/Target/X86/X86ExpandPseudo.cpp --- a/llvm/lib/Target/X86/X86ExpandPseudo.cpp +++ b/llvm/lib/Target/X86/X86ExpandPseudo.cpp @@ -468,6 +468,26 @@ case TargetOpcode::ICALL_BRANCH_FUNNEL: ExpandICallBranchFunnel(&MBB, MBBI); return true; + case X86::PTILELOADDV: { + for (unsigned i = 2; i > 0; --i) + MI.RemoveOperand(i); + MI.setDesc(TII->get(X86::TILELOADD)); + return true; + } + case X86::PTDPBSSDV: { + MI.untieRegOperand(4); + for (unsigned i = 3; i > 0; --i) + MI.RemoveOperand(i); + MI.setDesc(TII->get(X86::TDPBSSD)); + MI.tieOperands(0, 1); + return true; + } + case X86::PTILESTOREDV: { + for (int i = 1; i >= 0; --i) + MI.RemoveOperand(i); + MI.setDesc(TII->get(X86::TILESTORED)); + return true; + } } llvm_unreachable("Previous switch has a fallthrough?"); } diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4480,6 +4480,52 @@ switch (Opcode) { default: break; + case ISD::INTRINSIC_W_CHAIN: { + unsigned IntNo = Node->getConstantOperandVal(1); + switch (IntNo) { + default: + break; + case Intrinsic::x86_tileloadd64_internal: { + if (!Subtarget->hasAMXTILE()) + break; + unsigned Opc = X86::PTILELOADDV; + // _tile_loadd_internal(row, col, buf, STRIDE) + SDValue Base = Node->getOperand(4); + SDValue Scale = getI8Imm(1, dl); + SDValue Index = Node->getOperand(5); + SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32); + SDValue Segment = CurDAG->getRegister(0, MVT::i16); + SDValue Chain = Node->getOperand(0); + MachineSDNode *CNode; + SDValue Ops[] = {Node->getOperand(2), + Node->getOperand(3), + Base, + Scale, + Index, + Disp, + Segment, + Chain}; + CNode = CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops); + ReplaceNode(Node, CNode); + return; + } + case Intrinsic::x86_tdpbssd_internal: { + if (!Subtarget->hasAMXTILE()) + break; + unsigned Opc = X86::PTDPBSSDV; + SDValue Ops[] = {Node->getOperand(2), Node->getOperand(3), + Node->getOperand(4), Node->getOperand(5), + Node->getOperand(6), Node->getOperand(7)}; + MachineSDNode *CNode = + CurDAG->getMachineNode(Opc, dl, + {MVT::v256i32, MVT::Other}, + Ops); + ReplaceNode(Node, CNode); + return; + } + } + break; + } case ISD::INTRINSIC_VOID: { unsigned IntNo = Node->getConstantOperandVal(1); switch (IntNo) { @@ -4534,6 +4580,29 @@ break; } + case Intrinsic::x86_tilestored64_internal: { + unsigned Opc = X86::PTILESTOREDV; + // _tile_stored_internal(row, col, buf, STRIDE, c) + SDValue Base = Node->getOperand(4); + SDValue Scale = getI8Imm(1, dl); + SDValue Index = Node->getOperand(5); + SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32); + SDValue Segment = CurDAG->getRegister(0, MVT::i16); + SDValue Chain = Node->getOperand(0); + MachineSDNode *CNode; + SDValue Ops[] = {Node->getOperand(2), + Node->getOperand(3), + Base, + Scale, + Index, + Disp, + Segment, + Node->getOperand(6), + Chain}; + CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops); + ReplaceNode(Node, CNode); + return; + } case Intrinsic::x86_tileloadd64: case Intrinsic::x86_tileloaddt164: case Intrinsic::x86_tilestored64: { diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1885,6 +1885,10 @@ setOperationAction(ISD::TRUNCATE, MVT::v16i64, Custom); } + if (Subtarget.hasAMXTILE()) { + addRegisterClass(MVT::v256i32, &X86::TILERegClass); + } + // We want to custom lower some of our intrinsics. setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom); setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom); @@ -5271,6 +5275,12 @@ // width. if (MemVT.getSizeInBits() > Subtarget.getPreferVectorWidth()) return false; + + // Don't merge to x86 amx tile, as we only map MVT::v256i32 + // to x86 amx tile on amx intrinsics. + if (MemVT == MVT::v256i32) + return false; + return true; } diff --git a/llvm/lib/Target/X86/X86InstrAMX.td b/llvm/lib/Target/X86/X86InstrAMX.td --- a/llvm/lib/Target/X86/X86InstrAMX.td +++ b/llvm/lib/Target/X86/X86InstrAMX.td @@ -23,6 +23,7 @@ def STTILECFG : I <0x49, MRM0m, (outs), (ins opaquemem:$src), "sttilecfg\t$src", [(int_x86_sttilecfg addr:$src)]>, VEX, T8PD; + let mayLoad = 1 in def TILELOADD : I<0x4b, MRMSrcMemFSIB, (outs TILE:$dst), (ins sibmem:$src), "tileloadd\t{$src, $dst|$dst, $src}", []>, @@ -34,6 +35,7 @@ let Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in def TILERELEASE : I<0x49, MRM_C0, (outs), (ins), "tilerelease", [(int_x86_tilerelease)]>, VEX, T8PS; + let mayStore = 1 in def TILESTORED : I<0x4b, MRMDestMemFSIB, (outs), (ins sibmem:$dst, TILE:$src), "tilestored\t{$src, $dst|$dst, $src}", []>, @@ -42,6 +44,11 @@ "tilezero\t$dst", []>, VEX, T8XD; + def PTILESTOREDV : PseudoI<(outs), (ins GR16:$src1, + GR16:$src2, opaquemem:$src3, TILE:$src4), []>; + def PTILELOADDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, opaquemem:$src3), []>; + let usesCustomInserter = 1 in { // Pseudo instructions, using immediates instead of tile registers. // To be translated to the actual instructions in X86ISelLowering.cpp @@ -76,6 +83,11 @@ VEX_4V, T8PS; } + let Constraints = "$src4 = $dst" in + def PTDPBSSDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1, + GR16:$src2, GR16:$src3, TILE:$src4, + TILE:$src5, TILE:$src6), []>; + let usesCustomInserter = 1 in { // Pseudo instructions, using immediates instead of tile registers. // To be translated to the actual instructions in X86ISelLowering.cpp diff --git a/llvm/lib/Target/X86/X86InstrInfo.cpp b/llvm/lib/Target/X86/X86InstrInfo.cpp --- a/llvm/lib/Target/X86/X86InstrInfo.cpp +++ b/llvm/lib/Target/X86/X86InstrInfo.cpp @@ -3758,13 +3758,27 @@ const MachineFunction &MF = *MBB.getParent(); assert(MF.getFrameInfo().getObjectSize(FrameIdx) >= TRI->getSpillSize(*RC) && "Stack slot too small for store"); - unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); - bool isAligned = - (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || - RI.canRealignStack(MF); - unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget); - addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx) - .addReg(SrcReg, getKillRegState(isKill)); + if (RC->getID() != X86::TILERegClassID) { + unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); + bool isAligned = + (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || + RI.canRealignStack(MF); + unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget); + addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx) + .addReg(SrcReg, getKillRegState(isKill)); + } else { + unsigned Opc = X86::TILESTORED; + // tilestored %tmm, (%sp, %idx) + MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); + Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); + MachineInstr *NewMI = + BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64); + NewMI = addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx) + .addReg(SrcReg, getKillRegState(isKill)); + MachineOperand &MO = NewMI->getOperand(2); + MO.setReg(VirtReg); + MO.setIsKill(true); + } } void X86InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB, @@ -3772,13 +3786,28 @@ Register DestReg, int FrameIdx, const TargetRegisterClass *RC, const TargetRegisterInfo *TRI) const { - const MachineFunction &MF = *MBB.getParent(); - unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); - bool isAligned = - (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || - RI.canRealignStack(MF); - unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget); - addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), FrameIdx); + if (RC->getID() != X86::TILERegClassID) { + const MachineFunction &MF = *MBB.getParent(); + unsigned Alignment = std::max(TRI->getSpillSize(*RC), 16); + bool isAligned = + (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) || + RI.canRealignStack(MF); + unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget); + addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), + FrameIdx); + } else { + unsigned Opc = X86::TILELOADD; + // tileloadd (%sp, %idx), %tmm + MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo(); + Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass); + MachineInstr *NewMI = + BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64); + NewMI = addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), + FrameIdx); + MachineOperand &MO = NewMI->getOperand(3); + MO.setReg(VirtReg); + MO.setIsKill(true); + } } bool X86InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg, diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -0,0 +1,277 @@ +#include "X86.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsX86.h" +#include "llvm/InitializePasses.h" +#include "llvm/Pass.h" + +using namespace llvm; + +#define DEBUG_TYPE "lower-amx-type" + +namespace { +class X86LowerAMXType { + Function &Func; + const DataLayout &DL; + DenseSet LDSet; + DenseSet STSet; + DenseMap> LoadMap; + +public: + X86LowerAMXType(Function &F) : Func(F), DL(F.getParent()->getDataLayout()) {} + bool Visit(); + bool VisitLD(); + bool VisitST(); + void SplitST(Instruction *Inst); + void SplitLD(Instruction *Inst); +}; + +// Split v256i32 load/store to 2 v128i32, so that ISel can +// lower it to proper vector size. +void X86LowerAMXType::SplitST(Instruction *Inst) { + StoreInst *ST = dyn_cast(Inst); + IRBuilder<> Builder(ST); + LLVMContext &Ctx = Builder.getContext(); + Type *Ty = ST->getValueOperand()->getType(); + EVT VT = EVT::getEVT(Ty); + EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx); + Type *HalfTy = HalfVT.getTypeForEVT(Ctx); + + LoadInst *Lo, *Hi; + std::tie(Lo, Hi) = LoadMap[ST->getValueOperand()]; + Value *Ptr = ST->getPointerOperand(); + PointerType *HalfPtrTy = HalfTy->getPointerTo(ST->getPointerAddressSpace()); + Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy); + // The HW require the alignment for AMX tile is 64, but front-end generate + // code for the vector alignment which is the vector size. + TypeSize HalfTySize = HalfTy->getPrimitiveSizeInBits() / 8; + Align Alignment = std::min(Lo->getAlign(), Align(HalfTySize)); + Builder.CreateAlignedStore(Lo, HalfPtr, Alignment, ST->isVolatile()); + + HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1)); + Builder.CreateAlignedStore(Hi, HalfPtr, Alignment, ST->isVolatile()); +} + +bool X86LowerAMXType::VisitST() { + if (STSet.empty()) + return false; + for (auto *Inst : STSet) { + Value *Row, *Col; + const IntrinsicInst *II = dyn_cast(Inst->getOperand(0)); + if (!II) + Row = Col = nullptr; + else { + switch (II->getIntrinsicID()) { + default: + Row = Col = nullptr; + break; + case Intrinsic::x86_tileloadd64_internal: + case Intrinsic::x86_tdpbssd_internal: { + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + } + } + } + if (!Row) { + SplitST(Inst); + continue; + } + IRBuilder<> Builder(Inst); + LLVMContext &Ctx = Builder.getContext(); + // Use the maximun column as stride. It must be the same with load stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(Inst->getOperand(1), Type::getInt8PtrTy(Ctx)); + std::array Args = {Row, Col, I8Ptr, Stride, + Inst->getOperand(0)}; + + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); + } + return true; +} + +void X86LowerAMXType::SplitLD(Instruction *Inst) { + LoadInst *LD = dyn_cast(Inst); + IRBuilder<> Builder(LD); + LLVMContext &Ctx = Builder.getContext(); + Type *Ty = LD->getType(); + EVT VT = EVT::getEVT(Ty); + EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx); + Type *HalfTy = HalfVT.getTypeForEVT(Ctx); + + Value *Ptr = LD->getPointerOperand(); + PointerType *HalfPtrTy = HalfTy->getPointerTo(LD->getPointerAddressSpace()); + Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy); + // The HW require the alignment for AMX tile is 64, but front-end generate + // code for the vector alignment which is the vector size. + TypeSize HalfTySize = HalfTy->getPrimitiveSizeInBits() / 8; + Align Alignment = std::min(LD->getAlign(), Align(HalfTySize)); + auto *Lo = Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, + LD->isVolatile()); + + HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1)); + auto *Hi = Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, + LD->isVolatile()); + + LoadMap[Inst] = std::make_pair(Lo, Hi); +} + +bool X86LowerAMXType::VisitLD() { + if (LDSet.empty()) + return false; + for (auto &Inst : LDSet) { + int Count = 0; + Value *NewInst = nullptr; + // The user should be all AMX intrinsics or all LLVM instruction. + // Don't support it is used by both AMX intrinsics and LLVM instructions. + for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { + Use &U = *I++; + const IntrinsicInst *II = dyn_cast(U.getUser()); + if (!II) { + Count++; + continue; + } + if (NewInst) + continue; + Value *Row, *Col; + switch (II->getIntrinsicID()) { + default: + report_fatal_error("Non-AMX intrinsic use tile type."); + break; + case Intrinsic::x86_tdpbssd_internal: { + unsigned OpNo = U.getOperandNo(); + switch (OpNo) { + case 3: + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + case 4: + Row = II->getArgOperand(0); + Col = II->getArgOperand(2); + break; + case 5: + Row = II->getArgOperand(2); + Col = II->getArgOperand(1); + break; + } + break; + } + case Intrinsic::x86_tilestored64_internal: { + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + } + } + assert(Count == 0 && "Can NOT mix amx intrinsic and LLVM instruction"); + // FIXME: The shape def should be ahead of load. + IRBuilder<> Builder(Inst); + LLVMContext &Ctx = Builder.getContext(); + // Use the maximun column as stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(Inst->getOperand(0), Type::getInt8PtrTy(Ctx)); + std::array Args = {Row, Col, I8Ptr, Stride}; + + NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, + None, Args); + + Inst->replaceAllUsesWith(NewInst); + } + if (!NewInst) + SplitLD(Inst); + } + return true; +} + +bool X86LowerAMXType::Visit() { + bool C; + auto IsAMXType = [](FixedVectorType *VTy) { + if (!VTy) + return false; + if (!VTy->getScalarType()->isIntegerTy(32)) + return false; + if (VTy->getNumElements() != 256) + return false; + + return true; + }; + + for (BasicBlock &BB : Func) { + for (Instruction &Inst : BB) { + LoadInst *LD = dyn_cast(&Inst); + // Check load instruction. + // %3 = load <256 x i32>, <256 x i32>* %1, align 64 + if (LD) { + FixedVectorType *VTy = dyn_cast(Inst.getType()); + if (!IsAMXType(VTy)) + continue; + LDSet.insert(&Inst); + continue; + } + // Check store instruction. + // store <256 x i32> %3, <256 x i32>* %2, align 64 + StoreInst *ST = dyn_cast(&Inst); + if (!ST) + continue; + FixedVectorType *VTy = + dyn_cast(ST->getOperand(0)->getType()); + if (!IsAMXType(VTy)) + continue; + STSet.insert(&Inst); + } + } + + C = VisitLD() | VisitST(); + for (auto *Inst : STSet) + Inst->eraseFromParent(); + for (auto *Inst : LDSet) + Inst->eraseFromParent(); + return C; +} +} // anonymous namespace + +namespace { + +class X86LowerAMXTypeLegacyPass : public FunctionPass { +public: + static char ID; + + X86LowerAMXTypeLegacyPass() : FunctionPass(ID) { + initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + X86LowerAMXType LAT(F); + bool C = LAT.Visit(); + return C; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } + +private: + Function *F; +}; + +} // anonymous namespace + +static const char pass_name[] = "Lower AMX type for load/store"; +char X86LowerAMXTypeLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, pass_name, false, + false) +INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, pass_name, false, + false) + +FunctionPass *llvm::createX86LowerAMXTypePass() { + return new X86LowerAMXTypeLegacyPass(); +} diff --git a/llvm/lib/Target/X86/X86MachineFunctionInfo.h b/llvm/lib/Target/X86/X86MachineFunctionInfo.h --- a/llvm/lib/Target/X86/X86MachineFunctionInfo.h +++ b/llvm/lib/Target/X86/X86MachineFunctionInfo.h @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/CodeGen/CallingConvLower.h" #include "llvm/CodeGen/MachineFunction.h" +#include "llvm/CodeGen/TileShapeInfo.h" namespace llvm { diff --git a/llvm/lib/Target/X86/X86RegisterInfo.h b/llvm/lib/Target/X86/X86RegisterInfo.h --- a/llvm/lib/Target/X86/X86RegisterInfo.h +++ b/llvm/lib/Target/X86/X86RegisterInfo.h @@ -141,6 +141,11 @@ Register getFramePtr() const { return FramePtr; } // FIXME: Move to FrameInfok unsigned getSlotSize() const { return SlotSize; } + + bool getRegAllocationHints(Register VirtReg, ArrayRef Order, + SmallVectorImpl &Hints, + const MachineFunction &MF, const VirtRegMap *VRM, + const LiveRegMatrix *Matrix) const override; }; } // End llvm namespace diff --git a/llvm/lib/Target/X86/X86RegisterInfo.cpp b/llvm/lib/Target/X86/X86RegisterInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterInfo.cpp @@ -18,6 +18,8 @@ #include "X86Subtarget.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/CodeGen/LiveRegMatrix.h" #include "llvm/CodeGen/MachineFrameInfo.h" #include "llvm/CodeGen/MachineFunction.h" #include "llvm/CodeGen/MachineFunctionPass.h" @@ -812,3 +814,77 @@ StackReg = getX86SubSuperRegister(StackReg, 32); return StackReg; } + +static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM, + const MachineRegisterInfo *MRI) { + if (VRM->hasShape(VirtReg)) + return VRM->getShape(VirtReg); + + const MachineOperand &Def = *MRI->def_begin(VirtReg); + MachineInstr *MI = const_cast(Def.getParent()); + unsigned OpCode = MI->getOpcode(); + switch (OpCode) { + default: + llvm_unreachable("Unexpected machine instruction on tile register!"); + break; + // We only collect the tile shape that is defined. + case X86::PTILELOADDV: + case X86::PTDPBSSDV: + MachineOperand &MO1 = MI->getOperand(1); + MachineOperand &MO2 = MI->getOperand(2); + ShapeT Shape(&MO1, &MO2, MRI); + VRM->assignVirt2Shape(VirtReg, Shape); + return Shape; + } +} + +bool X86RegisterInfo::getRegAllocationHints(Register VirtReg, + ArrayRef Order, + SmallVectorImpl &Hints, + const MachineFunction &MF, + const VirtRegMap *VRM, + const LiveRegMatrix *Matrix) const { + const MachineRegisterInfo *MRI = &MF.getRegInfo(); + const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); + bool BaseImplRetVal = TargetRegisterInfo::getRegAllocationHints( + VirtReg, Order, Hints, MF, VRM, Matrix); + + if (RC.getID() != X86::TILERegClassID) + return BaseImplRetVal; + + ShapeT VirtShape = getTileShape(VirtReg, const_cast(VRM), MRI); + auto AddHint = [&](MCPhysReg PhysReg) { + Register VReg = Matrix->getOneVReg(PhysReg); + if (VReg == MCRegister::NoRegister) { // Not allocated yet + Hints.push_back(PhysReg); + return; + } + ShapeT PhysShape = getTileShape(VReg, const_cast(VRM), MRI); + if (PhysShape == VirtShape) + Hints.push_back(PhysReg); + }; + + SmallSet CopyHints; + CopyHints.insert(Hints.begin(), Hints.end()); + Hints.clear(); + for (auto Hint : CopyHints) { + if (RC.contains(Hint) && !MRI->isReserved(Hint)) + AddHint(Hint); + } + for (MCPhysReg PhysReg : Order) { + if (!MRI->isReserved(PhysReg)) + AddHint(PhysReg); + } + +#define DEBUG_TYPE "tile-hint" + LLVM_DEBUG({ + dbgs() << "Hints for virtual register " << format_hex(VirtReg, 8) << "\n"; + for (auto Hint : Hints) { + dbgs() << "tmm" << Hint << ","; + } + dbgs() << "\n"; + }); +#undef DEBUG_TYPE + + return true; +} diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td --- a/llvm/lib/Target/X86/X86RegisterInfo.td +++ b/llvm/lib/Target/X86/X86RegisterInfo.td @@ -633,6 +633,6 @@ def BNDR : RegisterClass<"X86", [v2i64], 128, (sequence "BND%u", 0, 3)>; // Tiles -let isAllocatable = 0 in -def TILE : RegisterClass<"X86", [untyped], 0, +let CopyCost = -1 in // Don't allow copy tile register +def TILE : RegisterClass<"X86", [v256i32], 8192, (sequence "TMM%u", 0, 7)> {let Size = 8192;} diff --git a/llvm/lib/Target/X86/X86Subtarget.h b/llvm/lib/Target/X86/X86Subtarget.h --- a/llvm/lib/Target/X86/X86Subtarget.h +++ b/llvm/lib/Target/X86/X86Subtarget.h @@ -457,6 +457,8 @@ /// entry to the function and which must be maintained by every function. Align stackAlignment = Align(4); + Align tileConfigAlignment = Align(4); + /// Max. memset / memcpy size that is turned into rep/movs, rep/stos ops. /// // FIXME: this is a known good value for Yonah. How about others? @@ -540,6 +542,9 @@ return &getInstrInfo()->getRegisterInfo(); } + unsigned getTileConfigSize() const { return 64; } + Align getTileConfigAlignment() const { return tileConfigAlignment; } + /// Returns the minimum alignment known to hold of the /// stack frame on entry to the function and which must be maintained by every /// function for this subtarget. diff --git a/llvm/lib/Target/X86/X86TargetMachine.cpp b/llvm/lib/Target/X86/X86TargetMachine.cpp --- a/llvm/lib/Target/X86/X86TargetMachine.cpp +++ b/llvm/lib/Target/X86/X86TargetMachine.cpp @@ -62,6 +62,7 @@ RegisterTargetMachine Y(getTheX86_64Target()); PassRegistry &PR = *PassRegistry::getPassRegistry(); + initializeX86LowerAMXTypeLegacyPassPass(PR); initializeGlobalISel(PR); initializeWinEHStatePassPass(PR); initializeFixupBWInstPassPass(PR); @@ -71,6 +72,7 @@ initializeX86FixupSetCCPassPass(PR); initializeX86CallFrameOptimizationPass(PR); initializeX86CmovConverterPassPass(PR); + initializeX86TileConfigPass(PR); initializeX86ExpandPseudoPass(PR); initializeX86ExecutionDomainFixPass(PR); initializeX86DomainReassignmentPass(PR); @@ -378,6 +380,7 @@ void addPreEmitPass() override; void addPreEmitPass2() override; void addPreSched2() override; + bool addPreRewrite() override; std::unique_ptr getCSEConfig() const override; }; @@ -406,6 +409,7 @@ void X86PassConfig::addIRPasses() { addPass(createAtomicExpandPass()); + addPass(createX86LowerAMXTypePass()); TargetPassConfig::addIRPasses(); @@ -564,6 +568,11 @@ addPass(createX86LoadValueInjectionRetHardeningPass()); } +bool X86PassConfig::addPreRewrite() { + addPass(createX86TileConfigPass()); + return true; +} + std::unique_ptr X86PassConfig::getCSEConfig() const { return getStandardCSEConfigForOpt(TM->getOptLevel()); } diff --git a/llvm/lib/Target/X86/X86TileConfig.cpp b/llvm/lib/Target/X86/X86TileConfig.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/X86/X86TileConfig.cpp @@ -0,0 +1,292 @@ +//===-- X86TileConfig.cpp - Tile Register Configure----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// + +#include "X86.h" +#include "X86InstrBuilder.h" +#include "X86MachineFunctionInfo.h" +#include "X86RegisterInfo.h" +#include "X86Subtarget.h" +#include "llvm/CodeGen/LiveIntervals.h" +#include "llvm/CodeGen/MachineDominators.h" +#include "llvm/CodeGen/MachineFrameInfo.h" +#include "llvm/CodeGen/MachineFunctionPass.h" +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/MachineRegisterInfo.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetInstrInfo.h" +#include "llvm/CodeGen/TargetRegisterInfo.h" +#include "llvm/CodeGen/TileShapeInfo.h" +#include "llvm/CodeGen/VirtRegMap.h" +#include "llvm/InitializePasses.h" + +using namespace llvm; + +#define DEBUG_TYPE "tile-config" + +namespace { + +class X86TileConfig : public MachineFunctionPass { + // context + MachineFunction *MF = nullptr; + const X86Subtarget *ST = nullptr; + const TargetRegisterInfo *TRI; + const TargetInstrInfo *TII; + MachineDominatorTree *DomTree = nullptr; + MachineRegisterInfo *MRI = nullptr; + VirtRegMap *VRM = nullptr; + LiveIntervals *LIS = nullptr; + + MachineInstr &getTileConfigPoint(); + void tileConfig(); + +public: + X86TileConfig() : MachineFunctionPass(ID) {} + + /// Return the pass name. + StringRef getPassName() const override { return "Tile Register Configure"; } + + /// X86TileConfig analysis usage. + void getAnalysisUsage(AnalysisUsage &AU) const override; + + /// Perform register allocation. + bool runOnMachineFunction(MachineFunction &mf) override; + + MachineFunctionProperties getRequiredProperties() const override { + return MachineFunctionProperties().set( + MachineFunctionProperties::Property::NoPHIs); + } + + static char ID; +}; + +} // end anonymous namespace + +char X86TileConfig::ID = 0; + +INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure", + false, false) +INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree) +INITIALIZE_PASS_DEPENDENCY(VirtRegMap) +INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure", + false, false) + +void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired(); + AU.addRequired(); + AU.addPreserved(); + AU.addRequired(); + AU.setPreservesAll(); + MachineFunctionPass::getAnalysisUsage(AU); +} + +static unsigned getTilePhysRegIndex(Register PhysReg) { + assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) && + "Tile register number is invalid"); + return (PhysReg - X86::TMM0); +} + +static MachineInstr *buildConfigMI(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MI, int FrameIdx, + const TargetInstrInfo *TII) { + return addFrameReference( + BuildMI(MBB, MI, DebugLoc(), TII->get(X86::LDTILECFG)), FrameIdx); +} + +static MachineInstr * +storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI, + Register SrcReg, unsigned BitSize, int FrameIdx, int Offset, + const TargetInstrInfo *TII, const TargetRegisterClass *RC, + const TargetRegisterInfo *TRI) { + + unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit; + unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr; + MachineInstr *NewMI = + addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx, + Offset) + .addReg(SrcReg); + MachineOperand &MO = NewMI->getOperand(5); + if (BitSize < TRI->getRegSizeInBits(*RC)) + MO.setSubReg(SubIdx); + return NewMI; +} + +static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB, + MachineBasicBlock::iterator MI, + int64_t Imm, unsigned BitSize, + int FrameIdx, int Offset, + const TargetInstrInfo *TII) { + unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi; + return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), + FrameIdx, Offset) + .addImm(Imm); +} + +MachineInstr &X86TileConfig::getTileConfigPoint() { + DenseMap PhysShapeInfo; + MachineBasicBlock *MBB = nullptr; + DenseSet MIs; + for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { + unsigned VirtReg = Register::index2VirtReg(i); + if (MRI->reg_nodbg_empty(VirtReg)) + continue; + const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); + if (RC.getID() != X86::TILERegClassID) + continue; + + // FIXME: The region split should be done before the Greedy RA. + // Here we assume only one config for all tile registers. + // + // Find the common dominator for all MI that define tile register. + for (const MachineOperand &MO : MRI->def_operands(VirtReg)) { + if (MO.isUndef()) + continue; + const auto *MI = MO.getParent(); + if (!MBB) + MBB = const_cast(MI->getParent()); + MBB = DomTree->findNearestCommonDominator( + MBB, const_cast(MI->getParent())); + } + // Collect the instructions that define shape. + ShapeT Shape = VRM->getShape(VirtReg); + std::array ShapeMOs = {Shape.getRow(), Shape.getCol()}; + for (auto *ShapeMO : ShapeMOs) { + Register ShapeReg = ShapeMO->getReg(); + for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) { + const auto *ShapeMI = MO.getParent(); + MIs.insert(ShapeMI); + } + } + +#if !defined(NDEBUG) + Register PhysReg = VRM->getPhys(VirtReg); + if (PhysShapeInfo.count(PhysReg)) + assert(PhysShapeInfo[PhysReg] == VRM->getShape(VirtReg) && + "The physical register is assigned to virtual registers" + "with different shape"); +#endif + } + // Shape def should dominate tile config MBB. + // TODO: Improve for shape that is immediate. + for (const auto *MI : MIs) { + const MachineBasicBlock *ShapeMBB = MI->getParent(); + if (DomTree->dominates(ShapeMBB, MBB)) + continue; + if (MI->isMoveImmediate()) + continue; + report_fatal_error("Failed to config tile register, " + "please define the shape earlier"); + } + + // ldtilecfg should be inserted after the MI that define the shape. + MachineBasicBlock::reverse_instr_iterator I, E; + for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) { + auto *MI = &*I; + if (MIs.count(MI) && (!MI->isMoveImmediate())) + break; + } + MachineBasicBlock::iterator MII; + if (I == E) + MII = MBB->getFirstNonPHI(); + else { + MII = MachineBasicBlock::iterator(&*I); + MII++; + } + return *MII; +} + +void X86TileConfig::tileConfig() { + MachineInstr &MI = getTileConfigPoint(); + MachineBasicBlock *MBB = MI.getParent(); + // Allocate stack buffer to config + unsigned Size = ST->getTileConfigSize(); + Align Alignment = ST->getTileConfigAlignment(); + + int SS = MF->getFrameInfo().CreateStackObject(Size, Alignment, false); + BitVector PhysRegs(TRI->getNumRegs()); + + // Insert ldtilecfg to the MBB + for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) { + unsigned VirtReg = Register::index2VirtReg(i); + if (MRI->reg_nodbg_empty(VirtReg)) + continue; + const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg); + if (RC.getID() != X86::TILERegClassID) + continue; + Register PhysReg = VRM->getPhys(VirtReg); + if (PhysRegs.test(PhysReg)) + continue; + PhysRegs.set(PhysReg); + ShapeT Shape = VRM->getShape(VirtReg); + Register RowReg = Shape.getRow()->getReg(); + Register ColReg = Shape.getCol()->getReg(); + + unsigned Index = getTilePhysRegIndex(PhysReg); + int RowOffset = 48 + Index; + int ColOffset = 16 + Index * 2; + + unsigned BitSize = 8; + for (const auto &Pair : {std::make_pair(RowReg, RowOffset), + std::make_pair(ColReg, ColOffset)}) { + int64_t Imm; + int ImmCount = 0; + // All def must be the same value, otherwise it is invalid MIs. + // Immediate is prefered. + for (const MachineOperand &MO : MRI->def_operands(Pair.first)) { + const auto *Inst = MO.getParent(); + if (Inst->isMoveImmediate()) { + ImmCount++; + Imm = Inst->getOperand(1).getImm(); + break; + } + } + auto StoreConfig = [&](int Offset) { + MachineInstr *NewMI = nullptr; + if (ImmCount) + NewMI = storeImmToStackSlot(*MBB, MI, Imm, BitSize, SS, Offset, TII); + else { + const TargetRegisterClass *RC = MRI->getRegClass(Pair.first); + NewMI = storeRegToStackSlot(*MBB, MI, Pair.first, BitSize, SS, Offset, + TII, RC, TRI); + } + SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI); + if (!ImmCount) { + // Extend the live interval. + SmallVector EndPoints = {SIdx.getRegSlot()}; + LiveInterval &Int = LIS->getInterval(Pair.first); + LIS->extendToIndices(Int, EndPoints); + } + }; + StoreConfig(Pair.second); + BitSize += 8; + } + } + MachineInstr *NewMI = buildConfigMI(*MBB, MI, SS, TII); + LIS->InsertMachineInstrInMaps(*NewMI); +} + +bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) { + LLVM_DEBUG(dbgs() << "********** TILE REGISTER CONFIGURE**********\n" + << "********** Function: " << mf.getName() << '\n'); + MF = &mf; + MRI = &mf.getRegInfo(); + ST = &mf.getSubtarget(); + TRI = ST->getRegisterInfo(); + TII = mf.getSubtarget().getInstrInfo(); + DomTree = &getAnalysis(); + VRM = &getAnalysis(); + LIS = &getAnalysis(); + + if (VRM->isShapeMapEmpty()) + return false; + + tileConfig(); + return true; +} + +FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); } diff --git a/llvm/test/CodeGen/X86/AMX/amx-config.ll b/llvm/test/CodeGen/X86/AMX/amx-config.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/amx-config.ll @@ -0,0 +1,72 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -verify-machineinstrs | FileCheck %s + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@buf = dso_local global [1024 x i8] zeroinitializer, align 16 +@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16 + +; Function Attrs: nounwind uwtable +define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_unnamed_addr #2 { +; CHECK-LABEL: test_api: +; CHECK: # %bb.0: +; CHECK-NEXT: movsbl %sil, %eax +; CHECK-NEXT: movb %al, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %si, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %al, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %al, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, -{{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg -{{[0-9]+}}(%rsp) +; CHECK-NEXT: testl %edi, %edi +; CHECK-NEXT: je .LBB0_2 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: movl $buf, %ecx +; CHECK-NEXT: jmp .LBB0_3 +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: movl $buf2, %ecx +; CHECK-NEXT: .LBB0_3: +; CHECK-NEXT: movl $32, %edi +; CHECK-NEXT: tileloadd (%rcx,%rdi), %tmm0 +; CHECK-NEXT: tileloadd (%rcx,%rdi), %tmm2 +; CHECK-NEXT: tileloadd (%rcx,%rdi), %tmm1 +; CHECK-NEXT: tdpbssd %tmm2, %tmm0, %tmm1 +; CHECK-NEXT: movl $buf, %ecx +; CHECK-NEXT: movl $32, %esi +; CHECK-NEXT: tilestored %tmm1, (%rcx,%rsi) +; CHECK-NEXT: retq + %4 = icmp eq i32 %0, 0 + %5 = shl i16 %1, 8 + %6 = ashr exact i16 %5, 8 + br i1 %4, label %11, label %7 + +7: ; preds = %3 + %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + br label %15 + +11: ; preds = %3 + %12 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + br label %15 + +15: ; preds = %11, %7 + %16 = phi <256 x i32> [ %12, %11 ], [ %8, %7 ] + %17 = phi <256 x i32> [ %13, %11 ], [ %9, %7 ] + %18 = phi <256 x i32> [ %14, %11 ], [ %10, %7 ] + %19 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, <256 x i32> %18, <256 x i32> %16, <256 x i32> %17) #3 + tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %19) #3 + ret void +} + +declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 + +declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3 + +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3 + +attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { nounwind } diff --git a/llvm/test/CodeGen/X86/AMX/amx-spill.ll b/llvm/test/CodeGen/X86/AMX/amx-spill.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/amx-spill.ll @@ -0,0 +1,107 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -verify-machineinstrs | FileCheck %s + +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@buf = dso_local global [1024 x i8] zeroinitializer, align 16 +@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16 + +define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_unnamed_addr #2 { +; CHECK-LABEL: test_api: +; CHECK: # %bb.0: +; CHECK-NEXT: subq $2936, %rsp # imm = 0xB78 +; CHECK-NEXT: .cfi_def_cfa_offset 2944 +; CHECK-NEXT: movb %dl, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %dl, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %sil, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %sil, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %dl, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %dl, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %sil, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %si, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movb %sil, {{[0-9]+}}(%rsp) +; CHECK-NEXT: movw %dx, {{[0-9]+}}(%rsp) +; CHECK-NEXT: ldtilecfg {{[0-9]+}}(%rsp) +; CHECK-NEXT: movl $buf, %r8d +; CHECK-NEXT: movl $32, %eax +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm1 +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm1 +; CHECK-NEXT: movabsq $64, %rcx +; CHECK-NEXT: tilestored %tmm1, 896(%rsp,%rcx) # 1024-byte Folded Spill +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm3 +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm4 +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm2 +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm5 +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm0 +; CHECK-NEXT: testl %edi, %edi +; CHECK-NEXT: je .LBB0_2 +; CHECK-NEXT: # %bb.1: +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm6 +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm7 +; CHECK-NEXT: tileloadd (%r8,%rax), %tmm1 +; CHECK-NEXT: jmp .LBB0_3 +; CHECK-NEXT: .LBB0_2: +; CHECK-NEXT: movl $buf2, %ecx +; CHECK-NEXT: tileloadd (%rcx,%rax), %tmm6 +; CHECK-NEXT: tileloadd (%rcx,%rax), %tmm7 +; CHECK-NEXT: tileloadd (%rcx,%rax), %tmm1 +; CHECK-NEXT: .LBB0_3: +; CHECK-NEXT: tdpbssd %tmm7, %tmm6, %tmm1 +; CHECK-NEXT: movabsq $64, %rax +; CHECK-NEXT: tileloadd 896(%rsp,%rax), %tmm7 # 1024-byte Folded Reload +; CHECK-NEXT: tdpbssd %tmm7, %tmm1, %tmm3 +; CHECK-NEXT: tdpbssd %tmm4, %tmm3, %tmm2 +; CHECK-NEXT: tdpbssd %tmm5, %tmm2, %tmm0 +; CHECK-NEXT: movl $buf, %eax +; CHECK-NEXT: movl $32, %ecx +; CHECK-NEXT: tilestored %tmm0, (%rax,%rcx) +; CHECK-NEXT: addq $2936, %rsp # imm = 0xB78 +; CHECK-NEXT: .cfi_def_cfa_offset 8 +; CHECK-NEXT: retq + %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %6 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %7 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %11 = icmp eq i32 %0, 0 + br i1 %11, label %16, label %12 + +12: ; preds = %3 + %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %15 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + br label %20 + +16: ; preds = %3 + %17 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %18 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %19 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + br label %20 + +20: ; preds = %16, %12 + %21 = phi <256 x i32> [ %17, %16 ], [ %13, %12 ] + %22 = phi <256 x i32> [ %18, %16 ], [ %14, %12 ] + %23 = phi <256 x i32> [ %19, %16 ], [ %15, %12 ] + %24 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, <256 x i32> %23, <256 x i32> %21, <256 x i32> %22) #3 + %25 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %6, <256 x i32> %24, <256 x i32> %5) #3 + %26 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %8, <256 x i32> %25, <256 x i32> %7) #3 + %27 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, <256 x i32> %10, <256 x i32> %26, <256 x i32> %9) #3 + tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %27) #3 + ret void +} + +declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 +declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3 +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3 + +attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { nounwind } diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -0,0 +1,143 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -lower-amx-type %s -S | FileCheck %s +target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +%struct.__tile_str = type { i16, i16, <256 x i32> } + +@buf = dso_local global [1024 x i8] zeroinitializer, align 16 +@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16 + +define dso_local void @test_load(i8* %in, i8* %out) local_unnamed_addr #2 { +; CHECK-LABEL: @test_load( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[IN:%.*]] to <256 x i32>* +; CHECK-NEXT: [[TMP2:%.*]] = bitcast i8* [[OUT:%.*]] to <256 x i32>* +; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[TMP1]] to <128 x i32>* +; CHECK-NEXT: [[TMP4:%.*]] = load <128 x i32>, <128 x i32>* [[TMP3]], align 64 +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP3]], i32 1 +; CHECK-NEXT: [[TMP6:%.*]] = load <128 x i32>, <128 x i32>* [[TMP5]], align 64 +; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[TMP2]] to <128 x i32>* +; CHECK-NEXT: store <128 x i32> [[TMP4]], <128 x i32>* [[TMP7]], align 64 +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP7]], i32 1 +; CHECK-NEXT: store <128 x i32> [[TMP6]], <128 x i32>* [[TMP8]], align 64 +; CHECK-NEXT: ret void +; + %1 = bitcast i8* %in to <256 x i32>* + %2 = bitcast i8* %out to <256 x i32>* + %3 = load <256 x i32>, <256 x i32>* %1, align 64, !tbaa !8 + store <256 x i32> %3, <256 x i32>* %2, align 64, !tbaa !8 + ret void +} + +define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr #0 { +; CHECK-LABEL: @__tile_loadd( +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2:!tbaa !.*]] +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1 +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7:!tbaa !.*]] +; CHECK-NEXT: [[TMP8:%.*]] = shl i64 [[TMP2:%.*]], 32 +; CHECK-NEXT: [[TMP9:%.*]] = ashr exact i64 [[TMP8]], 32 +; CHECK-NEXT: [[TMP10:%.*]] = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3:#.*]] +; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2 +; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, <256 x i32> [[TMP10]]) +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0 + %5 = load i16, i16* %4, align 64, !tbaa !2 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 1 + %7 = load i16, i16* %6, align 2, !tbaa !7 + %8 = shl i64 %2, 32 + %9 = ashr exact i64 %8, 32 + %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3 + %11 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2 + store <256 x i32> %10, <256 x i32>* %11, align 64, !tbaa !8 + ret void +} + +define dso_local void @__tile_dpbsud(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #0 { +; CHECK-LABEL: @__tile_dpbsud( +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]] +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1 +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]] +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1 +; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA7]] +; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2 +; CHECK-NEXT: [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8* +; CHECK-NEXT: [[TMP12:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64) +; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2 +; CHECK-NEXT: [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8* +; CHECK-NEXT: [[TMP15:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64) +; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 +; CHECK-NEXT: [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8* +; CHECK-NEXT: [[TMP18:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64) +; CHECK-NEXT: [[TMP19:%.*]] = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], <256 x i32> [[TMP12]], <256 x i32> [[TMP15]], <256 x i32> [[TMP18]]) [[ATTR3]] +; CHECK-NEXT: [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, <256 x i32> [[TMP19]]) +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0 + %5 = load i16, i16* %4, align 64, !tbaa !2 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1 + %7 = load i16, i16* %6, align 2, !tbaa !7 + %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 1 + %9 = load i16, i16* %8, align 2, !tbaa !7 + %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2 + %11 = load <256 x i32>, <256 x i32>* %10, align 64, !tbaa !8 + %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2 + %13 = load <256 x i32>, <256 x i32>* %12, align 64, !tbaa !8 + %14 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 + %15 = load <256 x i32>, <256 x i32>* %14, align 64, !tbaa !8 + %16 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, <256 x i32> %11, <256 x i32> %13, <256 x i32> %15) #3 + store <256 x i32> %16, <256 x i32>* %10, align 64, !tbaa !8 + ret void +} + +define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #1 { +; CHECK-LABEL: @__tile_stored( +; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0 +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]] +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1 +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]] +; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 +; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP8]] to i8* +; CHECK-NEXT: [[TMP10:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64) +; CHECK-NEXT: [[TMP11:%.*]] = shl i64 [[TMP1:%.*]], 32 +; CHECK-NEXT: [[TMP12:%.*]] = ashr exact i64 [[TMP11]], 32 +; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], <256 x i32> [[TMP10]]) [[ATTR3]] +; CHECK-NEXT: ret void +; + %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0 + %5 = load i16, i16* %4, align 64, !tbaa !2 + %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1 + %7 = load i16, i16* %6, align 2, !tbaa !7 + %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 + %9 = load <256 x i32>, <256 x i32>* %8, align 64, !tbaa !8 + %10 = shl i64 %1, 32 + %11 = ashr exact i64 %10, 32 + tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %11, <256 x i32> %9) #3 + ret void +} + +declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 +declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3 +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3 + +attributes #0 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #1 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #2 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } +attributes #3 = { nounwind } + +!llvm.module.flags = !{!0} +!llvm.ident = !{!1} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{!"clang version 12.0.0 (ssh://git-amr-1.devtools.intel.com:29418/dpd_icl-llvm_project_worldread f3c78a3f053379a2511e00e9ce2c13383ea3f835)"} +!2 = !{!3, !4, i64 0} +!3 = !{!"__tile_str", !4, i64 0, !4, i64 2, !5, i64 1024} +!4 = !{!"short", !5, i64 0} +!5 = !{!"omnipotent char", !6, i64 0} +!6 = !{!"Simple C/C++ TBAA"} +!7 = !{!3, !4, i64 2} +!8 = !{!5, !5, i64 0} diff --git a/llvm/test/CodeGen/X86/O0-pipeline.ll b/llvm/test/CodeGen/X86/O0-pipeline.ll --- a/llvm/test/CodeGen/X86/O0-pipeline.ll +++ b/llvm/test/CodeGen/X86/O0-pipeline.ll @@ -18,6 +18,7 @@ ; CHECK-NEXT: Pre-ISel Intrinsic Lowering ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Expand Atomic instructions +; CHECK-NEXT: Lower AMX type for load/store ; CHECK-NEXT: Module Verifier ; CHECK-NEXT: Lower Garbage Collection Instructions ; CHECK-NEXT: Shadow Stack GC Lowering diff --git a/llvm/test/CodeGen/X86/opt-pipeline.ll b/llvm/test/CodeGen/X86/opt-pipeline.ll --- a/llvm/test/CodeGen/X86/opt-pipeline.ll +++ b/llvm/test/CodeGen/X86/opt-pipeline.ll @@ -24,6 +24,7 @@ ; CHECK-NEXT: Pre-ISel Intrinsic Lowering ; CHECK-NEXT: FunctionPass Manager ; CHECK-NEXT: Expand Atomic instructions +; CHECK-NEXT: Lower AMX type for load/store ; CHECK-NEXT: Module Verifier ; CHECK-NEXT: Dominator Tree Construction ; CHECK-NEXT: Basic Alias Analysis (stateless AA impl) @@ -141,6 +142,7 @@ ; CHECK-NEXT: Lazy Machine Block Frequency Analysis ; CHECK-NEXT: Machine Optimization Remark Emitter ; CHECK-NEXT: Greedy Register Allocator +; CHECK-NEXT: Tile Register Configure ; CHECK-NEXT: Virtual Register Rewriter ; CHECK-NEXT: Stack Slot Coloring ; CHECK-NEXT: Machine Copy Propagation Pass diff --git a/llvm/utils/TableGen/IntrinsicEmitter.cpp b/llvm/utils/TableGen/IntrinsicEmitter.cpp --- a/llvm/utils/TableGen/IntrinsicEmitter.cpp +++ b/llvm/utils/TableGen/IntrinsicEmitter.cpp @@ -246,7 +246,8 @@ IIT_SUBDIVIDE4_ARG = 45, IIT_VEC_OF_BITCASTS_TO_INT = 46, IIT_V128 = 47, - IIT_BF16 = 48 + IIT_BF16 = 48, + IIT_V256 = 49 }; static void EncodeFixedValueType(MVT::SimpleValueType VT, @@ -384,6 +385,7 @@ case 32: Sig.push_back(IIT_V32); break; case 64: Sig.push_back(IIT_V64); break; case 128: Sig.push_back(IIT_V128); break; + case 256: Sig.push_back(IIT_V256); break; case 512: Sig.push_back(IIT_V512); break; case 1024: Sig.push_back(IIT_V1024); break; }