Index: mlir/include/mlir/Conversion/Passes.td =================================================================== --- mlir/include/mlir/Conversion/Passes.td +++ mlir/include/mlir/Conversion/Passes.td @@ -1092,6 +1092,10 @@ "bool", /*default=*/"false", "Enables the use of ArmSVE dialect while lowering the vector " "dialect.">, + Option<"armSME", "enable-arm-sme", + "bool", /*default=*/"false", + "Enables the use of ArmSME dialect while lowering the vector " + "dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " Index: mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -0,0 +1,27 @@ +//===- Transforms.h - ArmSME Dialect Transformation Entrypoints -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_ARMSME_TRANSFORMS_H +#define MLIR_DIALECT_ARMSME_TRANSFORMS_H + +namespace mlir { + +class LLVMConversionTarget; +class RewritePatternSet; + +namespace arm_sme { +void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns); +} // namespace arm_sme + +/// Configure the target to support lowering ArmSME ops to ops that map to LLVM +/// intrinsics. +void configureArmSMELegalizeForExportTarget(LLVMConversionTarget &target); + +} // namespace mlir + +#endif // MLIR_DIALECT_ARMSME_TRANSFORMS_H Index: mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt =================================================================== --- mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt +++ mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt @@ -15,6 +15,8 @@ LINK_LIBS PUBLIC MLIRArithDialect MLIRArmNeonDialect + MLIRArmSMEDialect + MLIRArmSMETransforms MLIRArmSVEDialect MLIRArmSVETransforms MLIRAMXDialect Index: mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp =================================================================== --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -14,6 +14,8 @@ #include "mlir/Dialect/AMX/Transforms.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/ArmSVE/ArmSVEDialect.h" #include "mlir/Dialect/ArmSVE/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -49,6 +51,8 @@ registry.insert(); if (armSVE) registry.insert(); + if (armSME) + registry.insert(); if (amx) registry.insert(); if (x86Vector) @@ -102,6 +106,10 @@ configureArmSVELegalizeForExportTarget(target); populateArmSVELegalizeForLLVMExportPatterns(converter, patterns); } + if (armSME) { + configureArmSMELegalizeForExportTarget(target); + arm_sme::populateVectorTransferLoweringPatterns(patterns); + } if (amx) { configureAMXLegalizeForExportTarget(target); populateAMXLegalizeForLLVMExportPatterns(converter, patterns); Index: mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp + LegalizeForLLVMExport.cpp + LowerVectorOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms @@ -8,6 +10,11 @@ MLIRArmSMETransformsIncGen LINK_LIBS PUBLIC + MLIRArmSMEDialect MLIRFuncDialect + MLIRLLVMDialect + MLIRVectorDialect + MLIRLLVMCommonConversion + MLIRIR MLIRPass ) Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -0,0 +1,19 @@ +//===- LegalizeForLLVMExport.cpp - Prepare ArmSME for LLVM translation ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +void mlir::configureArmSMELegalizeForExportTarget( + LLVMConversionTarget &target) { + target.addLegalOp(); +} Index: mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp @@ -0,0 +1,55 @@ +//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrite patterns to lower vector dialect ops to ArmSME. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +static constexpr unsigned kZeroZAMask = 255; + +namespace { +/// Lower `vector.transfer_write` op to `arm_sme.intr.zero` op. Currently only +/// supports 2d scalable vector type `vector<[16x16]xi8>` that maps to the ZA0.B +/// SME tile. This will be extended to support more element types. +struct TransferWriteToArmSMEZeroLowering + : public OpRewritePattern { + TransferWriteToArmSMEZeroLowering(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + auto vType = write.getVectorType(); + if (vType.getRank() != 2) + return failure(); + if (vType.getShape() != ArrayRef({16, 16})) + return failure(); + if (vType.getElementType() != rewriter.getI8Type()) + return failure(); + if (vType.getNumScalableDims() != 2) + return failure(); + auto tile = rewriter.create( + write.getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(kZeroZAMask)); + rewriter.replaceOpWithNewOp(write, tile); + return success(); + } +}; +} // namespace + +void mlir::arm_sme::populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} Index: mlir/test/Dialect/ArmSME/vector_ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/vector_ops.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s + +// CHECK-LABEL: @transfer_write_2d_zero_i8 +// CHECK: %[[C255:.*]] = arith.constant 255 : i32 +// CHECK: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () +func.func @transfer_write_2d_zero_i8() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' +// lowering only occurs for vector types of correct rank, shape, element size +// and number of scalable dims. + +// CHECK-LABEL: @transfer_write_2d_zero__bad_type +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_type() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16]xi4> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi4>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_shape +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_shape() { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %vscale = vector.vscale + %dim = arith.muli %c8, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[8x8]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[8x8]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_rank +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_rank() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16x16]xi8> + vector.transfer_write %cst, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16x16x16]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_num_scalable_dims +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_num_scalable_dims() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim) : memref<16x?xi8> + %cst = arith.constant dense<0> : vector<16x[16]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<16x[16]xi8>, memref<16x?xi8> + memref.dealloc %0 : memref<16x?xi8> + return +}