Index: mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt +++ mlir/include/mlir/Dialect/ArmSME/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(IR) add_subdirectory(Transforms) Index: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -0,0 +1,36 @@ +//===- ArmSMEDialect.h - MLIR Dialect for Arm SME ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for ArmSME in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H +#define MLIR_DIALECT_ARMSME_ARMSMEDIALECT_H + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.h.inc" + +namespace mlir { +namespace arm_sme { +class SMETile : public SideEffects::Resource::Base { +public: + StringRef getName() final { return "SMETile"; } +}; +} // namespace arm_sme +} // namespace mlir + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSME.h.inc" + +#endif // MLIR_DIALECT_ARMSVE_ARMSVEDIALECT_H Index: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -0,0 +1,269 @@ +//===-- ArmSME.td - ArmSME dialect operation definitions ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the ArmSME dialect. +// Currently contains the following ops: +// * zero - allocates a tile initialized with zeros +// * mopa/mops - matrix outer product accumulate/subtract +// * load.tile - allocates a tile initialized with data from memory +// * store.tile - stores tile to memory while deallocating it. +// +// There are a few outstanding TODOs: +// * Move the streaming mode function attribute into here +// * Think a bit more about how to share more functionality with vector dialect +// * Save/restore functionality for cross-function tile usage +// * Implement a more robust tile allocation scheme +// * Implement lowering of more general vector ops on SME tiles, e.g. extract +// +//===----------------------------------------------------------------------===// + +#ifndef ARMSME_OPS +#define ARMSME_OPS + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" + +//===----------------------------------------------------------------------===// +// ArmSME dialect definition +//===----------------------------------------------------------------------===// + +def ArmSME_Dialect : Dialect { + let name = "arm_sme"; + let cppNamespace = "::mlir::arm_sme"; + let summary = "Basic dialect to target Arm SME architectures"; + let description = [{ + This dialect contains the definitions necessary to target specific Arm SME + scalable matrix operations. + + Source: + https://developer.arm.com/documentation/ddi0616/aa + }]; + // FIXME: Is this necessary? + let dependentDialects = ["arm_sve::ArmSVEDialect"]; +} + +//===----------------------------------------------------------------------===// +// ArmSME misc definitions +//===----------------------------------------------------------------------===// + +// Effects +def TileResource : Resource<"SMETile">; +def TileAlloc : MemAlloc; +def TileWrite : MemWrite; +def TileRead : MemRead; +def TileFree : MemFree; + +def Predicate : ScalableVectorOfLengthAndType<[16, 8, 4, 2], [I1]>; +def SMEVector : ScalableVectorOfLengthAndType< + [16, 8, 4, 2], [SI8, SI16, UI8, UI16, BF16, F16, F32, F64]>; +def SMETile : AllOfType<[ + ScalableVectorOf<[SI8, SI16, SI32, SI64, UI8, UI16, UI32, UI64, + BF16, F16, F32, F64]>, + ScalableVectorOfRank<[2]>, + ScalableVectorOfLength<[256, 64, 16, 4]>]>; + +//===----------------------------------------------------------------------===// +// ArmSME op definitions +//===----------------------------------------------------------------------===// + +class ArmSME_Op traits = []> : + Op {} + +def ZeroOp : ArmSME_Op<"zero"> { + let summary = "Returns a SME tile with its contents zeroed."; + let description = [{ + Allocates a tile in the background. This will impose restrictions on the + returned vector (for now): + * The vector can only be read from before being written to. After which + reading from that vector will become invalid. + * Once overwritten (e.g. mopa), the new returned tile will represent the + same physical tile. + }]; + let results = (outs Res: $tile); + let assemblyFormat = "attr-dict `:` type($tile)"; +} + +class MOPOpBase + : ArmSME_Op, + AllShapesMatch<["lhsPred", "lhs"]>, + AllShapesMatch<["rhsPred", "rhs"]>, + AllElementTypesMatch<["lhs", "rhs"]>, + AllElementCountsMatch<["lhs", "rhs"]>]> { + let arguments = (ins + Arg:$tile, + Predicate:$lhsPred, + Predicate:$rhsPred, + SMEVector:$lhs, + SMEVector:$rhs + ); + let results = (outs Res:$result); + let extraClassDeclaration = [{ + bool isAccumulate() { return }] # accumulate # [{; + } + bool isSubtract() { return }] # !not(accumulate) # [{; } + bool isWidening() { + auto elTy = cast(this->getLhs().getType()).getElementType(); + if (elTy.isF32() || elTy.isF64()) + return false; + else + return true; + } + }]; + let assemblyFormat =[{ $tile`,` $lhsPred`,` $rhsPred`,` $lhs`,` $rhs attr-dict + `:` custom(type($tile), + type($lhsPred), type($rhsPred), type($lhs), + type($rhs)) }]; + let hasVerifier = 1; +} + +def MopaOp : MOPOpBase<"mopa", /*accumulate=*/true> { + let summary = "Vector-vector outer product and accumulate op"; + let description = [{ + MOPA: Outer product accumulate. + + This function maps to the *MOPA instructions, it takes scalable vector + operands which will be used to compute the outer product matrix. Two + masking predicate operands for each of the floating point operands will also + be provided, such that elements marked inactive by the predicate will not + update the corresponding row/column in the result matrix tile, specified by + the attribute. + + Theere are two variations of MOPA instructions - widening and non-widening. + + Non-widening MOPAs will take a 1D vector of f32 or f64 as input and + accumulate into 32b and 64b tiles respectively (za*s and za*d). + + Widening MOPAs will pack two f16/bf16 or four (signed or unsigned) i8 + elements into a single 32b lane of the vector and accumulate into 32b tiles + (za*s); Or it will pack four (signed or unsigned) i16 elements into a 64b + lane and accumulate into 64b tiles (za*d). Hence widening MOPAs will take + 2D scalable vectors as input, i.e. `<[4x2]xf16>, <[2x4]xsi16>, <[4x4]xsi8>` + + Example: Assume `vscale == 2`, `%lhs = %rhs = <1, 2, 3, 4> : <[2]xfp64>`, + `%lhsPred = %rhsPred = `, then: + ``` + arm_sme.zero za0d + arm_sme.fmopa za0d, %lhsPred, %rhsPred, %lhs, %rhs + : vector<[2]xi1>, vector<[2]xf64> + ``` + + Would result in za0d containing: + ``` + 1 2 0 4 + 2 4 0 8 + 0 0 0 0 + 4 8 0 16 + ``` + }]; +} + +def MopsOp : MOPOpBase<"mops", /*accumulate=*/false> { + let summary = "Vector-vector outer product and subtract op"; + let description = [{ + MOPS: Outer product subtract. + + Similar to the MOPA instruction, except a the outer product is subtracted + from the tile. + }]; +} + +def LoadTile : ArmSME_Op<"load.tile"> { + let summary = "Loads a 2D vector from memory into a SME tile"; + let description = [{ + This op will allocate a tile (similar to `arm_sme.zero`) in addition to + loading a tile from memory. This means the vector returned from this op + requires the same constraint as `arm_sme.zero`. + }]; + let arguments = (ins Arg:$base, + Variadic:$indices); + let results = (outs Res:$tile); + let assemblyFormat = "$base `[` $indices `]` attr-dict " + "`:` type($base) `,` type($tile)"; +} + +def StoreTile : ArmSME_Op<"store.tile"> { + let summary = "Stores a 2D vector into memory from a SME tile"; + let description = [{ + This op will deallocate a tile (is this necessary?) in addition to + storing the tile. + }]; + let arguments = (ins Arg:$tile, + Arg:$base, + Variadic:$indices); + let assemblyFormat = "$tile `,` $base `[` $indices `]` attr-dict " + "`:` type($base) `,` type($tile)"; +} + +//===----------------------------------------------------------------------===// +// ArmSME Intrinsic op definitions +//===----------------------------------------------------------------------===// + +class ArmSME_IntrOverloadedOp overloadOperands = [], + list traits = []> + : LLVM_IntrOpBase< + /*Dialect dialect=*/ArmSME_Dialect, + /*string opName=*/"intr." #mnemonic, + /*string enumName=*/"aarch64_sme_" #!subst(".", "_", mnemonic), + /*list overloadedResults=*/[], + /*list overloadedOperands=*/overloadOperands, + /*list traits=*/traits, + /*int numResults=*/0>; + +// Zero +def ZeroIntrOp : ArmSME_IntrOverloadedOp<"zero", []>, + Arguments<(ins Arg)>; + +// MOP's +class ArmSME_IntrMopOverloadedOp + : ArmSME_IntrOverloadedOp, + Arguments<(ins Arg, + Arg, + Arg, + Arg, + Arg)>; + +def FmopaIntrOp : ArmSME_IntrMopOverloadedOp<"mopa">; +def FmopsIntrOp : ArmSME_IntrMopOverloadedOp<"mops">; +def FmopaWidenIntrOp : ArmSME_IntrMopOverloadedOp<"mopa.wide">; +def FmopsWidenIntrOp : ArmSME_IntrMopOverloadedOp<"mops.wide">; +def SmopaIntrOp : ArmSME_IntrMopOverloadedOp<"smopa.wide">; +def SmopsIntrOp : ArmSME_IntrMopOverloadedOp<"smops.wide">; +def UmopaIntrOp : ArmSME_IntrMopOverloadedOp<"umopa.wide">; +def UmopsIntrOp : ArmSME_IntrMopOverloadedOp<"umops.wide">; +def SUmopaIntrOp : ArmSME_IntrMopOverloadedOp<"sumopa.wide">; +def SUmopsIntrOp : ArmSME_IntrMopOverloadedOp<"sumops.wide">; +def USmopaIntrOp : ArmSME_IntrMopOverloadedOp<"usmopa.wide">; +def USmopsIntrOp : ArmSME_IntrMopOverloadedOp<"usmops.wide">; + +// Loads +class ArmSME_IntrLoadOverloadedOp + : ArmSME_IntrOverloadedOp, + Arguments<(ins Arg, + Arg, + Arg, + Arg)>; + +def LoadHorizontalWordsIntrOp : ArmSME_IntrLoadOverloadedOp<"ld1w.horiz">; +def LoadHorizontalDoublesIntrOp : ArmSME_IntrLoadOverloadedOp<"ld1d.horiz">; +def LoadVerticalWordsIntrOp : ArmSME_IntrLoadOverloadedOp<"ld1w.vert">; +def LoadVerticalDoublesIntrOp : ArmSME_IntrLoadOverloadedOp<"ld1d.vert">; + +// Stores +class ArmSME_IntrStoreOverloadedOp + : ArmSME_IntrOverloadedOp, + Arguments<(ins Arg, + Arg, + Arg, + Arg)>; + +def StoreHorizontalWordsIntrOp : ArmSME_IntrStoreOverloadedOp<"st1w.horiz">; +def StoreHorizontalDoublesIntrOp : ArmSME_IntrStoreOverloadedOp<"st1d.horiz">; +def StoreVerticalWordsIntrOp : ArmSME_IntrStoreOverloadedOp<"st1w.vert">; +def StoreVerticalDoublesIntrOp : ArmSME_IntrStoreOverloadedOp<"st1d.vert">; + +#endif // ARMSME_OPS Index: mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/IR/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_dialect(ArmSME arm_sme ArmSME) +add_mlir_doc(ArmSME ArmSME Dialects/ -gen-dialect-doc -dialect=arm_sme) + +set(LLVM_TARGET_DEFINITIONS ArmSME.td) +mlir_tablegen(ArmSMEConversions.inc -gen-llvmir-conversions) +add_public_tablegen_target(MLIRArmSMEConversionsIncGen) Index: mlir/include/mlir/InitAllDialects.h =================================================================== --- mlir/include/mlir/InitAllDialects.h +++ mlir/include/mlir/InitAllDialects.h @@ -23,6 +23,7 @@ #include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/Async/IR/Async.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -117,6 +118,7 @@ pdl_interp::PDLInterpDialect, quant::QuantizationDialect, spirv::SPIRVDialect, + arm_sme::ArmSMEDialect, arm_sve::ArmSVEDialect, vector::VectorDialect, NVVM::NVVMDialect, Index: mlir/lib/Dialect/ArmSME/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/CMakeLists.txt @@ -1 +1,2 @@ +add_subdirectory(IR) add_subdirectory(Transforms) Index: mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp @@ -0,0 +1,196 @@ +//===- ArmSMEDialect.cpp - MLIR ArmSME dialect implementation -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the ArmSME dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +//===----------------------------------------------------------------------===// +// Custom Printer/Parser/Verifiers +//===----------------------------------------------------------------------===// + +/// Custom parser for MOPA and MOPS types - Since the predicates are always +/// going to be same shape as their respective vectors, we should be able to +/// omit them. +static ParseResult parseMOPOperandTypes(OpAsmParser &parser, Type &tileTy, + Type &lhsPredTy, Type &rhsPredTy, + Type &lhsTy, Type &rhsTy) { + MLIRContext *ctx = parser.getContext(); + // Parse tile type + ParseResult result = parser.parseType(tileTy); + if (result) + return result; + + result = parser.parseComma(); + if (result) + return result; + + // Parse LHS type + result = parser.parseType(lhsTy); + if (result) + return result; + + // Construct lhsPredTy + VectorType vecTy = dyn_cast(lhsTy); + llvm::SMLoc loc = parser.getCurrentLocation(); + StringRef expectVectorErr = "expected vector type"; + if (!vecTy) + return parser.emitError(loc, expectVectorErr); + lhsPredTy = vecTy.clone(IntegerType::get(ctx, 1)); + + result = parser.parseComma(); + if (result) + return result; + + // Parse RHS type + result = parser.parseType(rhsTy); + if (result) + return result; + + // Construct rhsPredTy + vecTy = dyn_cast(rhsTy); + loc = parser.getCurrentLocation(); + if (!vecTy) + return parser.emitError(loc, expectVectorErr); + rhsPredTy = vecTy.clone(IntegerType::get(ctx, 1)); + + return result; +} + +static void printMOPOperandTypes(OpAsmPrinter &printer, Operation *op, + Type tileTy, Type lhsPredTy, Type rhsPredTy, + Type lhsTy, Type rhsTy) { + (void)op; + (void)lhsPredTy; + (void)rhsPredTy; + printer << tileTy << ", " << lhsTy << ", " << rhsTy; +} + +// Verifier for MOPA/MOPS, additions to constraints in ODS +static LogicalResult verifyMOPOps(Operation *op, Value tile, Value lhs, + Value rhs, bool isWidening) { + VectorType tileTy = cast(tile.getType()); + VectorType lhsTy = cast(lhs.getType()); + VectorType rhsTy = cast(rhs.getType()); + Type tileElTy = tileTy.getElementType(); + Type vecElTy = lhsTy.getElementType(); + + // Verify tile type + unsigned tileElWidth = tileElTy.getIntOrFloatBitWidth(); + if (tileElWidth != 32 && tileElWidth != 64) + return op->emitOpError("expected 32 or 64-bit output tiles for MOPA/MOPS"); + if (tileTy.getShape()[0] != tileTy.getShape()[1]) + return op->emitOpError("expecting square tiles for SME ops"); + if (tileTy.getShape()[0] != lhsTy.getShape()[0] || + tileTy.getShape()[1] != rhsTy.getShape().back()) + return op->emitOpError("invalid shapes for outer product"); + + // Integer sign check + if (vecElTy.isSignedInteger() != tileElTy.isSignedInteger() || + vecElTy.isUnsignedInteger() != tileElTy.isUnsignedInteger()) + return op->emitOpError( + "expecting tile and vector operands to have same signedness"); + + // Scalable dimension check + if (lhsTy.getNumScalableDims() != lhsTy.getRank() || + rhsTy.getNumScalableDims() != rhsTy.getRank() || + tileTy.getNumScalableDims() != tileTy.getRank()) + return op->emitOpError( + "expecting all dimensions of all operands to be scalable"); + + // Vector length check + if (lhsTy.getNumElements() * vecElTy.getIntOrFloatBitWidth() != 128) + return op->emitOpError( + "expecting input operands to have unit vector length of 128"); + + // For widening MOPA/MOPS + if (isWidening) { + // Input vectors should be of rank 2, with specific shapes + if (lhsTy.getRank() != 2 || rhsTy.getRank() != 2) + return op->emitOpError( + "widening sme outer product instructions expects 2d input vectors"); + + if (vecElTy.isInteger(8)) { // SI8/UI8 accumulates to SI32/UI32 + if (!tileElTy.isInteger(32)) + return op->emitOpError("expecting i8 to accumulate into i32"); + + // LHS and RHS should all be <[4x4]xi8> + if (lhsTy.getShape()[0] != 4 || lhsTy != rhsTy) + return op->emitOpError( + "expecting shape of [4x4] for 8-bit integer operands"); + + } else if (vecElTy.isInteger(16)) { // SI16/UI16 accumulates to SI64/UI64 + if (!tileElTy.isInteger(64)) + return op->emitOpError("expecting i16 to accumulate into i64"); + + // LHS should be [2x4], while RHS should be [4x2] + if (lhsTy.getShape()[0] != 2 || rhsTy.getShape()[0] != 4) + return op->emitOpError("expecting outer product of LHS [2x4] and RHS " + "[4x2] for 16-bit integer operands"); + + } else { // BF16 and F16 should all accumulate to F32 + if (!tileElTy.isF32()) + return op->emitOpError("expecting f16/bf16 to accumulate into f32"); + + // LHS and RHS should be [4x2] and [2x4] respectively + if (lhsTy.getShape()[0] != 4 || rhsTy.getShape()[0] != 2) + return op->emitOpError("expecting outer product of LHS [4x2] and RHS " + "[2x4] for 16-bit floating point operands"); + } + return LogicalResult::success(); + } // End verification of widening MOPA/MOPS + + if (lhsTy.getRank() != 1 || rhsTy.getRank() != 1) + return op->emitOpError( + "expecting non-widening MOPA/MOPS to have vector operands of rank 1"); + + if (vecElTy != tileElTy) + return op->emitOpError("expecting same input and tile element types of " + "non-widening MOPA/MOPS"); + return LogicalResult::success(); +} + +LogicalResult MopaOp::verify() { + return verifyMOPOps(*this, getTile(), getLhs(), getRhs(), isWidening()); +} + +LogicalResult MopsOp::verify() { + return verifyMOPOps(*this, getTile(), getLhs(), getRhs(), isWidening()); +} + +//===----------------------------------------------------------------------===// +// Tablegen Definitions +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSMEDialect.cpp.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" + +#define GET_TYPEDEF_CLASSES +#include "mlir/Dialect/ArmSME/IR/ArmSMETypes.cpp.inc" + +void ArmSMEDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" + >(); +} Index: mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_dialect_library(MLIRArmSMEDialect + ArmSME.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME + + DEPENDS + MLIRArmSMEIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRSideEffectInterfaces +) Index: mlir/test/Dialect/ArmSME/roundtrip.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s +func.func @ui8mopa(%tile : vector<[4x4]xui32>, %pred : vector<[4x4]xi1>, + %operand : vector<[4x4]xui8>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[4x4]xui32> + %out = arm_sme.mopa %tile, %pred, %pred, %operand, %operand : + vector<[4x4]xui32>, vector<[4x4]xui8>, vector<[4x4]xui8> + return +} + +func.func @ui16mopa(%tile : vector<[2x2]xui64>, %lpred : vector<[2x4]xi1>, + %rpred : vector<[4x2]xi1>, %lhs : vector<[2x4]xui16>, + %rhs : vector<[4x2]xui16>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[2x2]xui64> + %out = arm_sme.mopa %tile, %lpred, %rpred, %lhs, %rhs : + vector<[2x2]xui64>, vector<[2x4]xui16>, vector<[4x2]xui16> + return +} + +func.func @si8mopa(%tile : vector<[4x4]xsi32>, %pred : vector<[4x4]xi1>, + %operand : vector<[4x4]xsi8>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[4x4]xsi32> + %out = arm_sme.mopa %tile, %pred, %pred, %operand, %operand : + vector<[4x4]xsi32>, vector<[4x4]xsi8>, vector<[4x4]xsi8> + return +} + +func.func @si16mopa(%tile : vector<[2x2]xsi64>, %lpred : vector<[2x4]xi1>, + %rpred : vector<[4x2]xi1>, %lhs : vector<[2x4]xsi16>, + %rhs : vector<[4x2]xsi16>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[2x2]xsi64> + %out = arm_sme.mopa %tile, %lpred, %rpred, %lhs, %rhs : + vector<[2x2]xsi64>, vector<[2x4]xsi16>, vector<[4x2]xsi16> + return +} + +func.func @bf16mopa(%tile : vector<[4x4]xf32>, %lpred : vector<[4x2]xi1>, + %rpred : vector<[2x4]xi1>, %lhs : vector<[4x2]xbf16>, + %rhs : vector<[2x4]xbf16>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[4x4]xf32> + %out = arm_sme.mopa %tile, %lpred, %rpred, %lhs, %rhs : + vector<[4x4]xf32>, vector<[4x2]xbf16>, vector<[2x4]xbf16> + return +} + +func.func @f16mopa(%tile : vector<[4x4]xf32>, %lpred : vector<[4x2]xi1>, + %rpred : vector<[2x4]xi1>, %lhs : vector<[4x2]xf16>, + %rhs : vector<[2x4]xf16>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[4x4]xf32> + %out = arm_sme.mopa %tile, %lpred, %rpred, %lhs, %rhs : + vector<[4x4]xf32>, vector<[4x2]xf16>, vector<[2x4]xf16> + return +} + +func.func @f32mopa(%tile : vector<[4x4]xf32>, %pred : vector<[4]xi1>, + %operand : vector<[4]xf32>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[4x4]xf32> + %out = arm_sme.mopa %tile, %pred, %pred, %operand, %operand : + vector<[4x4]xf32>, vector<[4]xf32>, vector<[4]xf32> + return +} + +func.func @f64mopa(%tile : vector<[2x2]xf64>, %pred : vector<[2]xi1>, + %operand : vector<[2]xf64>) { +// CHECK: arm_sme.mopa +// CHECK-SAME: <[2x2]xf64> + %out = arm_sme.mopa %tile, %pred, %pred, %operand, %operand : + vector<[2x2]xf64>, vector<[2]xf64>, vector<[2]xf64> + return +} Index: mlir/test/mlir-opt/commandline.mlir =================================================================== --- mlir/test/mlir-opt/commandline.mlir +++ mlir/test/mlir-opt/commandline.mlir @@ -6,6 +6,7 @@ // CHECK-SAME: amx // CHECK-SAME: arith // CHECK-SAME: arm_neon +// CHECK-SAME: arm_sme // CHECK-SAME: arm_sve // CHECK-SAME: async // CHECK-SAME: bufferization