diff --git a/mlir/include/mlir/Conversion/AMXToLLVM/ConvertAMXToLLVM.h b/mlir/include/mlir/Conversion/AMXToLLVM/ConvertAMXToLLVM.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/AMXToLLVM/ConvertAMXToLLVM.h @@ -0,0 +1,23 @@ +//===- ConvertAMXToLLVM.h - Conversion Patterns from AMX to LLVM ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CONVERSION_AMXTOLLVM_CONVERTAMXTOLLVM_H_ +#define MLIR_CONVERSION_AMXTOLLVM_CONVERTAMXTOLLVM_H_ + +namespace mlir { + +class LLVMTypeConverter; +class OwningRewritePatternList; + +/// Collect a set of patterns to convert from the AMX dialect to LLVM. +void populateAMXToLLVMConversionPatterns(LLVMTypeConverter &converter, + OwningRewritePatternList &patterns); + +} // namespace mlir + +#endif // MLIR_CONVERSION_AMXTOLLVM_CONVERTAMXTOLLVM_H_ diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -494,7 +494,7 @@ operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AVX512, ArmNeon, ArmSVE, etc.) in combination with the + (AMX, AVX512, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral vector dialect lowering. }]; @@ -509,6 +509,10 @@ "bool", /*default=*/"true", "Allows compiler to assume indices fit in 32-bit if that yields " "faster code">, + Option<"enableAMX", "enable-amx", + "bool", /*default=*/"false", + "Enables the use of AMX dialect while lowering the vector " + "dialect.">, Option<"enableAVX512", "enable-avx512", "bool", /*default=*/"false", "Enables the use of AVX512 dialect while lowering the vector " diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,7 +23,8 @@ struct LowerVectorToLLVMOptions { LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), - enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {} + enableArmNeon(false), enableArmSVE(false), enableAMX(false), + enableAVX512(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -41,6 +42,10 @@ enableArmSVE = b; return *this; } + LowerVectorToLLVMOptions &setEnableAMX(bool b) { + enableAMX = b; + return *this; + } LowerVectorToLLVMOptions &setEnableAVX512(bool b) { enableAVX512 = b; return *this; @@ -50,6 +55,7 @@ bool enableIndexOptimizations; bool enableArmNeon; bool enableArmSVE; + bool enableAMX; bool enableAVX512; }; diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -0,0 +1,206 @@ +//===-- AMXOps.td - AMX dialect operation definitions *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the AMX dialect. +// +// The Intel Advanced Matrix Extensions (AMX) provides a tile matrix +// multiply unit (TMUL), a tile control register (TILECFG), and eight +// tile registers TMM0 through TMM7 (TILEDATA). +// +// The AMX dialect provides a bridge between MLIR concepts, such as +// 2-d vector, operations, and memrefs, and the lower level details +// of Intel AMX, such as configuration setup, tile sizes, instructions, +// and tile release. +// +// Note that since configuration changes (implicit at dialect level) are +// costly, it is highly recommended to use the AMX dialect on same-shaped +// vectors, at least within a single method. +// +//===----------------------------------------------------------------------===// + +#ifndef AMX_OPS +#define AMX_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// AMX dialect definition +//===----------------------------------------------------------------------===// + +def AMX_Dialect : Dialect { + let name = "amx"; + let cppNamespace = "::mlir::amx"; +} + +//===----------------------------------------------------------------------===// +// AMX op definitions +//===----------------------------------------------------------------------===// + +class AMX_Op traits = []> : + Op {} + +def TileZeroOp : AMX_Op<"tilezero", [NoSideEffect]> { + let summary = "tile zero operation"; + let description = [{ + Zeroes the destination tile, with the shape defined by the 2-dim + vector type of the result. This is eventually lowered into the + 'tilezero' instruction with the corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tilezero : vector<16x16xbf16> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let results = (outs + VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); + let extraClassDeclaration = [{ + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "attr-dict `:` type($res)"; +} + +def TileLoadOp : AMX_Op<"tileload", [NoSideEffect]> { + let summary = "tile load operation"; + let description = [{ + Loads a tile from memory defined by a base and indices, with the + shape defined by the 2-dim vector type of the result. This is + eventually lowered into the 'tileloadd' instruction with the + corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tileload %arg0[%c0, %c0] : memref into vector<16x64xi8> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins Arg:$base, + Variadic:$indices); + let results = (outs + VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " + "type($base) `into` type($res)"; +} + +def TileStoreOp : AMX_Op<"tilestore"> { + let summary = "tile store operation"; + let description = [{ + Stores a tile to memory defined by a base and indices, with the + shape defined by the 2-dim vector type of the value. This is + eventually lowered into the 'tilestored' instruction with the + corresponding tile configuration. + + Example: + + ```mlir + amx.tilestore %arg1[%c0, %c0], %0 : memref, vector<16x64xi8> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins Arg:$base, + Variadic:$indices, + VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val); + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getVectorType() { + return val().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " + "type($base) `,` type($val)"; +} + +def TileMulFOp : AMX_Op<"tilemulf", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> { + let summary = "tile multiplication operation (floating-point)"; + let description = [{ + Multiplies a m x k tile with a k x n tile and accumulates the results + into a m x n destination tile. Supports f32 <- bf16 x bf16 (with + pairs of bf16). The operation is eventually lowered into the + 'tdpbf16ps' instruction with the corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tilemulf %a, %b, %c + : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs, + VectorOfRankAndType<[2], [F32, BF16]>:$rhs, + VectorOfRankAndType<[2], [F32, BF16]>:$acc); + let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res); + let extraClassDeclaration = [{ + VectorType getLhsVectorType() { + return lhs().getType().cast(); + } + VectorType getRhsVectorType() { + return rhs().getType().cast(); + } + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " + "type($lhs) `,` type($rhs) `,` type($acc) "; +} + +def TileMulIOp : AMX_Op<"tilemuli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> { + let summary = "tile multiplication operation (integer)"; + let description = [{ + Multiplies a m x k tile with a k x n tile and accumulates the results + into a m x n destination tile. Supports all si32 <- s/ui8 x s/ui8 + combinations (4 bytes packed into dwords in the columns of both the + source operand tiles; the zero or sign extension is specified with + the attributes). The operation is eventually lowered into one of + the tdpbssd', 'tdpbsud', 'tdpbusd', or 'tdpbuud' instructions with + the corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tilemuli %a, %b, %c [true, true] + : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs, + VectorOfRankAndType<[2], [I32, I8]>:$rhs, + VectorOfRankAndType<[2], [I32, I8]>:$acc, + BoolArrayAttr:$zext); + let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res); + let extraClassDeclaration = [{ + VectorType getLhsVectorType() { + return lhs().getType().cast(); + } + VectorType getRhsVectorType() { + return rhs().getType().cast(); + } + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` " + "type($lhs) `,` type($rhs) `,` type($acc) "; +} + +#endif // AMX_OPS diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h @@ -0,0 +1,26 @@ +//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for AMX in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_ +#define MLIR_DIALECT_AMX_AMXDIALECT_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/AMX/AMXDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/AMX/AMX.h.inc" + +#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(AMX amx) +add_mlir_doc(AMX -gen-dialect-doc AMX Dialects/) diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Async) add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) +add_subdirectory(AMX) add_subdirectory(AVX512) add_subdirectory(Complex) add_subdirectory(DLTI) diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -35,6 +35,12 @@ mlir_tablegen(ROCDLConversions.inc -gen-llvmir-conversions) add_public_tablegen_target(MLIRROCDLConversionsIncGen) +add_mlir_dialect(LLVMAMX llvm_amx LLVMAMX) +add_mlir_doc(LLVMAMX -gen-dialect-doc LLVMAMX Dialects/) +set(LLVM_TARGET_DEFINITIONS LLVMAMX.td) +mlir_tablegen(LLVMAMXConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRLLVMAMXConversionsIncGen) + add_mlir_dialect(LLVMArmSVE llvm_arm_sve LLVMArmSVE) add_mlir_doc(LLVMArmSVE -gen-dialect-doc LLVMArmSve Dialects/) set(LLVM_TARGET_DEFINITIONS LLVMArmSVE.td) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAMX.td @@ -0,0 +1,73 @@ +//===-- LLVMAMX.td - LLVMAMX dialect op definitions -------*- tablegen -*--===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the LLVMAMX dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVMIR_AMX_OPS +#define LLVMIR_AMX_OPS + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// LLVMAMX dialect definition +//===----------------------------------------------------------------------===// + +def LLVMAMX_Dialect : Dialect { + let name = "llvm_amx"; + let cppNamespace = "::mlir::LLVM"; +} + +//----------------------------------------------------------------------------// +// MLIR LLVM AMX intrinsics using the MLIR LLVM Dialect type system +//----------------------------------------------------------------------------// + +// The "internal" intrinsics are meant for compiler usage. +class LLVMAMX_IntrOp traits = []> : + LLVM_IntrOpBase; + +// +// Tile reset. +// + +def LLVM_x86_amx_tilezero : LLVMAMX_IntrOp<"tilezero", 1>, + Arguments<(ins LLVM_Type, LLVM_Type)>; + +// +// Tile memory operations. +// + +def LLVM_x86_amx_tileloadd64 : LLVMAMX_IntrOp<"tileloadd64", 1>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_amx_tilestored64 : LLVMAMX_IntrOp<"tilestored64", 0>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +// +// Tile arithmetic operations. +// + +def LLVM_x86_amx_tdpbf16ps : LLVMAMX_IntrOp<"tdpbf16ps", 1>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_amx_tdpbssd : LLVMAMX_IntrOp<"tdpbssd", 1>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_amx_tdpbsud : LLVMAMX_IntrOp<"tdpbsud", 1>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_amx_tdpbusd : LLVMAMX_IntrOp<"tdpbusd", 1>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +def LLVM_x86_amx_tdpbuud : LLVMAMX_IntrOp<"tdpbuud", 1>, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; + +#endif // AMX_OPS diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAMXDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMAMXDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAMXDialect.h @@ -0,0 +1,24 @@ +//===- LLVMAMXDialect.h - MLIR Dialect for LLVMAMX --------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for LLVMAMX in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_LLVMAMXDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_LLVMAMXDIALECT_H_ + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMAMX.h.inc" + +#include "mlir/Dialect/LLVMIR/LLVMAMXDialect.h.inc" + +#endif // MLIR_DIALECT_LLVMIR_LLVMAMXDIALECT_H_ diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -14,6 +14,7 @@ #ifndef MLIR_INITALLDIALECTS_H_ #define MLIR_INITALLDIALECTS_H_ +#include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" @@ -22,6 +23,7 @@ #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAMXDialect.h" #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" @@ -50,12 +52,14 @@ // clang-format off registry.insert> 3; + auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0)); + auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes); + m = rewriter.create(loc, llvmInt16Type, mattr); + n = rewriter.create(loc, llvmInt16Type, nattr); +} + +/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer +/// shape may "envelop" the actual tile shape, and may be dynamically sized. +Value getStride(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, MemRefType mType, Value base, + Location loc) { + unsigned last = mType.getShape().size(); + assert(last >= 2); + Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); + unsigned bytes = mType.getElementType().getIntOrFloatBitWidth() >> 3; + if (mType.isDynamicDim(last - 1)) { + // Dynamic size needs code to compute the stride at runtime. + MemRefDescriptor memrefDescriptor(base); + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = rewriter.create(loc, llvmInt64Type, attr); + return rewriter.create( + loc, llvmInt64Type, scale, + memrefDescriptor.size(rewriter, loc, last - 1)); + } + // Use direct constant for static size. + auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last - 1) * bytes); + return rewriter.create(loc, llvmInt64Type, attr); +} + +/// Cast any pointer to the !llvm.ptr pointer type. +Value castPtr(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, Location loc, Value ptr) { + auto i8Ptr = LLVM::LLVMPointerType::get( + IntegerType::get(&typeConverter.getContext(), 8)); + return rewriter.create(loc, i8Ptr, ptr); +} + +struct TileZeroConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TileZeroOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + VectorType vType = op.getVectorType(); + // Determine m x n tile sizes. + Value m, n; + getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc(), m, n); + // Replace operation with intrinsic. + Type resType = typeConverter->convertType(vType); + rewriter.replaceOpWithNewOp(op, resType, m, n); + return success(); + } +}; + +struct TileLoadConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TileLoadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileLoadOp::Adaptor adaptor(operands); + MemRefType mType = op.getMemRefType(); + VectorType vType = op.getVectorType(); + // Determine m x n tile sizes. + Value m, n; + getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc(), m, n); + // Replace operation with intrinsic. + Value stride = getStride(rewriter, *getTypeConverter(), mType, + adaptor.base(), op.getLoc()); + Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), + adaptor.indices(), rewriter); + ptr = castPtr(rewriter, *getTypeConverter(), op.getLoc(), ptr); + Type resType = typeConverter->convertType(vType); + rewriter.replaceOpWithNewOp(op, resType, m, n, + ptr, stride); + return success(); + } +}; + +struct TileStoreConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TileStoreOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileStoreOp::Adaptor adaptor(operands); + MemRefType mType = op.getMemRefType(); + VectorType vType = op.getVectorType(); + // Determine m x n tile sizes. + Value m, n; + getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc(), m, n); + // Replace operation with intrinsic. + Value stride = getStride(rewriter, *getTypeConverter(), mType, + adaptor.base(), op.getLoc()); + Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), + adaptor.indices(), rewriter); + ptr = castPtr(rewriter, *getTypeConverter(), op.getLoc(), ptr); + rewriter.replaceOpWithNewOp( + op, m, n, ptr, stride, adaptor.val()); + return success(); + } +}; + +struct TileMulFConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TileMulFOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileMulFOp::Adaptor adaptor(operands); + VectorType aType = op.getLhsVectorType(); + VectorType cType = op.getVectorType(); + // Determine m x n x k tile sizes. + Value m, n, k; + getTileSizes(rewriter, *getTypeConverter(), cType, op.getLoc(), m, n); + getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc(), m, k); + // Replace operation with intrinsic. + Type resType = typeConverter->convertType(cType); + rewriter.replaceOpWithNewOp( + op, resType, m, n, k, adaptor.acc(), adaptor.lhs(), adaptor.rhs()); + return success(); + } +}; + +struct TileMulIConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TileMulIOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileMulIOp::Adaptor adaptor(operands); + VectorType aType = op.getLhsVectorType(); + VectorType bType = op.getRhsVectorType(); + VectorType cType = op.getVectorType(); + // Determine m x n x k tile sizes. + Value m, n, k, k2; + getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc(), m, k); + getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc(), k2, n); + // Replace operation with intrinsic. + Type resType = typeConverter->convertType(cType); + bool zexta = op.zext()[0].cast().getValue(); + bool zextb = op.zext()[1].cast().getValue(); + if (zexta && zextb) + rewriter.replaceOpWithNewOp( + op, resType, m, n, k, adaptor.acc(), adaptor.lhs(), adaptor.rhs()); + else if (zexta && !zextb) + rewriter.replaceOpWithNewOp( + op, resType, m, n, k, adaptor.acc(), adaptor.lhs(), adaptor.rhs()); + else if (!zexta && zextb) + rewriter.replaceOpWithNewOp( + op, resType, m, n, k, adaptor.acc(), adaptor.lhs(), adaptor.rhs()); + else + rewriter.replaceOpWithNewOp( + op, resType, m, n, k, adaptor.acc(), adaptor.lhs(), adaptor.rhs()); + return success(); + } +}; + +} // namespace + +/// Populate the given list with patterns that convert from AMX to LLVM. +void mlir::populateAMXToLLVMConversionPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + patterns.insert(converter); +} diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -1,5 +1,6 @@ add_subdirectory(AffineToStandard) add_subdirectory(AsyncToLLVM) +add_subdirectory(AMXToLLVM) add_subdirectory(ComplexToLLVM) add_subdirectory(GPUCommon) add_subdirectory(GPUToNVVM) diff --git a/mlir/lib/Conversion/PassDetail.h b/mlir/lib/Conversion/PassDetail.h --- a/mlir/lib/Conversion/PassDetail.h +++ b/mlir/lib/Conversion/PassDetail.h @@ -30,6 +30,7 @@ namespace LLVM { class LLVMArmSVEDialect; +class LLVMAMXialect; class LLVMAVX512Dialect; class LLVMDialect; } // end namespace LLVM diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -14,6 +14,8 @@ LINK_LIBS PUBLIC MLIRArmNeon + MLIRAMX + MLIRAMXToLLVM MLIRAVX512 MLIRAVX512Transforms MLIRArmSVE diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -10,13 +10,16 @@ #include "../PassDetail.h" +#include "mlir/Conversion/AMXToLLVM/ConvertAMXToLLVM.h" #include "mlir/Conversion/ArmSVEToLLVM/ArmSVEToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" +#include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/AVX512/Transforms.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAMXDialect.h" #include "mlir/Dialect/LLVMIR/LLVMArmSVEDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" @@ -34,6 +37,7 @@ this->enableIndexOptimizations = options.enableIndexOptimizations; this->enableArmNeon = options.enableArmNeon; this->enableArmSVE = options.enableArmSVE; + this->enableAMX = options.enableAMX; this->enableAVX512 = options.enableAVX512; } // Override explicitly to allow conditional dialect dependence. @@ -43,6 +47,8 @@ registry.insert(); if (enableArmSVE) registry.insert(); + if (enableAMX) + registry.insert(); if (enableAVX512) registry.insert(); } @@ -102,6 +108,11 @@ }); populateArmSVEToLLVMConversionPatterns(converter, patterns); } + if (enableAMX) { + target.addLegalDialect(); + target.addIllegalDialect(); + populateAMXToLLVMConversionPatterns(converter, patterns); + } if (enableAVX512) { configureAVX512LegalizeForExportTarget(target); populateAVX512LegalizeForLLVMExportPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AMX/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRAMX + IR/AMXDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX + + DEPENDS + MLIRAMXIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -0,0 +1,106 @@ +//===- AMXOps.cpp - MLIR AMX ops implementation ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AMX dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void amx::AMXDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/AMX/AMX.cpp.inc" + >(); +} + +/// Verify that AMX supports the implied tile shape. +static LogicalResult verifyTileSize(VectorType tp) { + const unsigned kMaxRows = 16; + const unsigned kBitsPerRow = 64 * 8; + unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); + if (tp.getDimSize(0) > kMaxRows || col > kBitsPerRow) + return failure(); + if (col & 0x1f) + return failure(); // should be multiple of 4 bytes + return success(); +} + +/// Verify that AMX supports the multiplication. +static LogicalResult verifyMultShape(VectorType atp, VectorType btp, + VectorType ctp, unsigned scale) { + unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; + unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; + unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); + if (cm != am || cn != bn || ak != bk) + return failure(); + return success(); +} + +static LogicalResult verify(amx::TileZeroOp op) { + if (failed(verifyTileSize(op.getVectorType()))) + return op.emitOpError("unsupported tile size"); + return success(); +} + +static LogicalResult verify(amx::TileLoadOp op) { + if (failed(verifyTileSize(op.getVectorType()))) + return op.emitOpError("unsupported tile size"); + return success(); +} + +static LogicalResult verify(amx::TileStoreOp op) { + if (failed(verifyTileSize(op.getVectorType()))) + return op.emitOpError("unsupported tile size"); + return success(); +} + +static LogicalResult verify(amx::TileMulFOp op) { + VectorType aType = op.getLhsVectorType(); + VectorType bType = op.getRhsVectorType(); + VectorType cType = op.getVectorType(); + if (failed(verifyMultShape(aType, bType, cType, 1))) + return op.emitOpError("unexpected shape"); + if (failed(verifyTileSize(aType)) || failed(verifyTileSize(bType)) || + failed(verifyTileSize(cType))) + return op.emitOpError("unsupported tile size"); + Type ta = aType.getElementType(); + Type tb = bType.getElementType(); + Type tc = cType.getElementType(); + if (ta.isBF16() && tb.isBF16() && tc.isF32()) + return success(); + return op.emitOpError("unsupported type combination"); +} + +static LogicalResult verify(amx::TileMulIOp op) { + if (op.zext().size() != 2) + return op.emitOpError("unexpected zext length"); + VectorType aType = op.getLhsVectorType(); + VectorType bType = op.getRhsVectorType(); + VectorType cType = op.getVectorType(); + if (failed(verifyMultShape(aType, bType, cType, 2))) + return op.emitOpError("unexpected shape"); + if (failed(verifyTileSize(aType)) || failed(verifyTileSize(bType)) || + failed(verifyTileSize(cType))) + return op.emitOpError("unsupported tile size"); + Type ta = aType.getElementType(); + Type tb = bType.getElementType(); + Type tc = cType.getElementType(); + if (ta.isInteger(8) && tb.isInteger(8) && tc.isInteger(32)) + return success(); + return op.emitOpError("unsupported type combination"); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/AMX/AMX.cpp.inc" diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) add_subdirectory(Async) +add_subdirectory(AMX) add_subdirectory(AVX512) add_subdirectory(Complex) add_subdirectory(DLTI) diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -29,6 +29,27 @@ MLIRSupport ) +add_mlir_dialect_library(MLIRLLVMAMX + IR/LLVMAMXDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRLLVMAMXIncGen + MLIRLLVMAMXConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) + add_mlir_dialect_library(MLIRLLVMArmSVE IR/LLVMArmSVEDialect.cpp diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAMXDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAMXDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAMXDialect.cpp @@ -0,0 +1,31 @@ +//===- LLVMAMXDialect.cpp - MLIR LLVMAVX512 ops implementation ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the LLVMAMX dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/IR/IntrinsicsX86.h" + +#include "mlir/Dialect/LLVMIR/LLVMAMXDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void LLVM::LLVMAMXDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/LLVMAMX.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/LLVMAMX.cpp.inc" diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -39,6 +39,7 @@ MLIRArmNeonToLLVMIRTranslation MLIRAVX512ToLLVMIRTranslation MLIRLLVMArmSVEToLLVMIRTranslation + MLIRLLVMAMXToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation MLIRNVVMToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,6 +1,7 @@ add_subdirectory(ArmNeon) add_subdirectory(AVX512) add_subdirectory(LLVMArmSVE) +add_subdirectory(LLVMAMX) add_subdirectory(LLVMIR) add_subdirectory(NVVM) add_subdirectory(OpenMP) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMAMX/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/LLVMAMX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMAMX/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRLLVMAMXToLLVMIRTranslation + LLVMAMXToLLVMIRTranslation.cpp + + DEPENDS + MLIRLLVMAMXConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMAMX + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMAMX/LLVMAMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMAMX/LLVMAMXToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMAMX/LLVMAMXToLLVMIRTranslation.cpp @@ -0,0 +1,55 @@ +//===- LLVMAMXToLLVMIRTranslation.cpp - Translate LLVMAMX to LLVM IR ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the LLVMAMX dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/LLVMAMX/LLVMAMXToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/LLVMAMXDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsX86.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the LLVMAMX dialect to LLVM IR. +class LLVMAMXDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/LLVMIR/LLVMAMXConversions.inc" + + return failure(); + } +}; +} // end namespace + +void mlir::registerLLVMAMXDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerLLVMAMXDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerLLVMAMXDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -29,6 +29,7 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS) set(INTEL_SDE_EXECUTABLE "" CACHE STRING "If set, arch-specific integration tests are run with Intel SDE.") + option(MLIR_RUN_AMX_TESTS "Run AMX tests.") option(MLIR_RUN_AVX512_TESTS "Run AVX512 tests.") # Passed to lit.site.cfg.py.in to set up the path where to find the libraries. set(MLIR_INTEGRATION_TEST_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/mlir/test/Conversion/AMXToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AMXToLLVM/convert-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Conversion/AMXToLLVM/convert-to-llvm.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s + +// CHECK-LABEL: muli( +// CHECK: llvm_amx.tilezero +// CHECK: llvm_amx.tileloadd64 +// CHECK: llvm_amx.tileloadd64 +// CHECK: llvm_amx.tdpbuud +// CHECK: llvm_amx.tilestored64 +// CHECK: llvm_amx.tdpbssd +// CHECK: llvm_amx.tilestored64 +// CHECK: llvm_amx.tdpbusd +// CHECK: llvm_amx.tilestored64 +// CHECK: llvm_amx.tdpbsud +// CHECK: llvm_amx.tilestored64 +func @muli(%arg0: memref, %arg1: memref) { + %0 = constant 0 : index + %1 = amx.tilezero : vector<16x64xi8> + %2 = amx.tileload %arg0[%0, %0] : memref into vector<16x64xi8> + %3 = amx.tileload %arg1[%0, %0] : memref into vector<16x16xi32> + %4 = amx.tilemuli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tilestore %arg1[%0, %0], %4 : memref, vector<16x16xi32> + %5 = amx.tilemuli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tilestore %arg1[%0, %0], %5 : memref, vector<16x16xi32> + %6 = amx.tilemuli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tilestore %arg1[%0, %0], %6 : memref, vector<16x16xi32> + %7 = amx.tilemuli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tilestore %arg1[%0, %0], %7 : memref, vector<16x16xi32> + return +} + +// CHECK-LABEL: mulf( +// CHECK: llvm_amx.tilezero +// CHECK: llvm_amx.tileloadd64 +// CHECK: llvm_amx.tileloadd64 +// CHECK: llvm_amx.tdpbf16ps +// CHECK: llvm_amx.tilestored64 +func @mulf(%arg0: memref, %arg1: memref) { + %0 = constant 0 : index + %1 = amx.tilezero : vector<16x32xbf16> + %2 = amx.tileload %arg0[%0, %0] : memref into vector<16x32xbf16> + %3 = amx.tileload %arg1[%0, %0] : memref into vector<16x16xf32> + %4 = amx.tilemulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + amx.tilestore %arg1[%0, %0], %4 : memref, vector<16x16xf32> + return +} diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/AMX/invalid.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func @rowsize() { + // expected-error@+1 {{'amx.tilezero' op unsupported tile size}} + %0 = amx.tilezero : vector<17x16xbf16> +} + +// ----- + +func @colsize(%arg0: memref) { + %0 = constant 0 : index + // expected-error@+1 {{'amx.tileload' op unsupported tile size}} + %1 = amx.tileload %arg0[%0, %0] : memref into vector<16x65xi8> +} + +// ----- + +func @multsize() { + %0 = amx.tilezero : vector<8x8xbf16> + %1 = amx.tilezero : vector<8x8xbf16> + %2 = amx.tilezero : vector<4x4xf32> + // expected-error@+1 {{'amx.tilemulf' op unexpected shape}} + %3 = amx.tilemulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32> +} + +// ----- + +func @zextsize() { + %0 = amx.tilezero : vector<8x8xi8> + %1 = amx.tilezero : vector<8x8xi8> + %2 = amx.tilezero : vector<8x8xi32> + // expected-error@+1 {{'amx.tilemuli' op unexpected zext length}} + %3 = amx.tilemuli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32> +} diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/AMX/roundtrip.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: tzero +// CHECK: amx.tilezero : vector<16x16xbf16> +// CHECK amx.tilestore %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref, vector<16x16xbf16> +func @tzero(%arg0: memref) { + %0 = constant 0 : index + %1 = amx.tilezero : vector<16x16xbf16> + amx.tilestore %arg0[%0, %0], %1 : memref, vector<16x16xbf16> + return +} + +// CHECK-LABEL: tmulf +// CHECK: %[[x:.*]] = amx.tileload %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x32xbf16> +// CHECK: %[[z:.*]] = amx.tileload %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x16xf32> +// CHECK: %[[m:.*]] = amx.tilemulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK: amx.tilestore %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, vector<16x16xf32> +func @tmulf(%arg0: memref, %arg1: memref) { + %0 = constant 0 : index + %1 = amx.tileload %arg0[%0, %0] : memref into vector<16x32xbf16> + %2 = amx.tileload %arg1[%0, %0] : memref into vector<16x16xf32> + %3 = amx.tilemulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + amx.tilestore %arg1[%0, %0], %3 : memref, vector<16x16xf32> + return +} + +// CHECK-LABEL: tmuli +// CHECK: %[[x:.*]] = amx.tileload %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> +// CHECK: %[[y:.*]] = amx.tileload %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> +// CHECK: %[[z:.*]] = amx.tileload %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x16xi32> +// CHECK: %[[m:.*]] = amx.tilemuli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> +// CHECK: amx.tilestore %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, vector<16x16xi32> +func @tmuli(%arg0: memref, %arg1: memref, %arg2: memref) { + %0 = constant 0 : index + %1 = amx.tileload %arg0[%0, %0] : memref into vector<16x64xi8> + %2 = amx.tileload %arg1[%0, %0] : memref into vector<16x64xi8> + %3 = amx.tileload %arg2[%0, %0] : memref into vector<16x16xi32> + %4 = amx.tilemuli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tilestore %arg2[%0, %0], %4 : memref, vector<16x16xi32> + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg @@ -0,0 +1,15 @@ +import sys + +# AMX tests must be enabled via build flag. +if config.mlir_run_amx_tests != 'ON': + config.unsupported = True + +# No JIT on win32. +if sys.platform == 'win32': + config.unsupported = True + +if config.intel_sde_executable: + # Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU. + config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli')) +else: + config.substitutions.append(('%lli', 'lli')) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AMX. + +// Multiply into zeroed destination. +func @kernel1(%arg0: memref<2x4xbf16>, + %arg1: memref<2x4xbf16>, + %arg2: memref<2x2xf32>) { + %0 = constant 0 : index + %1 = amx.tileload %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %2 = amx.tileload %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %3 = amx.tilezero : vector<2x2xf32> + %4 = amx.tilemulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> + amx.tilestore %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> + return +} + +// Multiply and update into destination. +func @kernel2(%arg0: memref<2x4xbf16>, + %arg1: memref<2x4xbf16>, + %arg2: memref<2x2xf32>) { + %0 = constant 0 : index + %1 = amx.tileload %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %2 = amx.tileload %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %3 = amx.tileload %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32> + %4 = amx.tilemulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> + amx.tilestore %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> + return +} + +func @entry() { + %f0 = constant 0.0: f32 + %c0 = constant 0: index + %c1 = constant 1: index + %c2 = constant 2: index + + // Set up memory. + %a = alloc() : memref<2x4xbf16> + %b = alloc() : memref<2x4xbf16> + %c = alloc() : memref<2x2xf32> + + %0 = std.constant dense<[[1.0, 2.0, 3.0, 4.0 ], + [5.0, 6.0, 7.0, 8.0 ]]> : vector<2x4xbf16> + vector.transfer_write %0, %a[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16> + %1 = std.constant dense<[[ 9.0, 10.0, 11.0, 12.0 ], + [13.0, 14.0, 15.0, 16.0 ]]> : vector<2x4xbf16> + vector.transfer_write %1, %b[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16> + + // Call kernel. + call @kernel1(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> () + + // Print and verify. + // + // CHECK: ( 124, 144 ) + // CHECK-NEXT: ( 308, 360 ) + scf.for %i = %c0 to %c2 step %c1 { + %av = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32> + vector.print %av : vector<2xf32> + } + + // Call kernel. + call @kernel2(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> () + + // Print and verify. + // + // CHECK-NEXT: ( 248, 288 ) + // CHECK-NEXT: ( 616, 720 ) + // + scf.for %i = %c0 to %c2 step %c1 { + %cv = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32> + vector.print %cv : vector<2xf32> + } + + // Release resources. + dealloc %a : memref<2x4xbf16> + dealloc %b : memref<2x4xbf16> + dealloc %c : memref<2x2xf32> + + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AMX. + +// Multiply into zeroed destination. +func @kernel1(%arg0: memref<2x8xi8>, + %arg1: memref<2x8xi8>, + %arg2: memref<2x2xi32>) { + %0 = constant 0 : index + %1 = amx.tileload %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %2 = amx.tileload %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %3 = amx.tilezero : vector<2x2xi32> + %4 = amx.tilemuli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + amx.tilestore %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> + return +} + +// Multiply and update into destination. +func @kernel2(%arg0: memref<2x8xi8>, + %arg1: memref<2x8xi8>, + %arg2: memref<2x2xi32>) { + %0 = constant 0 : index + %1 = amx.tileload %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %2 = amx.tileload %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %3 = amx.tileload %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32> + %4 = amx.tilemuli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + amx.tilestore %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> + return +} + +func @entry() { + %i0 = constant 0: i32 + %c0 = constant 0: index + %c1 = constant 1: index + %c2 = constant 2: index + + // Set up memory. + %a = alloc() : memref<2x8xi8> + %b = alloc() : memref<2x8xi8> + %c = alloc() : memref<2x2xi32> + + %0 = std.constant dense<[[1 , 2, 3 , 4 , 5, 6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]]> : vector<2x8xi8> + vector.transfer_write %0, %a[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8> + %1 = std.constant dense<[[17, 18, 19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30, 31, 32]]> : vector<2x8xi8> + vector.transfer_write %1, %b[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8> + + // Call kernel. + call @kernel1(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> () + + // Print and verify. + // + // CHECK: ( 884, 1028 ) + // CHECK-NEXT: ( 2324, 2724 ) + scf.for %i = %c0 to %c2 step %c1 { + %av = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32> + vector.print %av : vector<2xi32> + } + + // Call kernel. + call @kernel2(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> () + + // Print and verify. + // + // CHECK-NEXT: ( 1768, 2056 ) + // CHECK-NEXT: ( 4648, 5448 ) + // + scf.for %i = %c0 to %c2 step %c1 { + %cv = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32> + vector.print %cv : vector<2xi32> + } + + // Release resources. + dealloc %a : memref<2x8xi8> + dealloc %b : memref<2x8xi8> + dealloc %c : memref<2x2xi32> + + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AMX. + +func @tilezero(%arg0: memref, %i: index, %j: index) { + %1 = amx.tilezero : vector<16x16xi32> + amx.tilestore %arg0[%i, %j], %1 : memref, vector<16x16xi32> + return +} + +func @entry() { + %i0 = constant 0: i32 + %i1 = constant 1: i32 + %c0 = constant 0: index + %c1 = constant 1: index + %c3 = constant 3: index + %c19 = constant 19: index + + // Set up memory. + %a = alloc(%c19, %c19) : memref + scf.for %i = %c0 to %c19 step %c1 { + scf.for %j = %c0 to %c19 step %c1 { + store %i1, %a[%i, %j] : memref + } + } + + // Call kernel. + call @tilezero(%a, %c1, %c1) : (memref, index, index) -> () + + // Print and verify that the tilezero is correctly strided within + // the enveloping 19x19 buffer. + // + // CHECK: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // + scf.for %i = %c0 to %c19 step %c1 { + %av = vector.transfer_read %a[%i, %c0], %i0: memref, vector<19xi32> + vector.print %av : vector<19xi32> + } + + // Call kernel with different indices. + call @tilezero(%a, %c0, %c3) : (memref, index, index) -> () + + // Print and verify that the tilezero is again correctly strided + // within the enveloping 19x19 buffer. + // + // CHECK-NEXT: ( 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // + scf.for %i = %c0 to %c19 step %c1 { + %av = vector.transfer_read %a[%i, %c0], %i0: memref, vector<19xi32> + vector.print %av : vector<19xi32> + } + + // Release resources. + dealloc %a : memref + + return +} diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define void @target(i8* %0) +// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16) +// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* %0, i64 32, x86_amx %[[c]] +llvm.func @target(%ptr: !llvm.ptr) { + %c = llvm.mlir.constant(16 : i16) : i16 + %s = llvm.mlir.constant(32 : i64) : i64 + %0 = "llvm_amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>> + "llvm_amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr, i64, !llvm.array<16 x vector<16xbf16>>) -> () + llvm.return +} + diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -48,6 +48,7 @@ config.enable_bindings_python = @MLIR_BINDINGS_PYTHON_ENABLED@ config.mlir_integration_test_dir = "@MLIR_INTEGRATION_TEST_DIR@" config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@" +config.mlir_run_amx_tests = "@MLIR_RUN_AMX_TESTS@" config.mlir_run_avx512_tests = "@MLIR_RUN_AVX512_TESTS@" config.mlir_include_integration_tests = "@MLIR_INCLUDE_INTEGRATION_TESTS@" diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -2,6 +2,7 @@ // CHECK: Available Dialects: // CHECK-NEXT: acc // CHECK-NEXT: affine +// CHECK-NEXT: amx // CHECK-NEXT: arm_neon // CHECK-NEXT: arm_sve // CHECK-NEXT: async @@ -11,6 +12,7 @@ // CHECK-NEXT: gpu // CHECK-NEXT: linalg // CHECK-NEXT: llvm +// CHECK-NEXT: llvm_amx // CHECK-NEXT: llvm_arm_sve // CHECK-NEXT: math // CHECK-NEXT: nvvm