diff --git a/clang/test/CodeGen/X86/amx_api.c b/clang/test/CodeGen/X86/amx_api.c --- a/clang/test/CodeGen/X86/amx_api.c +++ b/clang/test/CodeGen/X86/amx_api.c @@ -11,8 +11,8 @@ // This is an example code and integration test. void test_api(int cond, short row, short col) { //CHECK-LABEL: @test_api - //CHECK: call <256 x i32> @llvm.x86.tileloadd64.internal - //CHECK: call <256 x i32> @llvm.x86.tdpbssd.internal + //CHECK: call x86_amx @llvm.x86.tileloadd64.internal + //CHECK: call x86_amx @llvm.x86.tdpbssd.internal //CHECK: call void @llvm.x86.tilestored64.internal __tile1024i a = {row, 8}; __tile1024i b = {8, col}; @@ -33,19 +33,22 @@ void test_tile_loadd(short row, short col) { //CHECK-LABEL: @test_tile_loadd - //CHECK: call <256 x i32> @llvm.x86.tileloadd64.internal + //CHECK: call x86_amx @llvm.x86.tileloadd64.internal + //CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32> __tile1024i a = {row, col}; __tile_loadd(&a, buf, STRIDE); } void test_tile_dpbsud(__tile1024i a, __tile1024i b, __tile1024i c) { //CHECK-LABEL: @test_tile_dpbsud - //CHECK: call <256 x i32> @llvm.x86.tdpbssd.internal + //CHECK: call x86_amx @llvm.x86.tdpbssd.internal + //CHECK-NEXT: {{%.*}} = bitcast x86_amx {{%.*}} to <256 x i32> __tile_dpbsud(&c, a, b); } void test_tile_stored(__tile1024i c) { //CHECK-LABEL: @test_tile_stored - //CHECK: call void @llvm.x86.tilestored64.internal + //CHECK: {{%.*}} = bitcast <256 x i32> {{%.*}} to x86_amx + //CHECK-NEXT: call void @llvm.x86.tilestored64.internal __tile_stored(buf, STRIDE, c); } diff --git a/llvm/include/llvm-c/Core.h b/llvm/include/llvm-c/Core.h --- a/llvm/include/llvm-c/Core.h +++ b/llvm/include/llvm-c/Core.h @@ -160,6 +160,7 @@ LLVMVectorTypeKind, /**< Fixed width SIMD vector type */ LLVMMetadataTypeKind, /**< Metadata */ LLVMX86_MMXTypeKind, /**< X86 MMX */ + LLVMX86_AMXTypeKind, /**< X86 AMX */ LLVMTokenTypeKind, /**< Tokens */ LLVMScalableVectorTypeKind, /**< Scalable SIMD vector type */ LLVMBFloatTypeKind /**< 16 bit brain floating point type */ @@ -1493,6 +1494,11 @@ */ LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C); +/** + * Create a X86 AMX type in a context. + */ +LLVMTypeRef LLVMX86AMXTypeInContext(LLVMContextRef C); + /** * Create a token type in a context. */ @@ -1510,6 +1516,7 @@ LLVMTypeRef LLVMVoidType(void); LLVMTypeRef LLVMLabelType(void); LLVMTypeRef LLVMX86MMXType(void); +LLVMTypeRef LLVMX86AMXType(void); /** * @} diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h --- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h +++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h @@ -168,7 +168,8 @@ TYPE_CODE_TOKEN = 22, // TOKEN - TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT + TYPE_CODE_BFLOAT = 23, // BRAIN FLOATING POINT + TYPE_CODE_X86_AMX = 24 // X86 AMX }; enum OperandBundleTagCode { diff --git a/llvm/include/llvm/CodeGen/ValueTypes.td b/llvm/include/llvm/CodeGen/ValueTypes.td --- a/llvm/include/llvm/CodeGen/ValueTypes.td +++ b/llvm/include/llvm/CodeGen/ValueTypes.td @@ -196,6 +196,7 @@ def exnref : ValueType<0 , 161>; // WebAssembly's exnref type def funcref : ValueType<0 , 162>; // WebAssembly's funcref type def externref : ValueType<0 , 163>; // WebAssembly's externref type +def x86amx : ValueType<8192, 164>; // X86 AMX value def token : ValueType<0 , 248>; // TokenTy diff --git a/llvm/include/llvm/IR/DataLayout.h b/llvm/include/llvm/IR/DataLayout.h --- a/llvm/include/llvm/IR/DataLayout.h +++ b/llvm/include/llvm/IR/DataLayout.h @@ -690,6 +690,8 @@ case Type::PPC_FP128TyID: case Type::FP128TyID: return TypeSize::Fixed(128); + case Type::X86_AMXTyID: + return TypeSize::Fixed(8192); // In memory objects this is always aligned to a higher boundary, but // only 80 bits contain information. case Type::X86_FP80TyID: diff --git a/llvm/include/llvm/IR/Intrinsics.h b/llvm/include/llvm/IR/Intrinsics.h --- a/llvm/include/llvm/IR/Intrinsics.h +++ b/llvm/include/llvm/IR/Intrinsics.h @@ -125,7 +125,8 @@ VecElementArgument, Subdivide2Argument, Subdivide4Argument, - VecOfBitcastsToInt + VecOfBitcastsToInt, + AMX } Kind; union { diff --git a/llvm/include/llvm/IR/Intrinsics.td b/llvm/include/llvm/IR/Intrinsics.td --- a/llvm/include/llvm/IR/Intrinsics.td +++ b/llvm/include/llvm/IR/Intrinsics.td @@ -255,6 +255,8 @@ def llvm_x86mmx_ty : LLVMType; def llvm_ptrx86mmx_ty : LLVMPointerType; // <1 x i64>* +def llvm_x86amx_ty : LLVMType; + def llvm_v2i1_ty : LLVMType; // 2 x i1 def llvm_v4i1_ty : LLVMType; // 4 x i1 def llvm_v8i1_ty : LLVMType; // 8 x i1 diff --git a/llvm/include/llvm/IR/IntrinsicsX86.td b/llvm/include/llvm/IR/IntrinsicsX86.td --- a/llvm/include/llvm/IR/IntrinsicsX86.td +++ b/llvm/include/llvm/IR/IntrinsicsX86.td @@ -5041,6 +5041,22 @@ Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty], [ImmArg>, ImmArg>, ImmArg>]>; + // AMX - internal intrinsics + def int_x86_tileloadd64_internal : + GCCBuiltin<"__builtin_ia32_tileloadd64_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty], + []>; + def int_x86_tdpbssd_internal : + GCCBuiltin<"__builtin_ia32_tdpbssd_internal">, + Intrinsic<[llvm_x86amx_ty], + [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, + llvm_x86amx_ty, llvm_x86amx_ty, + llvm_x86amx_ty], []>; + def int_x86_tilestored64_internal : + GCCBuiltin<"__builtin_ia32_tilestored64_internal">, + Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, + llvm_i64_ty, llvm_x86amx_ty], []>; } //===----------------------------------------------------------------------===// @@ -5055,20 +5071,4 @@ Intrinsic<[llvm_i8_ty], [], []>; def int_x86_senduipi : GCCBuiltin<"__builtin_ia32_senduipi">, Intrinsic<[], [llvm_i64_ty], []>; -// AMX - internal intrinsics - def int_x86_tileloadd64_internal : - GCCBuiltin<"__builtin_ia32_tileloadd64_internal">, - Intrinsic<[llvm_v256i32_ty], - [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty], - []>; - def int_x86_tdpbssd_internal : - GCCBuiltin<"__builtin_ia32_tdpbssd_internal">, - Intrinsic<[llvm_v256i32_ty], - [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty, - llvm_v256i32_ty, llvm_v256i32_ty, - llvm_v256i32_ty], []>; - def int_x86_tilestored64_internal : - GCCBuiltin<"__builtin_ia32_tilestored64_internal">, - Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, - llvm_i64_ty, llvm_v256i32_ty], []>; } diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -65,6 +65,7 @@ LabelTyID, ///< Labels MetadataTyID, ///< Metadata X86_MMXTyID, ///< MMX vectors (64 bits, X86 specific) + X86_AMXTyID, ///< AMX vectors (8192 bits, X86 specific) TokenTyID, ///< Tokens // Derived types... see DerivedTypes.h file. @@ -182,6 +183,9 @@ /// Return true if this is X86 MMX. bool isX86_MMXTy() const { return getTypeID() == X86_MMXTyID; } + /// Return true if this is X86 AMX. + bool isX86_AMXTy() const { return getTypeID() == X86_AMXTyID; } + /// Return true if this is a FP type or a vector of FP. bool isFPOrFPVectorTy() const { return getScalarType()->isFloatingPointTy(); } @@ -252,7 +256,7 @@ /// includes all first-class types except struct and array types. bool isSingleValueType() const { return isFloatingPointTy() || isX86_MMXTy() || isIntegerTy() || - isPointerTy() || isVectorTy(); + isPointerTy() || isVectorTy() || isX86_AMXTy(); } /// Return true if the type is an aggregate type. This means it is valid as @@ -268,8 +272,8 @@ bool isSized(SmallPtrSetImpl *Visited = nullptr) const { // If it's a primitive, it is always sized. if (getTypeID() == IntegerTyID || isFloatingPointTy() || - getTypeID() == PointerTyID || - getTypeID() == X86_MMXTyID) + getTypeID() == PointerTyID || getTypeID() == X86_MMXTyID || + getTypeID() == X86_AMXTyID) return true; // If it is not something that can have a size (e.g. a function or label), // it doesn't have a size. @@ -405,6 +409,7 @@ static Type *getFP128Ty(LLVMContext &C); static Type *getPPC_FP128Ty(LLVMContext &C); static Type *getX86_MMXTy(LLVMContext &C); + static Type *getX86_AMXTy(LLVMContext &C); static Type *getTokenTy(LLVMContext &C); static IntegerType *getIntNTy(LLVMContext &C, unsigned N); static IntegerType *getInt1Ty(LLVMContext &C); @@ -460,6 +465,7 @@ static PointerType *getFP128PtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getPPC_FP128PtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getX86_MMXPtrTy(LLVMContext &C, unsigned AS = 0); + static PointerType *getX86_AMXPtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getIntNPtrTy(LLVMContext &C, unsigned N, unsigned AS = 0); static PointerType *getInt1PtrTy(LLVMContext &C, unsigned AS = 0); static PointerType *getInt8PtrTy(LLVMContext &C, unsigned AS = 0); diff --git a/llvm/include/llvm/Support/MachineValueType.h b/llvm/include/llvm/Support/MachineValueType.h --- a/llvm/include/llvm/Support/MachineValueType.h +++ b/llvm/include/llvm/Support/MachineValueType.h @@ -247,9 +247,10 @@ exnref = 161, // WebAssembly's exnref type funcref = 162, // WebAssembly's funcref type externref = 163, // WebAssembly's externref type + x86amx = 164, // This is an X86 AMX value FIRST_VALUETYPE = 1, // This is always the beginning of the list. - LAST_VALUETYPE = 164, // This always remains at the end of the list. + LAST_VALUETYPE = 165, // This always remains at the end of the list. // This is the current maximum for LAST_VALUETYPE. // MVT::MAX_ALLOWED_VALUETYPE is used for asserts and to size bit vectors @@ -966,6 +967,7 @@ case v256i32: case v128i64: case v256f32: + case x86amx: case v128f64: return TypeSize::Fixed(8192); case v512i32: case v256i64: diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp --- a/llvm/lib/Analysis/ConstantFolding.cpp +++ b/llvm/lib/Analysis/ConstantFolding.cpp @@ -105,9 +105,9 @@ "Invalid constantexpr bitcast!"); // Catch the obvious splat cases. - if (C->isNullValue() && !DestTy->isX86_MMXTy()) + if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy()) return Constant::getNullValue(DestTy); - if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && + if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() && !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types! return Constant::getAllOnesValue(DestTy); @@ -358,12 +358,13 @@ // Catch the obvious splat cases (since all-zeros can coerce non-integral // pointers legally). - if (C->isNullValue() && !DestTy->isX86_MMXTy()) + if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy()) return Constant::getNullValue(DestTy); if (C->isAllOnesValue() && (DestTy->isIntegerTy() || DestTy->isFloatingPointTy() || DestTy->isVectorTy()) && - !DestTy->isX86_MMXTy() && !DestTy->isPtrOrPtrVectorTy()) + !DestTy->isX86_AMXTy() && !DestTy->isX86_MMXTy() && + !DestTy->isPtrOrPtrVectorTy()) // Get ones when the input is trivial, but // only for supported types inside getAllOnesValue. return Constant::getAllOnesValue(DestTy); @@ -575,14 +576,16 @@ C = FoldBitCast(C, MapTy->getPointerTo(AS), DL); if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, MapTy, DL)) { - if (Res->isNullValue() && !LoadTy->isX86_MMXTy()) + if (Res->isNullValue() && !LoadTy->isX86_MMXTy() && + !LoadTy->isX86_AMXTy()) // Materializing a zero can be done trivially without a bitcast return Constant::getNullValue(LoadTy); Type *CastTy = LoadTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(LoadTy) : LoadTy; Res = FoldBitCast(Res, CastTy, DL); if (LoadTy->isPtrOrPtrVectorTy()) { // For vector of pointer, we needed to first convert to a vector of integer, then do vector inttoptr - if (Res->isNullValue() && !LoadTy->isX86_MMXTy()) + if (Res->isNullValue() && !LoadTy->isX86_MMXTy() && + !LoadTy->isX86_AMXTy()) return Constant::getNullValue(LoadTy); if (DL.isNonIntegralPointerType(LoadTy->getScalarType())) // Be careful not to replace a load of an addrspace value with an inttoptr here diff --git a/llvm/lib/AsmParser/LLLexer.cpp b/llvm/lib/AsmParser/LLLexer.cpp --- a/llvm/lib/AsmParser/LLLexer.cpp +++ b/llvm/lib/AsmParser/LLLexer.cpp @@ -840,6 +840,7 @@ TYPEKEYWORD("label", Type::getLabelTy(Context)); TYPEKEYWORD("metadata", Type::getMetadataTy(Context)); TYPEKEYWORD("x86_mmx", Type::getX86_MMXTy(Context)); + TYPEKEYWORD("x86_amx", Type::getX86_AMXTy(Context)); TYPEKEYWORD("token", Type::getTokenTy(Context)); #undef TYPEKEYWORD diff --git a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp --- a/llvm/lib/Bitcode/Reader/BitcodeReader.cpp +++ b/llvm/lib/Bitcode/Reader/BitcodeReader.cpp @@ -1763,6 +1763,9 @@ case bitc::TYPE_CODE_X86_MMX: // X86_MMX ResultTy = Type::getX86_MMXTy(Context); break; + case bitc::TYPE_CODE_X86_AMX: // X86_AMX + ResultTy = Type::getX86_AMXTy(Context); + break; case bitc::TYPE_CODE_TOKEN: // TOKEN ResultTy = Type::getTokenTy(Context); break; diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -913,6 +913,7 @@ case Type::LabelTyID: Code = bitc::TYPE_CODE_LABEL; break; case Type::MetadataTyID: Code = bitc::TYPE_CODE_METADATA; break; case Type::X86_MMXTyID: Code = bitc::TYPE_CODE_X86_MMX; break; + case Type::X86_AMXTyID: Code = bitc::TYPE_CODE_X86_AMX; break; case Type::TokenTyID: Code = bitc::TYPE_CODE_TOKEN; break; case Type::IntegerTyID: // INTEGER: [width] diff --git a/llvm/lib/CodeGen/ValueTypes.cpp b/llvm/lib/CodeGen/ValueTypes.cpp --- a/llvm/lib/CodeGen/ValueTypes.cpp +++ b/llvm/lib/CodeGen/ValueTypes.cpp @@ -164,6 +164,7 @@ case MVT::Other: return "ch"; case MVT::Glue: return "glue"; case MVT::x86mmx: return "x86mmx"; + case MVT::x86amx: return "x86amx"; case MVT::Metadata: return "Metadata"; case MVT::Untyped: return "Untyped"; case MVT::exnref: return "exnref"; @@ -195,6 +196,7 @@ case MVT::f128: return Type::getFP128Ty(Context); case MVT::ppcf128: return Type::getPPC_FP128Ty(Context); case MVT::x86mmx: return Type::getX86_MMXTy(Context); + case MVT::x86amx: return Type::getX86_AMXTy(Context); case MVT::v1i1: return FixedVectorType::get(Type::getInt1Ty(Context), 1); case MVT::v2i1: @@ -501,6 +503,7 @@ case Type::DoubleTyID: return MVT(MVT::f64); case Type::X86_FP80TyID: return MVT(MVT::f80); case Type::X86_MMXTyID: return MVT(MVT::x86mmx); + case Type::X86_AMXTyID: return MVT(MVT::x86amx); case Type::FP128TyID: return MVT(MVT::f128); case Type::PPC_FP128TyID: return MVT(MVT::ppcf128); case Type::PointerTyID: return MVT(MVT::iPTR); diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -609,6 +609,7 @@ case Type::LabelTyID: OS << "label"; return; case Type::MetadataTyID: OS << "metadata"; return; case Type::X86_MMXTyID: OS << "x86_mmx"; return; + case Type::X86_AMXTyID: OS << "x86_amx"; return; case Type::TokenTyID: OS << "token"; return; case Type::IntegerTyID: OS << 'i' << cast(Ty)->getBitWidth(); diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -535,7 +535,7 @@ return UndefValue::get(DestTy); } - if (V->isNullValue() && !DestTy->isX86_MMXTy() && + if (V->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() && opc != Instruction::AddrSpaceCast) return Constant::getNullValue(DestTy); diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -512,6 +512,8 @@ return LLVMVectorTypeKind; case Type::X86_MMXTyID: return LLVMX86_MMXTypeKind; + case Type::X86_AMXTyID: + return LLVMX86_AMXTypeKind; case Type::TokenTyID: return LLVMTokenTypeKind; case Type::ScalableVectorTyID: @@ -623,6 +625,9 @@ LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C) { return (LLVMTypeRef) Type::getX86_MMXTy(*unwrap(C)); } +LLVMTypeRef LLVMX86AMXTypeInContext(LLVMContextRef C) { + return (LLVMTypeRef) Type::getX86_AMXTy(*unwrap(C)); +} LLVMTypeRef LLVMHalfType(void) { return LLVMHalfTypeInContext(LLVMGetGlobalContext()); @@ -648,6 +653,9 @@ LLVMTypeRef LLVMX86MMXType(void) { return LLVMX86MMXTypeInContext(LLVMGetGlobalContext()); } +LLVMTypeRef LLVMX86AMXType(void) { + return LLVMX86AMXTypeInContext(LLVMGetGlobalContext()); +} /*--.. Operations on function types ........................................--*/ diff --git a/llvm/lib/IR/DataLayout.cpp b/llvm/lib/IR/DataLayout.cpp --- a/llvm/lib/IR/DataLayout.cpp +++ b/llvm/lib/IR/DataLayout.cpp @@ -810,6 +810,8 @@ Alignment = PowerOf2Ceil(Alignment); return Align(Alignment); } + case Type::X86_AMXTyID: + return Align(64); default: llvm_unreachable("Bad type for getAlignment!!!"); } diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp --- a/llvm/lib/IR/Function.cpp +++ b/llvm/lib/IR/Function.cpp @@ -764,6 +764,7 @@ case Type::FP128TyID: Result += "f128"; break; case Type::PPC_FP128TyID: Result += "ppcf128"; break; case Type::X86_MMXTyID: Result += "x86mmx"; break; + case Type::X86_AMXTyID: Result += "x86amx"; break; case Type::IntegerTyID: Result += "i" + utostr(cast(Ty)->getBitWidth()); break; @@ -848,7 +849,8 @@ IIT_V128 = 47, IIT_BF16 = 48, IIT_STRUCT9 = 49, - IIT_V256 = 50 + IIT_V256 = 50, + IIT_AMX = 51 }; static void DecodeIITType(unsigned &NextElt, ArrayRef Infos, @@ -871,6 +873,9 @@ case IIT_MMX: OutputTable.push_back(IITDescriptor::get(IITDescriptor::MMX, 0)); return; + case IIT_AMX: + OutputTable.push_back(IITDescriptor::get(IITDescriptor::AMX, 0)); + return; case IIT_TOKEN: OutputTable.push_back(IITDescriptor::get(IITDescriptor::Token, 0)); return; @@ -1108,6 +1113,7 @@ case IITDescriptor::Void: return Type::getVoidTy(Context); case IITDescriptor::VarArg: return Type::getVoidTy(Context); case IITDescriptor::MMX: return Type::getX86_MMXTy(Context); + case IITDescriptor::AMX: return Type::getX86_AMXTy(Context); case IITDescriptor::Token: return Type::getTokenTy(Context); case IITDescriptor::Metadata: return Type::getMetadataTy(Context); case IITDescriptor::Half: return Type::getHalfTy(Context); @@ -1287,6 +1293,7 @@ case IITDescriptor::Void: return !Ty->isVoidTy(); case IITDescriptor::VarArg: return true; case IITDescriptor::MMX: return !Ty->isX86_MMXTy(); + case IITDescriptor::AMX: return !Ty->isX86_AMXTy(); case IITDescriptor::Token: return !Ty->isTokenTy(); case IITDescriptor::Metadata: return !Ty->isMetadataTy(); case IITDescriptor::Half: return !Ty->isHalfTy(); diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -1418,7 +1418,7 @@ // Basic type instances. Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy, TokenTy; - Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy; + Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy, X86_AMXTy; IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty; BumpPtrAllocator Alloc; diff --git a/llvm/lib/IR/LLVMContextImpl.cpp b/llvm/lib/IR/LLVMContextImpl.cpp --- a/llvm/lib/IR/LLVMContextImpl.cpp +++ b/llvm/lib/IR/LLVMContextImpl.cpp @@ -35,6 +35,7 @@ FP128Ty(C, Type::FP128TyID), PPC_FP128Ty(C, Type::PPC_FP128TyID), X86_MMXTy(C, Type::X86_MMXTyID), + X86_AMXTy(C, Type::X86_AMXTyID), Int1Ty(C, 1), Int8Ty(C, 8), Int16Ty(C, 16), diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -49,6 +49,7 @@ case LabelTyID : return getLabelTy(C); case MetadataTyID : return getMetadataTy(C); case X86_MMXTyID : return getX86_MMXTy(C); + case X86_AMXTyID : return getX86_AMXTy(C); case TokenTyID : return getTokenTy(C); default: return nullptr; @@ -81,6 +82,14 @@ Ty->getPrimitiveSizeInBits().getFixedSize() == 64) return true; + // 8192-bit fixed width vector types can be losslessly converted to x86amx. + if (((isa(this)) && Ty->isX86_AMXTy()) && + getPrimitiveSizeInBits().getFixedSize() == 8192) + return true; + if ((isX86_AMXTy() && isa(Ty)) && + Ty->getPrimitiveSizeInBits().getFixedSize() == 8192) + return true; + // At this point we have only various mismatches of the first class types // remaining and ptr->ptr. Just select the lossless conversions. Everything // else is not lossless. Conservatively assume we can't losslessly convert @@ -120,6 +129,7 @@ case Type::FP128TyID: return TypeSize::Fixed(128); case Type::PPC_FP128TyID: return TypeSize::Fixed(128); case Type::X86_MMXTyID: return TypeSize::Fixed(64); + case Type::X86_AMXTyID: return TypeSize::Fixed(8192); case Type::IntegerTyID: return TypeSize::Fixed(cast(this)->getBitWidth()); case Type::FixedVectorTyID: @@ -179,6 +189,7 @@ Type *Type::getFP128Ty(LLVMContext &C) { return &C.pImpl->FP128Ty; } Type *Type::getPPC_FP128Ty(LLVMContext &C) { return &C.pImpl->PPC_FP128Ty; } Type *Type::getX86_MMXTy(LLVMContext &C) { return &C.pImpl->X86_MMXTy; } +Type *Type::getX86_AMXTy(LLVMContext &C) { return &C.pImpl->X86_AMXTy; } IntegerType *Type::getInt1Ty(LLVMContext &C) { return &C.pImpl->Int1Ty; } IntegerType *Type::getInt8Ty(LLVMContext &C) { return &C.pImpl->Int8Ty; } @@ -223,6 +234,10 @@ return getX86_MMXTy(C)->getPointerTo(AS); } +PointerType *Type::getX86_AMXPtrTy(LLVMContext &C, unsigned AS) { + return getX86_AMXTy(C)->getPointerTo(AS); +} + PointerType *Type::getIntNPtrTy(LLVMContext &C, unsigned N, unsigned AS) { return getIntNTy(C, N)->getPointerTo(AS); } diff --git a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp --- a/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp +++ b/llvm/lib/Target/X86/X86ISelDAGToDAG.cpp @@ -4618,7 +4618,7 @@ Segment, CFG, Chain}; - CNode = CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops); + CNode = CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops); ReplaceNode(Node, CNode); return; } @@ -4637,7 +4637,7 @@ CFG, Chain}; MachineSDNode *CNode = - CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops); + CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops); ReplaceNode(Node, CNode); return; } diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -1898,7 +1898,7 @@ } if (Subtarget.hasAMXTILE()) { - addRegisterClass(MVT::v256i32, &X86::TILERegClass); + addRegisterClass(MVT::x86amx, &X86::TILERegClass); } // We want to custom lower some of our intrinsics. @@ -5346,11 +5346,6 @@ if (MemVT.getSizeInBits() > Subtarget.getPreferVectorWidth()) return false; - // Don't merge to x86 amx tile, as we only map MVT::v256i32 - // to x86 amx tile on amx intrinsics. - if (MemVT == MVT::v256i32) - return false; - return true; } diff --git a/llvm/lib/Target/X86/X86LowerAMXType.cpp b/llvm/lib/Target/X86/X86LowerAMXType.cpp --- a/llvm/lib/Target/X86/X86LowerAMXType.cpp +++ b/llvm/lib/Target/X86/X86LowerAMXType.cpp @@ -6,20 +6,20 @@ // //===----------------------------------------------------------------------===// // -/// \file Pass to transform <256 x i32> -/// <256 x i32> is mapped to AMX tile register on X86, AMX instruction set only -/// provides simple operation on tile register. The basic elementwise operation -/// is not supported by AMX. Since we define the AMX tile as vector <256 x i32> +/// \file Pass to transform <256 x i32> load/store +/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only +/// provides simple operation on x86_amx. The basic elementwise operation +/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32> /// and only AMX intrinsics can operate on the type, we need transform -/// load/store <256 x i32> instruction to AMX load/store. Besides, we split -/// <256 x i32> to 2 <128 x i32> if the vector is not used or defined by AMX -/// intrinsics, so that in instruction selection it can be lowered to proper -/// size which HW can support. +/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can +/// not be combined with load/store, we transform the bitcast to amx load/store +/// and <256 x i32> store/load. // //===----------------------------------------------------------------------===// // #include "X86.h" -#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" @@ -30,231 +30,288 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsX86.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "lower-amx-type" -namespace { -class X86LowerAMXType { - Function &Func; - const DataLayout &DL; - DenseSet LDSet; - DenseSet STSet; - DenseMap> LoadMap; - -public: - X86LowerAMXType(Function &F) : Func(F), DL(F.getParent()->getDataLayout()) {} - bool visit(); - bool visitLD(); - bool visitST(); - void splitST(Instruction *Inst); - void splitLD(Instruction *Inst); -}; +static AllocaInst *CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB) { + Function &F = *BB->getParent(); + Module *M = BB->getModule(); + const DataLayout &DL = M->getDataLayout(); -// Split v256i32 load/store to 2 v128i32, so that ISel can -// lower it to proper vector size. -void X86LowerAMXType::splitST(Instruction *Inst) { - StoreInst *ST = dyn_cast(Inst); - IRBuilder<> Builder(ST); + Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false); LLVMContext &Ctx = Builder.getContext(); - Type *Ty = ST->getValueOperand()->getType(); - EVT VT = EVT::getEVT(Ty); - EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx); - Type *HalfTy = HalfVT.getTypeForEVT(Ctx); - - LoadInst *Lo, *Hi; - std::tie(Lo, Hi) = LoadMap[ST->getValueOperand()]; - Value *Ptr = ST->getPointerOperand(); - PointerType *HalfPtrTy = HalfTy->getPointerTo(ST->getPointerAddressSpace()); - Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy); - // The HW require the alignment for AMX tile is 64, but front-end generate - // code for the vector alignment which is the vector size. - uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8; - Align Alignment = std::min(Lo->getAlign(), Align(HalfTySize)); - Builder.CreateAlignedStore(Lo, HalfPtr, Alignment, ST->isVolatile()); - - HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1)); - Builder.CreateAlignedStore(Hi, HalfPtr, Alignment, ST->isVolatile()); + auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx)); + unsigned AllocaAS = DL.getAllocaAddrSpace(); + AllocaInst *AllocaRes = + new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front()); + AllocaRes->setAlignment(AllocaAlignment); + return AllocaRes; } -bool X86LowerAMXType::visitST() { - if (STSet.empty()) - return false; - for (auto *Inst : STSet) { - Value *Row, *Col; - const IntrinsicInst *II = dyn_cast(Inst->getOperand(0)); - if (!II) - Row = Col = nullptr; - else { - switch (II->getIntrinsicID()) { - default: - Row = Col = nullptr; - break; - case Intrinsic::x86_tileloadd64_internal: - case Intrinsic::x86_tdpbssd_internal: { - Row = II->getArgOperand(0); - Col = II->getArgOperand(1); - break; - } - } - } - if (!Row) { - splitST(Inst); - continue; +static std::pair getShape(IntrinsicInst *II, unsigned OpNo) { + Value *Row = nullptr, *Col = nullptr; + switch (II->getIntrinsicID()) { + default: + llvm_unreachable("Expect amx intrinsics"); + case Intrinsic::x86_tileloadd64_internal: + case Intrinsic::x86_tilestored64_internal: { + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + } + // a * b + c + // The shape depends on which operand. + case Intrinsic::x86_tdpbssd_internal: { + switch (OpNo) { + case 3: + Row = II->getArgOperand(0); + Col = II->getArgOperand(1); + break; + case 4: + Row = II->getArgOperand(0); + Col = II->getArgOperand(2); + break; + case 5: + Row = II->getArgOperand(2); + Col = II->getArgOperand(1); + break; } - IRBuilder<> Builder(Inst); - LLVMContext &Ctx = Builder.getContext(); - // Use the maximun column as stride. It must be the same with load stride. - Value *Stride = Builder.getInt64(64); - Value *I8Ptr = - Builder.CreateBitCast(Inst->getOperand(1), Type::getInt8PtrTy(Ctx)); - std::array Args = {Row, Col, I8Ptr, Stride, - Inst->getOperand(0)}; - - Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); + break; } - return true; + } + + return std::make_pair(Row, Col); } -void X86LowerAMXType::splitLD(Instruction *Inst) { - LoadInst *LD = dyn_cast(Inst); - IRBuilder<> Builder(LD); - LLVMContext &Ctx = Builder.getContext(); - Type *Ty = LD->getType(); - EVT VT = EVT::getEVT(Ty); - EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx); - Type *HalfTy = HalfVT.getTypeForEVT(Ctx); +// %src = load <256 x i32>, <256 x i32>* %addr, align 64 +// %2 = bitcast <256 x i32> %src to x86_amx +// --> +// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, +// i8* %addr, i64 %stride64) +static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) { + Value *Row = nullptr, *Col = nullptr; + Use &U = *(Bitcast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = cast(U.getUser()); + std::tie(Row, Col) = getShape(II, OpNo); + IRBuilder<> Builder(Bitcast); + // Use the maximun column as stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride}; - Value *Ptr = LD->getPointerOperand(); - PointerType *HalfPtrTy = HalfTy->getPointerTo(LD->getPointerAddressSpace()); - Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy); - // The HW require the alignment for AMX tile is 64, but front-end generate - // code for the vector alignment which is the vector size. - uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8; - Align Alignment = std::min(LD->getAlign(), Align(HalfTySize)); - auto *Lo = - Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile()); + Value *NewInst = + Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args); + Bitcast->replaceAllUsesWith(NewInst); +} - HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1)); - auto *Hi = - Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile()); +// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, +// %stride); +// %13 = bitcast x86_amx %src to <256 x i32> +// store <256 x i32> %13, <256 x i32>* %addr, align 64 +// --> +// call void @llvm.x86.tilestored64.internal(%row, %col, %addr, +// %stride64, %13) +static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) { - LoadMap[Inst] = std::make_pair(Lo, Hi); + Value *Tile = Bitcast->getOperand(0); + auto *II = cast(Tile); + // Tile is output from AMX intrinsic. The first operand of the + // intrinsic is row, the second operand of the intrinsic is column. + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + IRBuilder<> Builder(ST); + // Use the maximum column as stride. It must be the same with load + // stride. + Value *Stride = Builder.getInt64(64); + Value *I8Ptr = + Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy()); + std::array Args = {Row, Col, I8Ptr, Stride, Tile}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); + if (Bitcast->hasOneUse()) + return; + // %13 = bitcast x86_amx %src to <256 x i32> + // store <256 x i32> %13, <256 x i32>* %addr, align 64 + // %add = <256 x i32> %13, <256 x i32> %src2 + // --> + // %13 = bitcast x86_amx %src to <256 x i32> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %13) + // %14 = load <256 x i32>, %addr + // %add = <256 x i32> %14, <256 x i32> %src2 + Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1)); + Bitcast->replaceAllUsesWith(Vec); } -bool X86LowerAMXType::visitLD() { - if (LDSet.empty()) - return false; - for (auto &Inst : LDSet) { - int Count = 0; - Value *NewInst = nullptr; - // The user should be all AMX intrinsics or all LLVM instruction. - // Don't support it is used by both AMX intrinsics and LLVM instructions. - for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) { - Use &U = *I++; - const IntrinsicInst *II = dyn_cast(U.getUser()); - if (!II) { - Count++; - continue; - } - if (NewInst) - continue; - Value *Row, *Col; - switch (II->getIntrinsicID()) { - default: - report_fatal_error("Non-AMX intrinsic use tile type."); - break; - case Intrinsic::x86_tdpbssd_internal: { - unsigned OpNo = U.getOperandNo(); - switch (OpNo) { - case 3: - Row = II->getArgOperand(0); - Col = II->getArgOperand(1); - break; - case 4: - Row = II->getArgOperand(0); - Col = II->getArgOperand(2); - break; - case 5: - Row = II->getArgOperand(2); - Col = II->getArgOperand(1); - break; - } - break; - } - case Intrinsic::x86_tilestored64_internal: { - Row = II->getArgOperand(0); - Col = II->getArgOperand(1); - break; - } - } - assert(Count == 0 && "Can NOT mix amx intrinsic and LLVM instruction"); - // FIXME: The shape def should be ahead of load. - IRBuilder<> Builder(Inst); - LLVMContext &Ctx = Builder.getContext(); - // Use the maximun column as stride. - Value *Stride = Builder.getInt64(64); - Value *I8Ptr = - Builder.CreateBitCast(Inst->getOperand(0), Type::getInt8PtrTy(Ctx)); - std::array Args = {Row, Col, I8Ptr, Stride}; +// transform bitcast to instructions. +static bool transformBitcast(BitCastInst *Bitcast) { + IRBuilder<> Builder(Bitcast); + AllocaInst *AllocaAddr; + Value *I8Ptr, *Stride; + auto *Src = Bitcast->getOperand(0); - NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, - None, Args); + auto Prepare = [&]() { + AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent()); + I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy()); + Stride = Builder.getInt64(64); + }; - Inst->replaceAllUsesWith(NewInst); - } - if (!NewInst) - splitLD(Inst); + if (Bitcast->getType()->isX86_AMXTy()) { + // %2 = bitcast <256 x i32> %src to x86_amx + // --> + // %addr = alloca <256 x i32>, align 64 + // store <256 x i32> %src, <256 x i32>* %addr, align 64 + // %addr2 = bitcast <256 x i32>* to i8* + // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, + // i8* %addr2, + // i64 64) + Use &U = *(Bitcast->use_begin()); + unsigned OpNo = U.getOperandNo(); + auto *II = dyn_cast(U.getUser()); + if (!II) + return false; // May be bitcast from x86amx to <256 x i32>. + Prepare(); + Builder.CreateStore(Src, AllocaAddr); + // TODO we can pick an constant operand for the shape. + Value *Row = nullptr, *Col = nullptr; + std::tie(Row, Col) = getShape(II, OpNo); + std::array Args = {Row, Col, I8Ptr, Stride}; + Value *NewInst = Builder.CreateIntrinsic( + Intrinsic::x86_tileloadd64_internal, None, Args); + Bitcast->replaceAllUsesWith(NewInst); + } else { + // %2 = bitcast x86_amx %src to <256 x i32> + // --> + // %addr = alloca <256 x i32>, align 64 + // %addr2 = bitcast <256 x i32>* to i8* + // call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, + // i8* %addr2, i64 %stride) + // %2 = load <256 x i32>, <256 x i32>* %addr, align 64 + auto *II = dyn_cast(Src); + if (!II) + return false; // May be bitcast from <256 x i32> to x86amx. + Prepare(); + Value *Row = II->getOperand(0); + Value *Col = II->getOperand(1); + std::array Args = {Row, Col, I8Ptr, Stride, Src}; + Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args); + Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr); + Bitcast->replaceAllUsesWith(NewInst); } + return true; } +namespace { +class X86LowerAMXType { + Function &Func; + +public: + X86LowerAMXType(Function &F) : Func(F) {} + bool visit(); +}; + bool X86LowerAMXType::visit() { - bool C; - auto IsAMXType = [](FixedVectorType *VTy) { - if (!VTy) - return false; - if (!VTy->getScalarType()->isIntegerTy(32)) - return false; - if (VTy->getNumElements() != 256) - return false; + SmallVector DeadInsts; - return true; - }; + for (BasicBlock *BB : post_order(&Func)) { + for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend(); + II != IE;) { + Instruction &Inst = *II++; + auto *Bitcast = dyn_cast(&Inst); + if (!Bitcast) + continue; - for (BasicBlock &BB : Func) { - for (Instruction &Inst : BB) { - LoadInst *LD = dyn_cast(&Inst); - // Check load instruction. - // %3 = load <256 x i32>, <256 x i32>* %1, align 64 - if (LD) { - FixedVectorType *VTy = dyn_cast(Inst.getType()); - if (!IsAMXType(VTy)) + Value *Src = Bitcast->getOperand(0); + if (Bitcast->getType()->isX86_AMXTy()) { + if (Bitcast->user_empty()) { + DeadInsts.push_back(Bitcast); continue; - LDSet.insert(&Inst); - continue; + } + LoadInst *LD = dyn_cast(Src); + if (!LD) { + if (transformBitcast(Bitcast)) + DeadInsts.push_back(Bitcast); + continue; + } + // If load has mutli-user, duplicate a vector load. + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 + // %2 = bitcast <256 x i32> %src to x86_amx + // %add = add <256 x i32> %src, <256 x i32> %src2 + // --> + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 + // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, + // i8* %addr, i64 %stride64) + // %add = add <256 x i32> %src, <256 x i32> %src2 + + // If load has one user, the load will be eliminated in DAG ISel. + // %src = load <256 x i32>, <256 x i32>* %addr, align 64 + // %2 = bitcast <256 x i32> %src to x86_amx + // --> + // %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, + // i8* %addr, i64 %stride64) + combineLoadBitcast(LD, Bitcast); + DeadInsts.push_back(Bitcast); + if (LD->hasOneUse()) + DeadInsts.push_back(LD); + } else if (Src->getType()->isX86_AMXTy()) { + if (Bitcast->user_empty()) { + DeadInsts.push_back(Bitcast); + continue; + } + StoreInst *ST = nullptr; + for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end(); + UI != UE;) { + Value *I = (UI++)->getUser(); + ST = dyn_cast(I); + if (ST) + break; + } + if (!ST) { + if (transformBitcast(Bitcast)) + DeadInsts.push_back(Bitcast); + continue; + } + // If bitcast (%13) has one use, combine bitcast and store to amx store. + // %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr, + // %stride); + // %13 = bitcast x86_amx %src to <256 x i32> + // store <256 x i32> %13, <256 x i32>* %addr, align 64 + // --> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %13) + // + // If bitcast (%13) has multi-use, transform as below. + // %13 = bitcast x86_amx %src to <256 x i32> + // store <256 x i32> %13, <256 x i32>* %addr, align 64 + // %add = <256 x i32> %13, <256 x i32> %src2 + // --> + // %13 = bitcast x86_amx %src to <256 x i32> + // call void @llvm.x86.tilestored64.internal(%row, %col, %addr, + // %stride64, %13) + // %14 = load <256 x i32>, %addr + // %add = <256 x i32> %14, <256 x i32> %src2 + // + combineBitcastStore(Bitcast, ST); + // Delete user first. + DeadInsts.push_back(ST); + DeadInsts.push_back(Bitcast); } - // Check store instruction. - // store <256 x i32> %3, <256 x i32>* %2, align 64 - StoreInst *ST = dyn_cast(&Inst); - if (!ST) - continue; - FixedVectorType *VTy = - dyn_cast(ST->getOperand(0)->getType()); - if (!IsAMXType(VTy)) - continue; - STSet.insert(&Inst); } } - C = visitLD() | visitST(); - for (auto *Inst : STSet) - Inst->eraseFromParent(); - for (auto *Inst : LDSet) + bool C = !DeadInsts.empty(); + + for (auto *Inst : DeadInsts) Inst->eraseFromParent(); + return C; } } // anonymous namespace diff --git a/llvm/lib/Target/X86/X86RegisterInfo.td b/llvm/lib/Target/X86/X86RegisterInfo.td --- a/llvm/lib/Target/X86/X86RegisterInfo.td +++ b/llvm/lib/Target/X86/X86RegisterInfo.td @@ -637,7 +637,7 @@ // Tiles let CopyCost = -1 in // Don't allow copying of tile registers -def TILE : RegisterClass<"X86", [v256i32], 8192, +def TILE : RegisterClass<"X86", [x86amx], 8192, (sequence "TMM%u", 0, 7)> {let Size = 8192;} def TILECFG : RegisterClass<"X86", [untyped], 512, (add TMMCFG)> { let CopyCost = -1; // Don't allow copying of tile config registers. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -1115,6 +1115,10 @@ // Fold away bit casts of the stored value by storing the original type. if (auto *BC = dyn_cast(V)) { V = BC->getOperand(0); + // Don't transform when the type is x86_amx, it make the pass that lower + // x86_amx type happy. + if (BC->getType()->isX86_AMXTy() || V->getType()->isX86_AMXTy()) + return false; if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) { combineStoreToNewValue(IC, SI, V); return true; diff --git a/llvm/test/CodeGen/X86/AMX/amx-across-func.ll b/llvm/test/CodeGen/X86/AMX/amx-across-func.ll --- a/llvm/test/CodeGen/X86/AMX/amx-across-func.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-across-func.ll @@ -71,20 +71,20 @@ ; CHECK-NEXT: .cfi_def_cfa_offset 8 ; CHECK-NEXT: tilerelease ; CHECK-NEXT: retq - %3 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) #4 - %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) #4 + %3 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) #4 + %4 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) #4 tail call void (...) @foo() #4 - %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) #4 - %6 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, <256 x i32> %5, <256 x i32> %3, <256 x i32> %4) #4 - tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, <256 x i32> %6) #4 + %5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) #4 + %6 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, x86_amx %5, x86_amx %3, x86_amx %4) #4 + tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, x86_amx %6) #4 ret void } declare dso_local void @foo(...) local_unnamed_addr #3 -declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #4 -declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #4 -declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #4 +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #4 +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #4 +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #4 attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } diff --git a/llvm/test/CodeGen/X86/AMX/amx-config.ll b/llvm/test/CodeGen/X86/AMX/amx-config.ll --- a/llvm/test/CodeGen/X86/AMX/amx-config.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-config.ll @@ -47,31 +47,31 @@ br i1 %4, label %11, label %7 7: ; preds = %3 - %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %9 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 br label %15 11: ; preds = %3 - %12 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 - %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 - %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %12 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %13 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %14 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 br label %15 15: ; preds = %11, %7 - %16 = phi <256 x i32> [ %12, %11 ], [ %8, %7 ] - %17 = phi <256 x i32> [ %13, %11 ], [ %9, %7 ] - %18 = phi <256 x i32> [ %14, %11 ], [ %10, %7 ] - %19 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, <256 x i32> %18, <256 x i32> %16, <256 x i32> %17) #3 - tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %19) #3 + %16 = phi x86_amx [ %12, %11 ], [ %8, %7 ] + %17 = phi x86_amx [ %13, %11 ], [ %9, %7 ] + %18 = phi x86_amx [ %14, %11 ], [ %10, %7 ] + %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, x86_amx %18, x86_amx %16, x86_amx %17) #3 + tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %19) #3 ret void } -declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 -declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3 +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3 -declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3 +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3 attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } attributes #3 = { nounwind } diff --git a/llvm/test/CodeGen/X86/AMX/amx-intrinsic-chain.ll b/llvm/test/CodeGen/X86/AMX/amx-intrinsic-chain.ll --- a/llvm/test/CodeGen/X86/AMX/amx-intrinsic-chain.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-intrinsic-chain.ll @@ -37,23 +37,23 @@ ; CHECK-NEXT: vzeroupper ; CHECK-NEXT: retq entry: - %a1 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %A_mem, i64 64) + %a1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %A_mem, i64 64) %addr = getelementptr inbounds i8, i8* %A_mem, i64 1024 - %a2 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %addr, i64 64) - %c1 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64) + %a2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %addr, i64 64) + %c1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64) %caddr = getelementptr inbounds i8, i8* %C_mem, i64 1024 - %c2 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64) + %c2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64) br label %dotpd dotpd: - %b = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %B_mem, i64 64) - %dp1 = call <256 x i32> @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, <256 x i32> %c1, <256 x i32> %a1, <256 x i32> %b) - call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64, <256 x i32> %dp1) - %dp2 = call <256 x i32> @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, <256 x i32> %c2, <256 x i32> %a2, <256 x i32> %b) - call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64, <256 x i32> %dp2) + %b = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %B_mem, i64 64) + %dp1 = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c1, x86_amx %a1, x86_amx %b) + call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64, x86_amx %dp1) + %dp2 = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c2, x86_amx %a2, x86_amx %b) + call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64, x86_amx %dp2) ret void } -declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) -declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) -declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) diff --git a/llvm/test/CodeGen/X86/AMX/amx-spill.ll b/llvm/test/CodeGen/X86/AMX/amx-spill.ll --- a/llvm/test/CodeGen/X86/AMX/amx-spill.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-spill.ll @@ -70,43 +70,43 @@ ; CHECK-NEXT: tilerelease ; CHECK-NEXT: vzeroupper ; CHECK-NEXT: retq - %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %6 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %7 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %4 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %6 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %7 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %9 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 %11 = icmp eq i32 %0, 0 br i1 %11, label %16, label %12 12: ; preds = %3 - %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 - %15 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %13 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %14 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 + %15 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3 br label %20 16: ; preds = %3 - %17 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 - %18 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 - %19 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %17 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %18 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 + %19 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3 br label %20 20: ; preds = %16, %12 - %21 = phi <256 x i32> [ %17, %16 ], [ %13, %12 ] - %22 = phi <256 x i32> [ %18, %16 ], [ %14, %12 ] - %23 = phi <256 x i32> [ %19, %16 ], [ %15, %12 ] - %24 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, <256 x i32> %23, <256 x i32> %21, <256 x i32> %22) #3 - %25 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %6, <256 x i32> %24, <256 x i32> %5) #3 - %26 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %8, <256 x i32> %25, <256 x i32> %7) #3 - %27 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, <256 x i32> %10, <256 x i32> %26, <256 x i32> %9) #3 - tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %27) #3 + %21 = phi x86_amx [ %17, %16 ], [ %13, %12 ] + %22 = phi x86_amx [ %18, %16 ], [ %14, %12 ] + %23 = phi x86_amx [ %19, %16 ], [ %15, %12 ] + %24 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, x86_amx %23, x86_amx %21, x86_amx %22) #3 + %25 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, x86_amx %6, x86_amx %24, x86_amx %5) #3 + %26 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, x86_amx %8, x86_amx %25, x86_amx %7) #3 + %27 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, x86_amx %10, x86_amx %26, x86_amx %9) #3 + tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %27) #3 ret void } -declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 -declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3 -declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3 +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3 +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3 attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } attributes #3 = { nounwind } diff --git a/llvm/test/CodeGen/X86/AMX/amx-type.ll b/llvm/test/CodeGen/X86/AMX/amx-type.ll --- a/llvm/test/CodeGen/X86/AMX/amx-type.ll +++ b/llvm/test/CodeGen/X86/AMX/amx-type.ll @@ -8,18 +8,104 @@ @buf = dso_local global [1024 x i8] zeroinitializer, align 16 @buf2 = dso_local global [1024 x i8] zeroinitializer, align 16 +; test bitcast x86_amx to <256 x i32> +define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { +; CHECK-LABEL: @test_user_empty( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3:#.*]] +; CHECK-NEXT: ret void +; +entry: + %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s) #3 + %t2 = bitcast x86_amx %t1 to <256 x i32> + ret void +} + +; test bitcast <256 x i32> to x86_amx +define dso_local void @test_user_empty2(<256 x i32> %in) #2 { +; CHECK-LABEL: @test_user_empty2( +; CHECK-NEXT: entry: +; CHECK-NEXT: ret void +; +entry: + %t = bitcast <256 x i32> %in to x86_amx + ret void +} + +define dso_local <256 x i32> @test_amx_load_bitcast(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { +; CHECK-LABEL: @test_amx_load_bitcast( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, <256 x i32>* [[IN:%.*]], align 64 +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <256 x i32>* [[IN]] to i8* +; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[TMP0]], i64 64) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP1]]) [[ATTR3]] +; CHECK-NEXT: ret <256 x i32> [[T1]] +; +entry: + %t1 = load <256 x i32>, <256 x i32>* %in, align 64 + %t2 = bitcast <256 x i32> %t1 to x86_amx + call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2) #3 + ret <256 x i32> %t1 +} + +define dso_local <256 x i32> @test_amx_bitcast_store(<256 x i32>* %out, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 { +; CHECK-LABEL: @test_amx_bitcast_store( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[M]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3]] +; CHECK-NEXT: [[TMP0:%.*]] = bitcast <256 x i32>* [[OUT:%.*]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP0]], i64 64, x86_amx [[T1]]) +; CHECK-NEXT: [[TMP1:%.*]] = load <256 x i32>, <256 x i32>* [[OUT]], align 1024 +; CHECK-NEXT: ret <256 x i32> [[TMP1]] +; +entry: + %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %m, i8* %buf, i64 %s) #3 + %t2 = bitcast x86_amx %t1 to <256 x i32> + store <256 x i32> %t2, <256 x i32>* %out + ret <256 x i32> %t2 +} + +define dso_local void @test_src_add(<256 x i32> %x, <256 x i32> %y, i16 %r, i16 %c, i8* %buf, i64 %s) #2 { +; CHECK-LABEL: @test_src_add( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[Y:%.*]], [[X:%.*]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: store <256 x i32> [[ADD]], <256 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[TMP1]], i64 64) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP2]]) [[ATTR3]] +; CHECK-NEXT: ret void +; +entry: + %add = add <256 x i32> %y, %x + %t = bitcast <256 x i32> %add to x86_amx + call void @llvm.x86.tilestored64.internal(i16 %r, i16 %c, i8* %buf, i64 %s, x86_amx %t) #3 + ret void +} + +define dso_local void @test_src_add2(<256 x i32> %x, i16 %r, i16 %c, i8* %buf, i64 %s) #2 { +; CHECK-LABEL: @test_src_add2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64 +; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3]] +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8* +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[TMP1]], i64 64, x86_amx [[T1]]) +; CHECK-NEXT: [[TMP2:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024 +; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[TMP2]], [[X:%.*]] +; CHECK-NEXT: ret void +; +entry: + %t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %r, i16 %c, i8* %buf, i64 %s) #3 + %t2 = bitcast x86_amx %t1 to <256 x i32> + %add = add <256 x i32> %t2, %x + ret void +} + define dso_local void @test_load(i8* %in, i8* %out) local_unnamed_addr #2 { ; CHECK-LABEL: @test_load( ; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[IN:%.*]] to <256 x i32>* ; CHECK-NEXT: [[TMP2:%.*]] = bitcast i8* [[OUT:%.*]] to <256 x i32>* -; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[TMP1]] to <128 x i32>* -; CHECK-NEXT: [[TMP4:%.*]] = load <128 x i32>, <128 x i32>* [[TMP3]], align 64 -; CHECK-NEXT: [[TMP5:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP3]], i32 1 -; CHECK-NEXT: [[TMP6:%.*]] = load <128 x i32>, <128 x i32>* [[TMP5]], align 64 -; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[TMP2]] to <128 x i32>* -; CHECK-NEXT: store <128 x i32> [[TMP4]], <128 x i32>* [[TMP7]], align 64 -; CHECK-NEXT: [[TMP8:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP7]], i32 1 -; CHECK-NEXT: store <128 x i32> [[TMP6]], <128 x i32>* [[TMP8]], align 64 +; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 64, [[TBAA2:!tbaa !.*]] +; CHECK-NEXT: store <256 x i32> [[TMP3]], <256 x i32>* [[TMP2]], align 64, [[TBAA2]] ; CHECK-NEXT: ret void ; %1 = bitcast i8* %in to <256 x i32>* @@ -29,18 +115,33 @@ ret void } +define dso_local <256 x i32> @foo(<256 x i32>* nocapture readonly byval(<256 x i32>) align 1024 %0, <256 x i32>* nocapture readonly byval(<256 x i32>) align 1024 %1) local_unnamed_addr #0 { +; CHECK-LABEL: @foo( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[X:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0:%.*]], align 1024, [[TBAA5:!tbaa !.*]] +; CHECK-NEXT: [[Y:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1:%.*]], align 1024, [[TBAA5]] +; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[Y]], [[X]] +; CHECK-NEXT: ret <256 x i32> [[ADD]] +; +entry: + %x = load <256 x i32>, <256 x i32>* %0, align 1024, !tbaa !2 + %y = load <256 x i32>, <256 x i32>* %1, align 1024, !tbaa !2 + %add = add <256 x i32> %y, %x + ret <256 x i32> %add +} + define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr #0 { ; CHECK-LABEL: @__tile_loadd( ; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0 -; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2:!tbaa !.*]] +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]] ; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1 -; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7:!tbaa !.*]] +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8:!tbaa !.*]] ; CHECK-NEXT: [[TMP8:%.*]] = shl i64 [[TMP2:%.*]], 32 ; CHECK-NEXT: [[TMP9:%.*]] = ashr exact i64 [[TMP8]], 32 -; CHECK-NEXT: [[TMP10:%.*]] = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3:#.*]] +; CHECK-NEXT: [[TMP10:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3]] ; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2 ; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, <256 x i32> [[TMP10]]) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, x86_amx [[TMP10]]) ; CHECK-NEXT: ret void ; %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0 @@ -49,32 +150,33 @@ %7 = load i16, i16* %6, align 2, !tbaa !7 %8 = shl i64 %2, 32 %9 = ashr exact i64 %8, 32 - %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3 - %11 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2 - store <256 x i32> %10, <256 x i32>* %11, align 64, !tbaa !8 + %10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3 + %11 = bitcast x86_amx %10 to <256 x i32> + %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2 + store <256 x i32> %11, <256 x i32>* %12, align 64, !tbaa !8 ret void } define dso_local void @__tile_dpbsud(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #0 { ; CHECK-LABEL: @__tile_dpbsud( ; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0 -; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]] +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]] ; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1 -; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]] +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8]] ; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1 -; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA7]] +; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA8]] ; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2 ; CHECK-NEXT: [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8* -; CHECK-NEXT: [[TMP12:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64) +; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64) ; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2 ; CHECK-NEXT: [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8* -; CHECK-NEXT: [[TMP15:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64) +; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64) ; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 ; CHECK-NEXT: [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8* -; CHECK-NEXT: [[TMP18:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64) -; CHECK-NEXT: [[TMP19:%.*]] = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], <256 x i32> [[TMP12]], <256 x i32> [[TMP15]], <256 x i32> [[TMP18]]) [[ATTR3]] +; CHECK-NEXT: [[TMP18:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64) +; CHECK-NEXT: [[TMP19:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP12]], x86_amx [[TMP15]], x86_amx [[TMP18]]) [[ATTR3]] ; CHECK-NEXT: [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8* -; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, <256 x i32> [[TMP19]]) +; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, x86_amx [[TMP19]]) ; CHECK-NEXT: ret void ; %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0 @@ -85,27 +187,31 @@ %9 = load i16, i16* %8, align 2, !tbaa !7 %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2 %11 = load <256 x i32>, <256 x i32>* %10, align 64, !tbaa !8 - %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2 - %13 = load <256 x i32>, <256 x i32>* %12, align 64, !tbaa !8 - %14 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 - %15 = load <256 x i32>, <256 x i32>* %14, align 64, !tbaa !8 - %16 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, <256 x i32> %11, <256 x i32> %13, <256 x i32> %15) #3 - store <256 x i32> %16, <256 x i32>* %10, align 64, !tbaa !8 + %12 = bitcast <256 x i32> %11 to x86_amx + %13 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2 + %14 = load <256 x i32>, <256 x i32>* %13, align 64, !tbaa !8 + %15 = bitcast <256 x i32> %14 to x86_amx + %16 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 + %17 = load <256 x i32>, <256 x i32>* %16, align 64, !tbaa !8 + %18 = bitcast <256 x i32> %17 to x86_amx + %19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, x86_amx %12, x86_amx %15, x86_amx %18) #3 + %20 = bitcast x86_amx %19 to <256 x i32> + store <256 x i32> %20, <256 x i32>* %10, align 64, !tbaa !8 ret void } define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #1 { ; CHECK-LABEL: @__tile_stored( ; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0 -; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]] +; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]] ; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1 -; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]] +; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8]] ; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2 ; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP8]] to i8* -; CHECK-NEXT: [[TMP10:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64) +; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64) ; CHECK-NEXT: [[TMP11:%.*]] = shl i64 [[TMP1:%.*]], 32 ; CHECK-NEXT: [[TMP12:%.*]] = ashr exact i64 [[TMP11]], 32 -; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], <256 x i32> [[TMP10]]) [[ATTR3]] +; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], x86_amx [[TMP10]]) [[ATTR3]] ; CHECK-NEXT: ret void ; %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0 @@ -114,15 +220,16 @@ %7 = load i16, i16* %6, align 2, !tbaa !7 %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2 %9 = load <256 x i32>, <256 x i32>* %8, align 64, !tbaa !8 - %10 = shl i64 %1, 32 - %11 = ashr exact i64 %10, 32 - tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %11, <256 x i32> %9) #3 + %10 = bitcast <256 x i32> %9 to x86_amx + %11 = shl i64 %1, 32 + %12 = ashr exact i64 %11, 32 + tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx %10) #3 ret void } -declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 -declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3 -declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3 +declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3 +declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3 +declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3 attributes #0 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } attributes #1 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" } diff --git a/llvm/utils/TableGen/CodeGenTarget.cpp b/llvm/utils/TableGen/CodeGenTarget.cpp --- a/llvm/utils/TableGen/CodeGenTarget.cpp +++ b/llvm/utils/TableGen/CodeGenTarget.cpp @@ -76,6 +76,7 @@ case MVT::f128: return "MVT::f128"; case MVT::ppcf128: return "MVT::ppcf128"; case MVT::x86mmx: return "MVT::x86mmx"; + case MVT::x86amx: return "MVT::x86amx"; case MVT::Glue: return "MVT::Glue"; case MVT::isVoid: return "MVT::isVoid"; case MVT::v1i1: return "MVT::v1i1"; diff --git a/llvm/utils/TableGen/IntrinsicEmitter.cpp b/llvm/utils/TableGen/IntrinsicEmitter.cpp --- a/llvm/utils/TableGen/IntrinsicEmitter.cpp +++ b/llvm/utils/TableGen/IntrinsicEmitter.cpp @@ -248,7 +248,8 @@ IIT_V128 = 47, IIT_BF16 = 48, IIT_STRUCT9 = 49, - IIT_V256 = 50 + IIT_V256 = 50, + IIT_AMX = 51 }; static void EncodeFixedValueType(MVT::SimpleValueType VT, @@ -276,6 +277,7 @@ case MVT::token: return Sig.push_back(IIT_TOKEN); case MVT::Metadata: return Sig.push_back(IIT_METADATA); case MVT::x86mmx: return Sig.push_back(IIT_MMX); + case MVT::x86amx: return Sig.push_back(IIT_AMX); // MVT::OtherVT is used to mean the empty struct type here. case MVT::Other: return Sig.push_back(IIT_EMPTYSTRUCT); // MVT::isVoid is used to represent varargs here.