diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -494,7 +494,7 @@ operations. The lowering pass provides several options to control the kinds of optimizations that are allowed. It also provides options that enable the use of one or more architectural-specific dialects - (AVX512, ArmNeon, ArmSVE, etc.) in combination with the + (AMX, AVX512, ArmNeon, ArmSVE, etc.) in combination with the architectural-neutral vector dialect lowering. }]; @@ -509,6 +509,10 @@ "bool", /*default=*/"true", "Allows compiler to assume indices fit in 32-bit if that yields " "faster code">, + Option<"enableAMX", "enable-amx", + "bool", /*default=*/"false", + "Enables the use of AMX dialect while lowering the vector " + "dialect.">, Option<"enableAVX512", "enable-avx512", "bool", /*default=*/"false", "Enables the use of AVX512 dialect while lowering the vector " diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -23,7 +23,8 @@ struct LowerVectorToLLVMOptions { LowerVectorToLLVMOptions() : reassociateFPReductions(false), enableIndexOptimizations(true), - enableArmNeon(false), enableArmSVE(false), enableAVX512(false) {} + enableArmNeon(false), enableArmSVE(false), enableAMX(false), + enableAVX512(false) {} LowerVectorToLLVMOptions &setReassociateFPReductions(bool b) { reassociateFPReductions = b; @@ -41,6 +42,10 @@ enableArmSVE = b; return *this; } + LowerVectorToLLVMOptions &setEnableAMX(bool b) { + enableAMX = b; + return *this; + } LowerVectorToLLVMOptions &setEnableAVX512(bool b) { enableAVX512 = b; return *this; @@ -50,6 +55,7 @@ bool enableIndexOptimizations; bool enableArmNeon; bool enableArmSVE; + bool enableAMX; bool enableAVX512; }; diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -0,0 +1,292 @@ +//===-- AMX.td - AMX dialect operation definitions *- tablegen -*----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the AMX dialect. +// +// The Intel Advanced Matrix Extensions (AMX) provides a tile matrix +// multiply unit (TMUL), a tile control register (TILECFG), and eight +// tile registers TMM0 through TMM7 (TILEDATA). +// +// The AMX dialect provides a bridge between MLIR concepts, such as +// 2-d vector, operations, and memrefs, and the lower level details +// of Intel AMX, such as configuration setup, tile sizes, instructions, +// and tile release. +// +// Note that since configuration changes (implicit at dialect level) are +// costly, it is highly recommended to use the AMX dialect on same-shaped +// vectors, at least within a single method. +// +// https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html +// +//===----------------------------------------------------------------------===// + +#ifndef AMX +#define AMX + +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// AMX dialect definition. +//===----------------------------------------------------------------------===// + +def AMX_Dialect : Dialect { + let name = "amx"; + let cppNamespace = "::mlir::amx"; +} + +//===----------------------------------------------------------------------===// +// AMX Op and IntrOp definitions. +//===----------------------------------------------------------------------===// + +class AMX_Op traits = []> : + Op {} + +// The "internal" intrinsics are meant for compiler usage. +class AMX_IntrOp traits = []> : + LLVM_IntrOpBase; + +//===----------------------------------------------------------------------===// +// AMX Op definitions (user facing). +//===----------------------------------------------------------------------===// + +// +// Tile reset. +// + +def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> { + let summary = "tile zero operation"; + let description = [{ + Zeroes the destination tile, with the shape defined by the 2-dim + vector type of the result. This is eventually lowered into the + "tilezero" instruction with the corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tilezero : vector<16x16xbf16> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let results = (outs + VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); + let extraClassDeclaration = [{ + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "attr-dict `:` type($res)"; +} + +// +// Tile memory operations. +// + +def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> { + let summary = "tile load operation"; + let description = [{ + Loads a tile from memory defined by a base and indices, with the + shape defined by the 2-dim vector type of the result. This is + eventually lowered into the "tileloadd" instruction with the + corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tileload %arg0[%c0, %c0] : memref into vector<16x64xi8> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins Arg:$base, + Variadic:$indices); + let results = (outs + VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " + "type($base) `into` type($res)"; +} + +def TileStoreOp : AMX_Op<"tile_store"> { + let summary = "tile store operation"; + let description = [{ + Stores a tile to memory defined by a base and indices, with the + shape defined by the 2-dim vector type of the value. This is + eventually lowered into the "tilestored" instruction with the + corresponding tile configuration. + + Example: + + ```mlir + amx.tilestore %arg1[%c0, %c0], %0 : memref, vector<16x64xi8> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins Arg:$base, + Variadic:$indices, + VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val); + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return base().getType().cast(); + } + VectorType getVectorType() { + return val().getType().cast(); + } + }]; + let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " + "type($base) `,` type($val)"; +} + +// +// Tile arithmetic operations. +// + +def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> { + let summary = "tile multiplication operation (floating-point)"; + let description = [{ + Multiplies a "m x k" tile with a "k x n" tile and accumulates the results + into a "m x n" destination tile. Supports "f32 <- bf16 x bf16" (with + pairs of "bf16"). The operation is eventually lowered into the + "tdpbf16ps" instruction with the corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tilemulf %a, %b, %c + : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs, + VectorOfRankAndType<[2], [F32, BF16]>:$rhs, + VectorOfRankAndType<[2], [F32, BF16]>:$acc); + let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res); + let extraClassDeclaration = [{ + VectorType getLhsVectorType() { + return lhs().getType().cast(); + } + VectorType getRhsVectorType() { + return rhs().getType().cast(); + } + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " + "type($lhs) `,` type($rhs) `,` type($acc) "; +} + +def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> { + let summary = "tile multiplication operation (integer)"; + let description = [{ + Multiplies a "m x k" tile with a "k x n" tile and accumulates the results + into a "m x n" destination tile. Supports all "si32 <- s/ui8 x s/ui8" + combinations (4 bytes packed into dwords in the columns of both the + source operand tiles; the zero or sign extension is specified with + the attributes). The operation is eventually lowered into one of + the "tdpbssd", "tdpbsud", "tdpbusd", or "tdpbuud" instructions with + the corresponding tile configuration. + + Example: + + ```mlir + %0 = amx.tilemuli %a, %b, %c [true, true] + : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + ``` + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs, + VectorOfRankAndType<[2], [I32, I8]>:$rhs, + VectorOfRankAndType<[2], [I32, I8]>:$acc, + BoolArrayAttr:$zext); + let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res); + let extraClassDeclaration = [{ + VectorType getLhsVectorType() { + return lhs().getType().cast(); + } + VectorType getRhsVectorType() { + return rhs().getType().cast(); + } + VectorType getVectorType() { + return res().getType().cast(); + } + }]; + let assemblyFormat = "$lhs `,` $rhs `,` $acc $zext attr-dict `:` " + "type($lhs) `,` type($rhs) `,` type($acc) "; +} + +//===----------------------------------------------------------------------===// +// AMX IntrOp definitions (LLVM compiler facing). +//===----------------------------------------------------------------------===// + +// +// Tile reset. Parameters define the tile size. +// + +def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>, + Arguments<(ins LLVM_AnyInteger, LLVM_AnyInteger)>; + +// +// Tile memory operations. Parameters define the tile size, +// base address, and stride between consecutive rows for the +// memory operation. +// + +def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>, + Arguments<(ins LLVM_AnyInteger, + LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger)>; + +def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>, + Arguments<(ins LLVM_AnyInteger, + LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger, LLVM_Type)>; + +// +// Tile multiplication operations (series of dot products). Parameters +// define the tile sizes and source and destination tiles for the +// operation. Note that the prefix "tdp" stands for tile dot product. +// + +// Dot product of bf16 tiles into f32 tile. +def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>, + Arguments<(ins LLVM_AnyInteger, + LLVM_AnyInteger, + LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + +// Dot product of i8 tiles into i32 tile (with sign/sign extension). +def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>, + Arguments<(ins LLVM_AnyInteger, + LLVM_AnyInteger, + LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + +// Dot product of i8 tiles into i32 tile (with sign/zero extension). +def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>, + Arguments<(ins LLVM_AnyInteger, + LLVM_AnyInteger, + LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + +// Dot product of i8 tiles into i32 tile (with zero/sign extension). +def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>, + Arguments<(ins LLVM_AnyInteger, + LLVM_AnyInteger, + LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + +// Dot product of i8 tiles into i32 tile (with zero/zero extension). +def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>, + Arguments<(ins LLVM_AnyInteger, + LLVM_AnyInteger, + LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>; + +#endif // AMX diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h @@ -0,0 +1,26 @@ +//===- AMXDialect.h - MLIR Dialect for AMX ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for AMX in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMX_AMXDIALECT_H_ +#define MLIR_DIALECT_AMX_AMXDIALECT_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/AMX/AMXDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/AMX/AMX.h.inc" + +#endif // MLIR_DIALECT_AMX_AMXDIALECT_H_ diff --git a/mlir/include/mlir/Dialect/AMX/CMakeLists.txt b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(AMX amx) +add_mlir_doc(AMX -gen-dialect-doc AMX Dialects/) + +set(LLVM_TARGET_DEFINITIONS AMX.td) +mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRAMXConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMX/Transforms.h @@ -0,0 +1,29 @@ +//===- Transforms.h - AMX Dialect Transformation Entrypoints ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AMX_TRANSFORMS_H +#define MLIR_DIALECT_AMX_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class OwningRewritePatternList; + +/// Collect a set of patterns to lower AMX ops to ops that map to LLVM +/// intrinsics. +void populateAMXLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns); + +/// Configure the target to support lowering AMX ops to ops that map to LLVM +/// intrinsics. +void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_AMX_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Async) add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) +add_subdirectory(AMX) add_subdirectory(AVX512) add_subdirectory(Complex) add_subdirectory(DLTI) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -14,6 +14,7 @@ #ifndef MLIR_INITALLDIALECTS_H_ #define MLIR_INITALLDIALECTS_H_ +#include "mlir/Dialect/AMX/AMXDialect.h" #include "mlir/Dialect/AVX512/AVX512Dialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" @@ -50,6 +51,7 @@ // clang-format off registry.insertenableIndexOptimizations = options.enableIndexOptimizations; this->enableArmNeon = options.enableArmNeon; this->enableArmSVE = options.enableArmSVE; + this->enableAMX = options.enableAMX; this->enableAVX512 = options.enableAVX512; } // Override explicitly to allow conditional dialect dependence. @@ -43,6 +46,8 @@ registry.insert(); if (enableArmSVE) registry.insert(); + if (enableAMX) + registry.insert(); if (enableAVX512) registry.insert(); } @@ -102,6 +107,10 @@ }); populateArmSVEToLLVMConversionPatterns(converter, patterns); } + if (enableAMX) { + configureAMXLegalizeForExportTarget(target); + populateAMXLegalizeForLLVMExportPatterns(converter, patterns); + } if (enableAVX512) { configureAVX512LegalizeForExportTarget(target); populateAVX512LegalizeForLLVMExportPatterns(converter, patterns); diff --git a/mlir/lib/Dialect/AMX/CMakeLists.txt b/mlir/lib/Dialect/AMX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AMX/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -0,0 +1,106 @@ +//===- AMXDialect.cpp - MLIR AMX ops implementation -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the AMX dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +void amx::AMXDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/AMX/AMX.cpp.inc" + >(); +} + +/// Verify that AMX supports the implied tile shape. +static LogicalResult verifyTileSize(Operation *op, VectorType tp) { + const unsigned kMaxRows = 16; + const unsigned kBitsPerRow = 64 * 8; + unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth(); + if (tp.getDimSize(0) > kMaxRows) + return op->emitOpError("bad row height: ") << tp.getDimSize(0); + if (col > kBitsPerRow || col & 0x1f) + return op->emitOpError("bad column width: ") << (col >> 3); + return success(); +} + +/// Verify that AMX supports the multiplication. +static LogicalResult verifyMultShape(Operation *op, VectorType atp, + VectorType btp, VectorType ctp, + unsigned scale) { + unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale; + unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale; + unsigned cm = ctp.getDimSize(0), cn = ctp.getDimSize(1); + if (cm != am || cn != bn || ak != bk) + return op->emitOpError("bad mult shape: ") + << cm << " x " << cn << " x " << ak; + return success(); +} + +static LogicalResult verify(amx::TileZeroOp op) { + return verifyTileSize(op, op.getVectorType()); +} + +static LogicalResult verify(amx::TileLoadOp op) { + unsigned rank = op.getMemRefType().getRank(); + if (llvm::size(op.indices()) != rank) + return op.emitOpError("requires ") << rank << " indices"; + return verifyTileSize(op, op.getVectorType()); +} + +static LogicalResult verify(amx::TileStoreOp op) { + unsigned rank = op.getMemRefType().getRank(); + if (llvm::size(op.indices()) != rank) + return op.emitOpError("requires ") << rank << " indices"; + return verifyTileSize(op, op.getVectorType()); +} + +static LogicalResult verify(amx::TileMulFOp op) { + VectorType aType = op.getLhsVectorType(); + VectorType bType = op.getRhsVectorType(); + VectorType cType = op.getVectorType(); + if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || + failed(verifyTileSize(op, cType)) || + failed(verifyMultShape(op, aType, bType, cType, 1))) + return failure(); + Type ta = aType.getElementType(); + Type tb = bType.getElementType(); + Type tc = cType.getElementType(); + if (!ta.isBF16() || !tb.isBF16() || !tc.isF32()) + return op.emitOpError("unsupported type combination"); + return success(); +} + +static LogicalResult verify(amx::TileMulIOp op) { + if (op.zext().size() != 2) + return op.emitOpError("unexpected zext length"); + VectorType aType = op.getLhsVectorType(); + VectorType bType = op.getRhsVectorType(); + VectorType cType = op.getVectorType(); + if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || + failed(verifyTileSize(op, cType)) || + failed(verifyMultShape(op, aType, bType, cType, 2))) + return failure(); + Type ta = aType.getElementType(); + Type tb = bType.getElementType(); + Type tc = cType.getElementType(); + if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) + return op.emitOpError("unsupported type combination"); + return success(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/AMX/AMX.cpp.inc" diff --git a/mlir/lib/Dialect/AMX/IR/CMakeLists.txt b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AMX/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRAMX + AMXDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMX + + DEPENDS + MLIRAMXIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMIR + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AMX/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_dialect_library(MLIRAMXTransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRAMXConversionsIncGen + + LINK_LIBS PUBLIC + MLIRAMX + MLIRIR + MLIRLLVMIR + MLIRStandardToLLVM + ) diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,232 @@ +//===- LegalizeForLLVMExport.cpp - Prepare AMX for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMX/Transforms.h" + +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::amx; + +namespace { + +/// Maps the 2-dim vector shape to the two 16-bit tile sizes. The first +/// dimension directly translates into the number of rows of the tiles. +/// The second dimensions needs to be scaled by the number of bytes. +std::pair getTileSizes(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, + VectorType vType, Location loc) { + Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16); + unsigned width = vType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0)); + auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes); + return std::make_pair( + rewriter.create(loc, llvmInt16Type, mattr), + rewriter.create(loc, llvmInt16Type, nattr)); +} + +/// Verifies if the stride matches proper tile access. +LogicalResult verifyStride(MemRefType mType) { + if (mType.getRank() < 2) + return failure(); + int64_t last = mType.getRank() - 1; + int64_t offset; + SmallVector strides; + if (failed(getStridesAndOffset(mType, strides, offset)) || strides[last] != 1) + return failure(); + return success(); +} + +/// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer +/// shape may "envelop" the actual tile shape, and may be dynamically sized. +Value getStride(ConversionPatternRewriter &rewriter, + LLVMTypeConverter &typeConverter, MemRefType mType, Value base, + Location loc) { + assert(mType.getRank() >= 2); + int64_t last = mType.getRank() - 1; + Type llvmInt64Type = IntegerType::get(&typeConverter.getContext(), 64); + unsigned width = mType.getElementType().getIntOrFloatBitWidth(); + assert(llvm::isPowerOf2_64(width) && width >= 8); + unsigned bytes = width >> 3; + if (mType.isDynamicDim(last)) { + // Dynamic size needs code to compute the stride at runtime. + MemRefDescriptor memrefDescriptor(base); + auto attr = rewriter.getI64IntegerAttr(bytes); + Value scale = rewriter.create(loc, llvmInt64Type, attr); + return rewriter.create( + loc, llvmInt64Type, scale, memrefDescriptor.size(rewriter, loc, last)); + } + // Use direct constant for static size. + auto attr = rewriter.getI64IntegerAttr(mType.getDimSize(last) * bytes); + return rewriter.create(loc, llvmInt64Type, attr); +} + +/// Cast any pointer to the !llvm.ptr pointer type. +Value castPtr(ConversionPatternRewriter &rewriter, Location loc, Value ptr) { + auto i8Ptr = + LLVM::LLVMPointerType::get(IntegerType::get(ptr.getContext(), 8)); + return rewriter.create(loc, i8Ptr, ptr); +} + +struct TileZeroConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TileZeroOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + VectorType vType = op.getVectorType(); + // Determine m x n tile sizes. + std::pair tsz = + getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); + // Replace operation with intrinsic. + Type resType = typeConverter->convertType(vType); + rewriter.replaceOpWithNewOp(op, resType, tsz.first, + tsz.second); + return success(); + } +}; + +struct TileLoadConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TileLoadOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileLoadOp::Adaptor adaptor(operands); + MemRefType mType = op.getMemRefType(); + VectorType vType = op.getVectorType(); + // Determine m x n tile sizes. + std::pair tsz = + getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); + // Determine stride. + if (failed(verifyStride(mType))) + return failure(); + Value stride = getStride(rewriter, *getTypeConverter(), mType, + adaptor.base(), op.getLoc()); + // Replace operation with intrinsic. + Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), + adaptor.indices(), rewriter); + ptr = castPtr(rewriter, op.getLoc(), ptr); + Type resType = typeConverter->convertType(vType); + rewriter.replaceOpWithNewOp( + op, resType, tsz.first, tsz.second, ptr, stride); + return success(); + } +}; + +struct TileStoreConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TileStoreOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileStoreOp::Adaptor adaptor(operands); + MemRefType mType = op.getMemRefType(); + VectorType vType = op.getVectorType(); + // Determine m x n tile sizes. + std::pair tsz = + getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc()); + // Determine stride. + if (failed(verifyStride(mType))) + return failure(); + Value stride = getStride(rewriter, *getTypeConverter(), mType, + adaptor.base(), op.getLoc()); + // Replace operation with intrinsic. + Value ptr = getStridedElementPtr(op.getLoc(), mType, adaptor.base(), + adaptor.indices(), rewriter); + ptr = castPtr(rewriter, op.getLoc(), ptr); + rewriter.replaceOpWithNewOp( + op, tsz.first, tsz.second, ptr, stride, adaptor.val()); + return success(); + } +}; + +struct TileMulFConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TileMulFOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileMulFOp::Adaptor adaptor(operands); + VectorType aType = op.getLhsVectorType(); + VectorType bType = op.getRhsVectorType(); + VectorType cType = op.getVectorType(); + // Determine m x n x k tile sizes. + std::pair tsza = + getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); + std::pair tszb = + getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); + // Replace operation with intrinsic. + Type resType = typeConverter->convertType(cType); + rewriter.replaceOpWithNewOp( + op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), + adaptor.lhs(), adaptor.rhs()); + return success(); + } +}; + +struct TileMulIConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(TileMulIOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TileMulIOp::Adaptor adaptor(operands); + VectorType aType = op.getLhsVectorType(); + VectorType bType = op.getRhsVectorType(); + VectorType cType = op.getVectorType(); + // Determine m x n x k tile sizes. + std::pair tsza = + getTileSizes(rewriter, *getTypeConverter(), aType, op.getLoc()); + std::pair tszb = + getTileSizes(rewriter, *getTypeConverter(), bType, op.getLoc()); + // Replace operation with intrinsic. + Type resType = typeConverter->convertType(cType); + bool zexta = op.zext()[0].cast().getValue(); + bool zextb = op.zext()[1].cast().getValue(); + if (zexta && zextb) + rewriter.replaceOpWithNewOp( + op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), + adaptor.lhs(), adaptor.rhs()); + else if (zexta && !zextb) + rewriter.replaceOpWithNewOp( + op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), + adaptor.lhs(), adaptor.rhs()); + else if (!zexta && zextb) + rewriter.replaceOpWithNewOp( + op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), + adaptor.lhs(), adaptor.rhs()); + else + rewriter.replaceOpWithNewOp( + op, resType, tsza.first, tszb.second, tsza.second, adaptor.acc(), + adaptor.lhs(), adaptor.rhs()); + return success(); + } +}; + +} // namespace + +void mlir::populateAMXLegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, OwningRewritePatternList &patterns) { + // Registry::registerPatterns(converter, patterns); + patterns.insert(converter); +} + +void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) { + // Registry::configureTarget(target); + target.addLegalOp(); + target.addIllegalOp(); +} diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(ArmNeon) add_subdirectory(ArmSVE) add_subdirectory(Async) +add_subdirectory(AMX) add_subdirectory(AVX512) add_subdirectory(Complex) add_subdirectory(DLTI) diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -37,6 +37,7 @@ LINK_LIBS PUBLIC MLIRArmNeonToLLVMIRTranslation + MLIRAMXToLLVMIRTranslation MLIRAVX512ToLLVMIRTranslation MLIRLLVMArmSVEToLLVMIRTranslation MLIRLLVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.cpp @@ -0,0 +1,55 @@ +//===- AMXToLLVMIRTranslation.cpp - Translate AMX to LLVM IR --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the AMX dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsX86.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the AMX dialect to LLVM IR. +class AMXDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/AMX/AMXConversions.inc" + + return failure(); + } +}; +} // end namespace + +void mlir::registerAMXDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addDialectInterface(); +} + +void mlir::registerAMXDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerAMXDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/AMX/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRAMXToLLVMIRTranslation + AMXToLLVMIRTranslation.cpp + + DEPENDS + MLIRAMXConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRAMX + MLIRLLVMIR + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(ArmNeon) +add_subdirectory(AMX) add_subdirectory(AVX512) add_subdirectory(LLVMArmSVE) add_subdirectory(LLVMIR) diff --git a/mlir/test/CMakeLists.txt b/mlir/test/CMakeLists.txt --- a/mlir/test/CMakeLists.txt +++ b/mlir/test/CMakeLists.txt @@ -29,6 +29,7 @@ if (MLIR_INCLUDE_INTEGRATION_TESTS) set(INTEL_SDE_EXECUTABLE "" CACHE STRING "If set, arch-specific integration tests are run with Intel SDE.") + option(MLIR_RUN_AMX_TESTS "Run AMX tests.") option(MLIR_RUN_AVX512_TESTS "Run AVX512 tests.") # Passed to lit.site.cfg.py.in to set up the path where to find the libraries. set(MLIR_INTEGRATION_TEST_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/mlir/test/Dialect/AMX/invalid.mlir b/mlir/test/Dialect/AMX/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/AMX/invalid.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func @rowheight() { + // expected-error@+1 {{'amx.tile_zero' op bad row height: 17}} + %0 = amx.tile_zero : vector<17x16xbf16> +} + +// ----- + +func @colwidth() { + // expected-error@+1 {{'amx.tile_zero' op bad column width: 65}} + %0 = amx.tile_zero : vector<16x65xi8> +} + +// ----- + +func @col4bytemultiple() { + // expected-error@+1 {{'amx.tile_zero' op bad column width: 5}} + %0 = amx.tile_zero : vector<16x5xi8> +} + +// ----- + +func @memtilesize(%arg0: memref) { + %0 = constant 0 : index + // expected-error@+1 {{'amx.tile_load' op bad column width: 68}} + %1 = amx.tile_load %arg0[%0, %0] : memref into vector<16x17xf32> +} + +// ----- + +func @memindexsize(%arg0: memref) { + %0 = constant 0 : index + // expected-error@+1 {{'amx.tile_load' op requires 2 indices}} + %1 = amx.tile_load %arg0[%0] : memref into vector<16x16xf32> +} + +// ----- + +func @multsize() { + %0 = amx.tile_zero : vector<8x8xbf16> + %1 = amx.tile_zero : vector<8x8xbf16> + %2 = amx.tile_zero : vector<4x4xf32> + // expected-error@+1 {{'amx.tile_mulf' op bad mult shape: 4 x 4 x 4}} + %3 = amx.tile_mulf %0, %1, %2 : vector<8x8xbf16>, vector<8x8xbf16>, vector<4x4xf32> +} + +// ----- + +func @zextsize() { + %0 = amx.tile_zero : vector<8x8xi8> + %1 = amx.tile_zero : vector<8x8xi8> + %2 = amx.tile_zero : vector<8x8xi32> + // expected-error@+1 {{'amx.tile_muli' op unexpected zext length}} + %3 = amx.tile_muli %0, %1, %2 [true] : vector<8x8xi8>, vector<8x8xi8>, vector<8x8xi32> +} diff --git a/mlir/test/Dialect/AMX/legalize-for-llvm.mlir b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/AMX/legalize-for-llvm.mlir @@ -0,0 +1,45 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-amx" | mlir-opt | FileCheck %s + +// CHECK-LABEL: muli( +// CHECK: amx.tilezero +// CHECK: amx.tileloadd64 +// CHECK: amx.tileloadd64 +// CHECK: amx.tdpbuud +// CHECK: amx.tilestored64 +// CHECK: amx.tdpbssd +// CHECK: amx.tilestored64 +// CHECK: amx.tdpbusd +// CHECK: amx.tilestored64 +// CHECK: amx.tdpbsud +// CHECK: amx.tilestored64 +func @muli(%arg0: memref, %arg1: memref) { + %0 = constant 0 : index + %1 = amx.tile_zero : vector<16x64xi8> + %2 = amx.tile_load %arg0[%0, %0] : memref into vector<16x64xi8> + %3 = amx.tile_load %arg1[%0, %0] : memref into vector<16x16xi32> + %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tile_store %arg1[%0, %0], %4 : memref, vector<16x16xi32> + %5 = amx.tile_muli %1, %2, %3 [false, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tile_store %arg1[%0, %0], %5 : memref, vector<16x16xi32> + %6 = amx.tile_muli %1, %2, %3 [true, false] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tile_store %arg1[%0, %0], %6 : memref, vector<16x16xi32> + %7 = amx.tile_muli %1, %2, %3 [false, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tile_store %arg1[%0, %0], %7 : memref, vector<16x16xi32> + return +} + +// CHECK-LABEL: mulf( +// CHECK: amx.tilezero +// CHECK: amx.tileloadd64 +// CHECK: amx.tileloadd64 +// CHECK: amx.tdpbf16ps +// CHECK: amx.tilestored64 +func @mulf(%arg0: memref, %arg1: memref) { + %0 = constant 0 : index + %1 = amx.tile_zero : vector<16x32xbf16> + %2 = amx.tile_load %arg0[%0, %0] : memref into vector<16x32xbf16> + %3 = amx.tile_load %arg1[%0, %0] : memref into vector<16x16xf32> + %4 = amx.tile_mulf %1, %2, %3 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + amx.tile_store %arg1[%0, %0], %4 : memref, vector<16x16xf32> + return +} diff --git a/mlir/test/Dialect/AMX/roundtrip.mlir b/mlir/test/Dialect/AMX/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/AMX/roundtrip.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt -verify-diagnostics %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: tzero +// CHECK: amx.tile_zero : vector<16x16xbf16> +// CHECK amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %{{.*}} : memref, vector<16x16xbf16> +func @tzero(%arg0: memref) { + %0 = constant 0 : index + %1 = amx.tile_zero : vector<16x16xbf16> + amx.tile_store %arg0[%0, %0], %1 : memref, vector<16x16xbf16> + return +} + +// CHECK-LABEL: tmulf +// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x32xbf16> +// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x16xf32> +// CHECK: %[[m:.*]] = amx.tile_mulf %[[x]], %[[x]], %[[z]] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, vector<16x16xf32> +func @tmulf(%arg0: memref, %arg1: memref) { + %0 = constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref into vector<16x32xbf16> + %2 = amx.tile_load %arg1[%0, %0] : memref into vector<16x16xf32> + %3 = amx.tile_mulf %1, %1, %2 : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> + amx.tile_store %arg1[%0, %0], %3 : memref, vector<16x16xf32> + return +} + +// CHECK-LABEL: tmuli +// CHECK: %[[x:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> +// CHECK: %[[y:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x64xi8> +// CHECK: %[[z:.*]] = amx.tile_load %{{.*}}[%{{.*}}, %{{.*}}] : memref into vector<16x16xi32> +// CHECK: %[[m:.*]] = amx.tile_muli %[[x]], %[[y]], %[[z]] [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> +// CHECK: amx.tile_store %{{.*}}[%{{.*}}, %{{.*}}], %[[m]] : memref, vector<16x16xi32> +func @tmuli(%arg0: memref, %arg1: memref, %arg2: memref) { + %0 = constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref into vector<16x64xi8> + %2 = amx.tile_load %arg1[%0, %0] : memref into vector<16x64xi8> + %3 = amx.tile_load %arg2[%0, %0] : memref into vector<16x16xi32> + %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> + amx.tile_store %arg2[%0, %0], %4 : memref, vector<16x16xi32> + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/lit.local.cfg @@ -0,0 +1,15 @@ +import sys + +# AMX tests must be enabled via build flag. +if config.mlir_run_amx_tests != 'ON': + config.unsupported = True + +# No JIT on win32. +if sys.platform == 'win32': + config.unsupported = True + +if config.intel_sde_executable: + # Run test in emulator (Intel SDE): AMX needs Sapphire Rapids CPU. + config.substitutions.append(('%lli', config.intel_sde_executable + ' -spr -- lli')) +else: + config.substitutions.append(('%lli', 'lli')) diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-mulf.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AMX. + +// Multiply into zeroed destination. +func @kernel1(%arg0: memref<2x4xbf16>, + %arg1: memref<2x4xbf16>, + %arg2: memref<2x2xf32>) { + %0 = constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %3 = amx.tile_zero : vector<2x2xf32> + %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> + amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> + return +} + +// Multiply and update into destination. +func @kernel2(%arg0: memref<2x4xbf16>, + %arg1: memref<2x4xbf16>, + %arg2: memref<2x2xf32>) { + %0 = constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %2 = amx.tile_load %arg1[%0, %0] : memref<2x4xbf16> into vector<2x4xbf16> + %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xf32> into vector<2x2xf32> + %4 = amx.tile_mulf %1, %2, %3 : vector<2x4xbf16>, vector<2x4xbf16>, vector<2x2xf32> + amx.tile_store %arg2[%0, %0], %4 : memref<2x2xf32>, vector<2x2xf32> + return +} + +func @entry() { + %f0 = constant 0.0: f32 + %c0 = constant 0: index + %c1 = constant 1: index + %c2 = constant 2: index + + // Set up memory. + %a = alloc() : memref<2x4xbf16> + %b = alloc() : memref<2x4xbf16> + %c = alloc() : memref<2x2xf32> + + %0 = std.constant dense<[[1.0, 2.0, 3.0, 4.0 ], + [5.0, 6.0, 7.0, 8.0 ]]> : vector<2x4xbf16> + vector.transfer_write %0, %a[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16> + %1 = std.constant dense<[[ 9.0, 10.0, 11.0, 12.0 ], + [13.0, 14.0, 15.0, 16.0 ]]> : vector<2x4xbf16> + vector.transfer_write %1, %b[%c0, %c0] : vector<2x4xbf16>, memref<2x4xbf16> + + // Call kernel. + call @kernel1(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> () + + // Print and verify. + // + // CHECK: ( 124, 144 ) + // CHECK-NEXT: ( 308, 360 ) + scf.for %i = %c0 to %c2 step %c1 { + %av = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32> + vector.print %av : vector<2xf32> + } + + // Call kernel. + call @kernel2(%a, %b, %c) : (memref<2x4xbf16>, memref<2x4xbf16>, memref<2x2xf32>) -> () + + // Print and verify. + // + // CHECK-NEXT: ( 248, 288 ) + // CHECK-NEXT: ( 616, 720 ) + // + scf.for %i = %c0 to %c2 step %c1 { + %cv = vector.transfer_read %c[%i, %c0], %f0: memref<2x2xf32>, vector<2xf32> + vector.print %cv : vector<2xf32> + } + + // Release resources. + dealloc %a : memref<2x4xbf16> + dealloc %b : memref<2x4xbf16> + dealloc %c : memref<2x2xf32> + + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-muli.mlir @@ -0,0 +1,83 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AMX. + +// Multiply into zeroed destination. +func @kernel1(%arg0: memref<2x8xi8>, + %arg1: memref<2x8xi8>, + %arg2: memref<2x2xi32>) { + %0 = constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %3 = amx.tile_zero : vector<2x2xi32> + %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> + return +} + +// Multiply and update into destination. +func @kernel2(%arg0: memref<2x8xi8>, + %arg1: memref<2x8xi8>, + %arg2: memref<2x2xi32>) { + %0 = constant 0 : index + %1 = amx.tile_load %arg0[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %2 = amx.tile_load %arg1[%0, %0] : memref<2x8xi8> into vector<2x8xi8> + %3 = amx.tile_load %arg2[%0, %0] : memref<2x2xi32> into vector<2x2xi32> + %4 = amx.tile_muli %1, %2, %3 [true, true] : vector<2x8xi8>, vector<2x8xi8>, vector<2x2xi32> + amx.tile_store %arg2[%0, %0], %4 : memref<2x2xi32>, vector<2x2xi32> + return +} + +func @entry() { + %i0 = constant 0: i32 + %c0 = constant 0: index + %c1 = constant 1: index + %c2 = constant 2: index + + // Set up memory. + %a = alloc() : memref<2x8xi8> + %b = alloc() : memref<2x8xi8> + %c = alloc() : memref<2x2xi32> + + %0 = std.constant dense<[[1 , 2, 3 , 4 , 5, 6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]]> : vector<2x8xi8> + vector.transfer_write %0, %a[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8> + %1 = std.constant dense<[[17, 18, 19, 20, 21, 22, 23, 24], + [25, 26, 27, 28, 29, 30, 31, 32]]> : vector<2x8xi8> + vector.transfer_write %1, %b[%c0, %c0] : vector<2x8xi8>, memref<2x8xi8> + + // Call kernel. + call @kernel1(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> () + + // Print and verify. + // + // CHECK: ( 884, 1028 ) + // CHECK-NEXT: ( 2324, 2724 ) + scf.for %i = %c0 to %c2 step %c1 { + %av = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32> + vector.print %av : vector<2xi32> + } + + // Call kernel. + call @kernel2(%a, %b, %c) : (memref<2x8xi8>, memref<2x8xi8>, memref<2x2xi32>) -> () + + // Print and verify. + // + // CHECK-NEXT: ( 1768, 2056 ) + // CHECK-NEXT: ( 4648, 5448 ) + // + scf.for %i = %c0 to %c2 step %c1 { + %cv = vector.transfer_read %c[%i, %c0], %i0: memref<2x2xi32>, vector<2xi32> + vector.print %cv : vector<2xi32> + } + + // Release resources. + dealloc %a : memref<2x8xi8> + dealloc %b : memref<2x8xi8> + dealloc %c : memref<2x2xi32> + + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AMX/test-tilezero.mlir @@ -0,0 +1,96 @@ +// RUN: mlir-opt %s -convert-vector-to-scf -lower-affine -convert-scf-to-std -convert-vector-to-llvm="enable-amx" -convert-std-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="+amx-tile,+amx-int8,+amx-bf16" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +// Note: To run this test, your CPU must support AMX. + +func @tilezero(%arg0: memref, %i: index, %j: index) { + %1 = amx.tile_zero : vector<16x16xi32> + amx.tile_store %arg0[%i, %j], %1 : memref, vector<16x16xi32> + return +} + +func @entry() { + %i0 = constant 0: i32 + %i1 = constant 1: i32 + %c0 = constant 0: index + %c1 = constant 1: index + %c3 = constant 3: index + %c19 = constant 19: index + + // Set up memory. + %a = alloc(%c19, %c19) : memref + scf.for %i = %c0 to %c19 step %c1 { + scf.for %j = %c0 to %c19 step %c1 { + store %i1, %a[%i, %j] : memref + } + } + + // Call kernel. + call @tilezero(%a, %c1, %c1) : (memref, index, index) -> () + + // Print and verify that the tilezero is correctly strided within + // the enveloping 19x19 buffer. + // + // CHECK: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // + scf.for %i = %c0 to %c19 step %c1 { + %av = vector.transfer_read %a[%i, %c0], %i0: memref, vector<19xi32> + vector.print %av : vector<19xi32> + } + + // Call kernel with different indices. + call @tilezero(%a, %c0, %c3) : (memref, index, index) -> () + + // Print and verify that the tilezero is again correctly strided + // within the enveloping 19x19 buffer. + // + // CHECK-NEXT: ( 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // CHECK-NEXT: ( 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 ) + // + scf.for %i = %c0 to %c19 step %c1 { + %av = vector.transfer_read %a[%i, %c0], %i0: memref, vector<19xi32> + vector.print %av : vector<19xi32> + } + + // Release resources. + dealloc %a : memref + + return +} diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Target/LLVMIR/amx.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-translate --mlir-to-llvmir %s | FileCheck %s + +// CHECK-LABEL: define void @target(i8* %0) +// CHECK: %[[c:.*]] = call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 16) +// CHECK: call void @llvm.x86.tilestored64.internal(i16 16, i16 16, i8* %0, i64 32, x86_amx %[[c]] +llvm.func @target(%ptr: !llvm.ptr) { + %c = llvm.mlir.constant(16 : i16) : i16 + %s = llvm.mlir.constant(32 : i64) : i64 + %0 = "amx.tilezero"(%c, %c) : (i16, i16) -> !llvm.array<16 x vector<16xbf16>> + "amx.tilestored64"(%c, %c, %ptr, %s, %0) : (i16, i16, !llvm.ptr, i64, !llvm.array<16 x vector<16xbf16>>) -> () + llvm.return +} + diff --git a/mlir/test/lit.site.cfg.py.in b/mlir/test/lit.site.cfg.py.in --- a/mlir/test/lit.site.cfg.py.in +++ b/mlir/test/lit.site.cfg.py.in @@ -48,6 +48,7 @@ config.enable_bindings_python = @MLIR_BINDINGS_PYTHON_ENABLED@ config.mlir_integration_test_dir = "@MLIR_INTEGRATION_TEST_DIR@" config.intel_sde_executable = "@INTEL_SDE_EXECUTABLE@" +config.mlir_run_amx_tests = "@MLIR_RUN_AMX_TESTS@" config.mlir_run_avx512_tests = "@MLIR_RUN_AVX512_TESTS@" config.mlir_include_integration_tests = "@MLIR_INCLUDE_INTEGRATION_TESTS@" diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -2,6 +2,7 @@ // CHECK: Available Dialects: // CHECK-NEXT: acc // CHECK-NEXT: affine +// CHECK-NEXT: amx // CHECK-NEXT: arm_neon // CHECK-NEXT: arm_sve // CHECK-NEXT: async