diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -976,6 +976,10 @@ "bool", /*default=*/"false", "Enables the use of ArmNeon dialect while lowering the vector " "dialect.">, + Option<"armSME", "enable-arm-sme", + "bool", /*default=*/"false", + "Enables the use of ArmSME dialect while lowering the vector " + "dialect.">, Option<"armSVE", "enable-arm-sve", "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " diff --git a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h --- a/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h +++ b/mlir/include/mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h @@ -38,6 +38,10 @@ armNeon = b; return *this; } + LowerVectorToLLVMOptions &enableArmSME(bool b = true) { + armSME = b; + return *this; + } LowerVectorToLLVMOptions &enableArmSVE(bool b = true) { armSVE = b; return *this; @@ -54,6 +58,7 @@ bool reassociateFPReductions{false}; bool force32BitVectorIndices{true}; bool armNeon{false}; + bool armSME{false}; bool armSVE{false}; bool amx{false}; bool x86Vector{false}; diff --git a/mlir/include/mlir/Dialect/ArmSME/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/ArmSME.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/ArmSME.td @@ -0,0 +1,225 @@ +//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the ArmSME dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef ARMSME_OPS +#define ARMSME_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// ArmSME dialect definition +//===----------------------------------------------------------------------===// + +def ArmSME_Dialect : Dialect { + let name = "arm_sme"; + let cppNamespace = "::mlir::arm_sme"; + let summary = "Basic dialect to target Arm SME architectures"; + let description = [{ + This dialect contains the definitions necessary to target specific Arm SME + scalable vector operations. + + Source: + https://developer.arm.com/documentation/ddi0616/aa + }]; + let dependentDialects = ["arm_sve::ArmSVEDialect"]; +} + +//===----------------------------------------------------------------------===// +// ArmSME Tile enum definitions +//===----------------------------------------------------------------------===// + +def ZA0D : I32EnumAttrCase<"za0d", 1>; +def ZA1D : I32EnumAttrCase<"za1d", 2>; +def ZA2D : I32EnumAttrCase<"za2d", 4>; +def ZA3D : I32EnumAttrCase<"za3d", 8>; +def ZA4D : I32EnumAttrCase<"za4d", 16>; +def ZA5D : I32EnumAttrCase<"za5d", 32>; +def ZA6D : I32EnumAttrCase<"za6d", 64>; +def ZA7D : I32EnumAttrCase<"za7d", 128>; +def ZA0S : I32EnumAttrCase<"za0s", 17>; // = ZA0D | ZA4D +def ZA1S : I32EnumAttrCase<"za1s", 34>; // = ZA1D | ZA5D +def ZA2S : I32EnumAttrCase<"za2s", 68>; // = ZA2D | ZA6D +def ZA3S : I32EnumAttrCase<"za3s", 136>; // = ZA3D | ZA7D + +def ArmSME_TileAttr : I32EnumAttr<"TileEnum", + "Enum representation the SME matrix tiles", + [ZA0D, ZA1D, ZA2D, ZA3D, ZA4D, ZA5D, ZA6D, + ZA7D, ZA0S, ZA1S, ZA2S, ZA3S]> { + let cppNamespace = "::mlir::arm_sme"; +} + +//===----------------------------------------------------------------------===// +// ArmSME op definitions +//===----------------------------------------------------------------------===// + +class ArmSME_Op traits = []> : + Op {} + +def Predicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>; +def SMEVector : ScalableVectorOfLengthAndType< + [16, 8, 4, 2], [SI8, SI16, UI8, UI16, BF16, F16, F32, F64]>; +def TileList : TypedArrayAttrBase; + +class MOPOpBase + : ArmSME_Op]> { + let arguments = (ins + ArmSME_TileAttr:$tile, + Predicate:$lhsPred, + Predicate:$rhsPred, + SMEVector:$lhs, + SMEVector:$rhs + ); + let extraClassDeclaration = [{ + bool isAccumulate() { return }] # accumulate # [{; +} + bool isSubtract() { return }] # !not(accumulate) # [{; } + bool isWidening() { + auto elTy = this->getLhs().getType().cast().getElementType(); + if (elTy.isF32() || elTy.isF64()) + return false; + else + return true; + } + }]; + let assemblyFormat =[{ $tile`,` $lhsPred`,` $rhsPred`,` $lhs`,` $rhs attr-dict + `:` type($lhsPred)`,` type($rhsPred)`,` type($lhs)`,` type($rhs) }]; + let hasVerifier = 1; +} + +def MopaOp : MOPOpBase<"mopa", /*accumulate=*/true> { + let summary = "Vector-vector outer product and accumulate op"; + let description = [{ + MOPA: Outer product product accumulate. + + This function maps to the *MOPA instructions, it takes scalable vector + operands which will be used to compute the outer product matrix. Two + masking predicate operands for each of the floating point operands will also + be provided, such that elements marked inactive by the predicate will not + update the corresponding row/column in the result matrix tile, specified by + the attribute. + + Theere are two variations of MOPA instructions - widening and non-widening. + + Non-widening MOPAs will take a 1D vector of f32 or f64 as input and + accumulate into 32b and 64b tiles respectively (za*s and za*d). + + Widening MOPAs will pack two f16/bf16 or four (signed or unsigned) i8 + elements into a single 32b lane of the vector and accumulate into 32b tiles + (za*s); Or it will pack four (signed or unsigned) i16 elements into a 64b + lane and accumulate into 64b tiles (za*d). Hence widening MOPAs will take + 2D scalable vectors as input, i.e. `<[4x2]xf16>, <[2x4]xsi16>, <[4x4]xsi8>` + + Example: Assume `vscale == 2`, `%lhs = %rhs = <1, 2, 3, 4> : <[2]xfp64>`, + `%lhsPred = %rhsPred = `, then: + ``` + arm_sme.zero za0d + arm_sme.fmopa za0d, %lhsPred, %rhsPred, %lhs, %rhs + : vector<[2]xi1>, vector<[2]xf64> + ``` + + Would result in za0d containing: + ``` + 1 2 0 4 + 2 4 0 8 + 0 0 0 0 + 4 8 0 16 + ``` + }]; +} + +def MopsOp : MOPOpBase<"mops", /*accumulate=*/false> { + let summary = "Vector-vector outer product and subtract op"; + let description = [{ + FMOPA: Outer product product accumulate. + + This function maps to the *MOPS instructions, it functions similarily to + the *MOPA instructions, but differs in that it subtracts the outer product + computed from the input vectors from the existing values within the tile + provided. + }]; +} + +def ZeroOp : ArmSME_Op<"zero"> { + let summary = "Zeroes a list of SME matrix tiles"; + let description = [{ + ZERO: Sets the contents of specified matrix tiles to zero"; + + Source: + https: // developer.arm.com/documentation/ddi0616/aa + }]; + let arguments = (ins TileList:$tiles); + let assemblyFormat = "custom($tiles) attr-dict"; +} + +//===----------------------------------------------------------------------===// +// ArmSME Intrinsic op definitions +//===----------------------------------------------------------------------===// + +class ArmSME_IntrOverloadedOp overloadOperands = []> + : LLVM_IntrOpBase< + /*Dialect dialect=*/ArmSME_Dialect, + /*string opName=*/"intr." #mnemonic, + /*string enumName=*/"aarch64_sme_" #!subst(".", "_", mnemonic), + /*list overloadedResults=*/[], + /*list overloadedOperands=*/overloadOperands, + /*list traits=*/[], + /*int numResults=*/0>; + +def ZeroIntrOp : ArmSME_IntrOverloadedOp<"zero">, + Arguments<(ins Arg)>; + +class ArmSME_IntrMopOverloadedOp + : ArmSME_IntrOverloadedOp, + Arguments<(ins Arg, + Arg, + Arg, + Arg, + Arg)>; + +def FmopaIntrOp : ArmSME_IntrMopOverloadedOp<"mopa">; +def FmopsIntrOp : ArmSME_IntrMopOverloadedOp<"mops">; + +def FmopaWidenIntrOp : ArmSME_IntrMopOverloadedOp<"mopa_wide">; +def FmopsWidenIntrOp : ArmSME_IntrMopOverloadedOp<"mops_wide">; + +def SmopaIntrOp : ArmSME_IntrMopOverloadedOp<"smopa_wide">; +def SmopsIntrOp : ArmSME_IntrMopOverloadedOp<"smops_wide">; +def UmopaIntrOp : ArmSME_IntrMopOverloadedOp<"umopa_wide">; +def UmopsIntrOp : ArmSME_IntrMopOverloadedOp<"umops_wide">; +def SUmopaIntrOp : ArmSME_IntrMopOverloadedOp<"sumopa_wide">; +def SUmopsIntrOp : ArmSME_IntrMopOverloadedOp<"sumops_wide">; +def USmopaIntrOp : ArmSME_IntrMopOverloadedOp<"usmopa_wide">; +def USmopsIntrOp : ArmSME_IntrMopOverloadedOp<"usmops_wide">; + +class ArmSME_IntrLoadStoreOverloadedOp + : ArmSME_IntrOverloadedOp, + Arguments<(ins Arg, + Arg, + Arg, Arg)>; + +// Loads +def LoadHorizontalBytesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1b_horiz">; +def LoadHorizontalHalfsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1h_horiz">; +def LoadHorizontalWordsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1w_horiz">; +def LoadHorizontalDoublesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1d_horiz">; +def LoadHorizontalQuadsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"ld1q_horiz">; + +// Stores +def StoreVerticalBytesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1b_vert">; +def StoreVerticalHalfsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1h_vert">; +def StoreVerticalWordsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1w_vert">; +def StoreVerticalDoublesIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1d_vert">; +def StoreVerticalQuadsIntrOp : ArmSME_IntrLoadStoreOverloadedOp<"st1q_vert">; + +#endif // ARMSME_OPS diff --git a/mlir/include/mlir/Dialect/ArmSME/ArmSMEDialect.h b/mlir/include/mlir/Dialect/ArmSME/ArmSMEDialect.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/ArmSMEDialect.h @@ -0,0 +1,27 @@ +//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for ArmSME in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H +#define MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/ArmSME/ArmSMEDialect.h.inc" +#include "mlir/Dialect/ArmSME/ArmSMEEnums.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/ArmSME.h.inc" + +#endif // MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H diff --git a/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1,8 @@ +add_mlir_dialect(ArmSME arm_sme ArmSME) +add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme) + +set(LLVM_TARGET_DEFINITIONS ArmSME.td) +mlir_tablegen(ArmSMEEnums.h.inc -gen-enum-decls) +mlir_tablegen(ArmSMEEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRArmSMEConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms.h b/mlir/include/mlir/Dialect/ArmSME/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms.h @@ -0,0 +1,29 @@ +//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H +#define MLIR_DIALECT_ARMSME_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class LLVMTypeConverter; +class RewritePatternSet; + +/// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM +/// intrinsics. +void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); + +/// Configure the target to support lowering ArmSME ops to ops that map to LLVM +/// intrinsics. +void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(Affine) add_subdirectory(Arith) add_subdirectory(ArmNeon) +add_subdirectory(ArmSME) add_subdirectory(ArmSVE) add_subdirectory(Async) add_subdirectory(Bufferization) diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -21,6 +21,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/ArmSMEDialect.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -103,6 +104,7 @@ pdl_interp::PDLInterpDialect, quant::QuantizationDialect, spirv::SPIRVDialect, + arm_sme::ArmSMEDialect, arm_sve::ArmSVEDialect, vector::VectorDialect, NVVM::NVVMDialect, diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h --- a/mlir/include/mlir/Target/LLVMIR/Dialect/All.h +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/All.h @@ -16,6 +16,7 @@ #include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" @@ -32,6 +33,7 @@ static inline void registerAllToLLVMIRTranslations(DialectRegistry ®istry) { registerArmNeonDialectTranslation(registry); registerAMXDialectTranslation(registry); + registerArmSMEDialectTranslation(registry); registerArmSVEDialectTranslation(registry); registerLLVMDialectTranslation(registry); registerNVVMDialectTranslation(registry); diff --git a/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h @@ -0,0 +1,31 @@ +//=======- ArmSMEToLLVMIRTranslation.h - ArmSME to LLVM IR --*- C++ -*-=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This provides registration calls for ArmSME dialect to LLVM IR translation. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H +#define MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H + +namespace mlir { + +class DialectRegistry; +class MLIRContext; + +/// Register the ArmSME dialect and the translation from it to the LLVM IR in +/// the given registry; +void registerArmSMEDialectTranslation(DialectRegistry ®istry); + +/// Register the ArmSME dialect and the translation from it in the registry +/// associated with the given context. +void registerArmSMEDialectTranslation(MLIRContext &context); + +} // namespace mlir + +#endif // MLIR_TARGET_LLVMIR_DIALECT_ARMSME_ARMSMETOLLVMIRTRANSLATION_H diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -15,6 +15,8 @@ LINK_LIBS PUBLIC MLIRArithDialect MLIRArmNeonDialect + MLIRArmSMEDialect + MLIRArmSMETransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,8 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/ArmSMEDialect.h" +#include "mlir/Dialect/ArmSME/Transforms.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -40,6 +42,7 @@ this->reassociateFPReductions = options.reassociateFPReductions; this->force32BitVectorIndices = options.force32BitVectorIndices; this->armNeon = options.armNeon; + this->armSME = options.armSME; this->armSVE = options.armSVE; this->amx = options.amx; this->x86Vector = options.x86Vector; @@ -51,7 +54,10 @@ registry.insert(); if (armNeon) registry.insert(); - if (armSVE) + if (armSME) { + registry.insert(); + registry.insert(); + } else if (armSVE) registry.insert(); if (amx) registry.insert(); @@ -99,7 +105,12 @@ // can be translated to LLVM IR so there is no conversion necessary. target.addLegalDialect(); } - if (armSVE) { + if (armSME) { + configureArmSMELegalizeForExportTarget(target); + populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + configureArmSVELegalizeForExportTarget(target); + populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); + } else if (armSVE) { configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } diff --git a/mlir/lib/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1,2 @@ +add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/IR/ArmSMEDialect.cpp @@ -0,0 +1,176 @@ +//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ArmSME dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/ArmSMEDialect.h" +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +//===----------------------------------------------------------------------===// +// Custom printer/parser for list of SME Tile enums +//===----------------------------------------------------------------------===// + +namespace { + +void printTileEnumList(OpAsmPrinter &printer, Operation *op, ArrayAttr tiles) { + (void)op; + llvm::interleaveComma(tiles, printer, [&](Attribute elem) { + auto tile = elem.cast().getValue(); + printer << stringifyTileEnum(tile); + }); +} + +ParseResult parseTileEnumList(OpAsmParser &parser, ArrayAttr &tiles) { + SmallVector tileStorage; + auto parseTileEnumAttr = [&]() -> ParseResult { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return failure(); + Optional maybeTile = symbolizeTileEnum(keyword); + if (!maybeTile) + return parser.emitError(parser.getCurrentLocation(), + "invalid SME tile name"); + auto tileAttr = TileEnumAttr::get(parser.getContext(), *maybeTile); + tileStorage.push_back(tileAttr); + return success(); + }; + auto loc = parser.getCurrentLocation(); + if (parser.parseCommaSeparatedList(parseTileEnumAttr)) + return parser.emitError(loc, "expected list of SME tiles"); + tiles = ArrayAttr::get(parser.getContext(), tileStorage); + return success(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Tablegen Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/ArmSMEDialect.cpp.inc" +#include "mlir/Dialect/ArmSME/ArmSMEEnums.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/ArmSME.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/ArmSME/ArmSMETypes.cpp.inc" + +void ArmSMEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSME/ArmSME.cpp.inc" + >(); +} + +//===----------------------------------------------------------------------===// +// Custom Verifier +//===----------------------------------------------------------------------===// + +/// Additional verification of MOP ops +static LogicalResult verifyMOP(TileEnum tile, Type lhsTy, Type rhsTy, + bool isWidening, Operation *op) { + auto lhsVecTy = lhsTy.cast(); + auto rhsVecTy = lhsTy.cast(); + if (lhsVecTy.getNumScalableDims() != lhsVecTy.getRank() || + rhsVecTy.getNumScalableDims() != rhsVecTy.getRank()) + return op->emitOpError("expecting all dimensions to be scalable"); + Type lhsElTy = lhsVecTy.getElementType(); + Type rhsElTy = rhsVecTy.getElementType(); + + const llvm::DenseSet b32Tiles( + {TileEnum::za0s, TileEnum::za1s, TileEnum::za2s, TileEnum::za3s}); + const llvm::DenseSet b64Tiles( + {TileEnum::za0d, TileEnum::za1d, TileEnum::za2d, TileEnum::za3d, + TileEnum::za4d, TileEnum::za5d, TileEnum::za6d, TileEnum::za7d}); + // Verify element type width + unsigned elWidth = lhsElTy.getIntOrFloatBitWidth(); + if (elWidth != rhsElTy.getIntOrFloatBitWidth()) + return op->emitOpError("invalid vector element type"); + + // Verify valid vector unit length: + constexpr unsigned sveUnitVecWidth = 128; + if (elWidth * lhsVecTy.getNumElements() != sveUnitVecWidth) + return op->emitOpError( + "expected operand vector length to be multiples of 128 bits"); + + if (isWidening) { + // Check element types - integer types can be either signed or unsigned for + // both operands, otherwise the types must match. + if (lhsVecTy.getRank() != 2) + return op->emitOpError( + "expecting widening MOP ops to have 2D vector operands"); + auto lhsShape = lhsVecTy.getShape(); + if (lhsElTy.isBF16() || lhsElTy.isF16()) { + // widening fmop*/bfmop* + if (!b32Tiles.contains(tile)) + return op->emitOpError( + "expecting 16b float types to accumulate into 32b tiles"); + if (rhsElTy != lhsElTy) + return op->emitOpError("mismatching lhs and rhs vector element types"); + if (lhsShape[0] != 4) + return op->emitOpError("invalid vector shape for widening MOP"); + } else if (lhsElTy.isInteger(8)) { + // 8->32-bit smop*/umop*/sumop*/usmop* + if (!b32Tiles.contains(tile)) + return op->emitOpError( + "expecting 8b int types to accumulate into 32b tiles"); + if (!rhsElTy.isInteger(8)) + return op->emitOpError( + "expecting lhs and rhs element types to be of same integer width"); + if (lhsShape[0] != 4) + return op->emitOpError("invalid vector shape for widening MOP"); + } else if (lhsElTy.isInteger(16)) { + // 16->64-bit smop*/umop*/sumop*/usmop* + if (!b64Tiles.contains(tile)) + return op->emitOpError( + "expecting 16b int types to accumulate into 64b tiles"); + if (!rhsElTy.isInteger(16)) + return op->emitOpError( + "expecting lhs and rhs element types to be of same integer width"); + if (lhsShape[0] != 2) + return op->emitOpError("invalid vector shape for widening MOP"); + } + return success(); + } + + // non-widening fmop* + if (lhsVecTy != rhsVecTy) + return op->emitOpError("expecting lhs and rhs operands to have the same " + "type for non-widening MOP"); + if (lhsVecTy.getRank() != 1) + return op->emitOpError("expecting 1D vector operands for non-widening MOP"); + if (lhsVecTy.isF32() && !b32Tiles.contains(tile)) + return op->emitOpError("expecting f32 MOP to accumulate into 32b tiles"); + if (lhsVecTy.isF64() && !b64Tiles.contains(tile)) + return op->emitOpError("expecting f64 MOP to accumulate into 64b tiles"); + + return success(); +} + +LogicalResult MopaOp::verify() { + return verifyMOP(getTile(), getLhs().getType(), getRhs().getType(), + isWidening(), getOperation()); +} + +LogicalResult MopsOp::verify() { + return verifyMOP(getTile(), getLhs().getType(), getRhs().getType(), + isWidening(), getOperation()); +} diff --git a/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRArmSMEDialect + ArmSMEDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME + + DEPENDS + MLIRArmSMEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRSideEffectInterfaces + ) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRArmSMETransforms + LegalizeForLLVMExport.cpp + + DEPENDS + MLIRArmSMEConversionsIncGen + + LINK_LIBS PUBLIC + MLIRArmSMEDialect + MLIRFuncDialect + MLIRIR + MLIRLLVMCommonConversion + MLIRLLVMDialect + ) diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,148 @@ +//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSME/ArmSMEDialect.h" +#include "mlir/Dialect/ArmSME/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/SmallVector.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +template +class MOPLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(MOPTy op, typename MOPTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + (void)adaptor; + Location loc = op.getLoc(); + SmallVector operands; + auto tile = static_cast(op.getTile()); + // Operands: + // Tile number + operands.push_back( + rewriter.create(loc, rewriter.getI32Type(), tile) + .getResult()); + if (op.isWidening()) { + return op.emitOpError("lowering of widening SME outer product " + "instructions not yet supported"); + } + // Predicates + operands.push_back(op.getLhsPred()); + operands.push_back(op.getRhsPred()); + // Input vectors + operands.push_back(op.getLhs()); + operands.push_back(op.getRhs()); + Type lhsElTy = + op.getLhs().getType().template cast().getElementType(); + Type rhsElTy = + op.getRhs().getType().template cast().getElementType(); + ValueRange operandsRange(operands); + switch (op.isAccumulate()) { + case true: + // MOPA ops + if (lhsElTy.isF32() || lhsElTy.isF64()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isF16() || lhsElTy.isBF16()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isSignedInteger() && rhsElTy.isSignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isSignedInteger() && rhsElTy.isUnsignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isUnsignedInteger() && rhsElTy.isSignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isUnsignedInteger() && rhsElTy.isUnsignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else + return op.emitOpError("unsupported SME vector element type"); + break; + case false: + // MOPS ops + if (lhsElTy.isF32() || lhsElTy.isF64()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isF16() || lhsElTy.isBF16()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isSignedInteger() && rhsElTy.isSignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isSignedInteger() && rhsElTy.isUnsignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isUnsignedInteger() && rhsElTy.isSignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else if (lhsElTy.isUnsignedInteger() && rhsElTy.isUnsignedInteger()) + rewriter.create(loc, TypeRange{}, operandsRange); + else + return op.emitOpError("unsupported SME vector element type"); + } + rewriter.eraseOp(op); + return LogicalResult::success(); + } +}; + +class ZeroOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(ZeroOp op, ZeroOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + (void)adaptor; + Location loc = op.getLoc(); + ArrayAttr tiles = op.getTiles(); + uint32_t tileNum = 0; + for (auto tile : tiles) { + auto tileEnum = tile.cast().getValue(); + tileNum |= static_cast(tileEnum); + } + Value tileVal = + rewriter.create(loc, rewriter.getI32Type(), tileNum); + rewriter.create(loc, tileVal); + rewriter.eraseOp(op); + return LogicalResult::success(); + } +}; + +/// Populate the given list with patterns that convert from ArmSME to LLVM. +void mlir::populateArmSMELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // Populate conversion patterns + // clang-format off + patterns.add, + MOPLowering, + ZeroOpLowering>(converter, &converter.getContext()); + // clang-format on +} + +void mlir::configureArmSMELegalizeForExportTarget( + LLVMConversionTarget &target) { + // clang-format off + target.addLegalOp(); + target.addIllegalOp(); + // clang-format on +} diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(AMDGPU) add_subdirectory(Arith) add_subdirectory(ArmNeon) +add_subdirectory(ArmSME) add_subdirectory(ArmSVE) add_subdirectory(Async) add_subdirectory(AMX) diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -39,6 +39,7 @@ LINK_LIBS PUBLIC MLIRArmNeonToLLVMIRTranslation + MLIRArmSMEToLLVMIRTranslation MLIRArmSVEToLLVMIRTranslation MLIRAMXToLLVMIRTranslation MLIRX86VectorToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.cpp @@ -0,0 +1,56 @@ +//======- ArmSMEToLLVMIRTranslation.cpp - Translate ArmSME to LLVM IR -=======// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the ArmSME dialect and LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h" +#include "mlir/Dialect/ArmSME/ArmSMEDialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicsAArch64.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the ArmSME dialect to LLVM IR. +class ArmSMEDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + Operation &opInst = *op; +#include "mlir/Dialect/ArmSME/ArmSMEConversions.inc" + + return failure(); + } +}; +} // namespace + +void mlir::registerArmSMEDialectTranslation(DialectRegistry ®istry) { + registry.insert(); + registry.addExtension(+[](MLIRContext *ctx, arm_sme::ArmSMEDialect *dialect) { + dialect->addInterfaces(); + }); +} + +void mlir::registerArmSMEDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerArmSMEDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/ArmSME/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_translation_library(MLIRArmSMEToLLVMIRTranslation + ArmSMEToLLVMIRTranslation.cpp + + DEPENDS + MLIRArmSMEConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRArmSMEDialect + MLIRLLVMDialect + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(ArmNeon) +add_subdirectory(ArmSME) add_subdirectory(ArmSVE) add_subdirectory(AMX) add_subdirectory(LLVMIR) diff --git a/mlir/test/Dialect/ArmSME/lower-llvm.mlir b/mlir/test/Dialect/ArmSME/lower-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/lower-llvm.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -convert-func-to-llvm | mlir-translate --mlir-to-llvmir | FileCheck %s +func.func @arm_sme_lower(%0 : vector<[2]xf64>, %1 : vector<[4]xf32>) { + %c = arith.constant 128 : index + %pred.64 = vector.create_mask %c : vector<[2]xi1> + // CHECK: call void @llvm.aarch64.sme.zero(i32 255) + arm_sme.zero za0d, za1d, za2d, za3d, za4d, za5d, za6d, za7d, za0s, za1s, za2s, za3s + // CHECK: call void @llvm.aarch64.sme.mopa.nxv2f64(i32 1, + arm_sme.mopa za0d, %pred.64, %pred.64, %0, %0 : vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64> + %pred.32 = vector.create_mask %c : vector<[4]xi1> + // CHECK: call void @llvm.aarch64.sme.mopa.nxv4f32(i32 17, + arm_sme.mopa za0s, %pred.32, %pred.32, %1, %1 : vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32> + return +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -0,0 +1,17 @@ +// RUN: mlir-opt %s +func.func @arm_sme_ops(%0 : vector<[2]xf64>, + %1 : vector<[4]xf32>, + %2 : vector<[4x2]xf16>, + %3 : vector<[2x4]xsi16>) { + %c = arith.constant 128 : index + %pred.64 = vector.create_mask %c : vector<[2]xi1> + arm_sme.zero za0d, za1d, za2d, za3d, za4d, za5d, za6d, za7d, za0s, za1s, za2s, za3s + arm_sme.mopa za0d, %pred.64, %pred.64, %0, %0 : vector<[2]xi1>, vector<[2]xi1>, vector<[2]xf64>, vector<[2]xf64> + %pred.32 = vector.create_mask %c : vector<[4]xi1> + arm_sme.mopa za0s, %pred.32, %pred.32, %1, %1 : vector<[4]xi1>, vector<[4]xi1>, vector<[4]xf32>, vector<[4]xf32> + %pred.16 = vector.create_mask %c, %c : vector<[4x2]xi1> + arm_sme.mopa za1s, %pred.16, %pred.16, %2, %2 : vector<[4x2]xi1>, vector<[4x2]xi1>, vector<[4x2]xf16>, vector<[4x2]xf16> + %pred.i16 = vector.create_mask %c, %c : vector<[2x4]xi1> + arm_sme.mopa za1d, %pred.i16, %pred.i16, %3, %3 : vector<[2x4]xi1>, vector<[2x4]xi1>, vector<[2x4]xsi16>, vector<[2x4]xsi16> + return +}