Index: mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -15,6 +15,10 @@ class LLVMTypeConverter; class RewritePatternSet; +namespace arm_sme { +void populateVectorTransferLoweringPatterns(RewritePatternSet &patterns); +} // namespace arm_sme + /// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM /// intrinsics. void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, Index: mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp =================================================================== --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -109,6 +109,7 @@ if (armSME) { configureArmSMELegalizeForExportTarget(target); populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + arm_sme::populateVectorTransferLoweringPatterns(patterns); } if (amx) { configureAMXLegalizeForExportTarget(target); Index: mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp LegalizeForLLVMExport.cpp + LowerVectorOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms @@ -12,5 +13,6 @@ MLIRArmSMEDialect MLIRFuncDialect MLIRLLVMCommonConversion + MLIRVectorDialect MLIRPass ) Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -51,7 +51,7 @@ void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); // Mark 'func.func' ops as legal if either: Index: mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp @@ -0,0 +1,55 @@ +//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrite patterns to lower vector dialect ops to ArmSME. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +static constexpr unsigned kZeroZAMask = 255; + +namespace { +/// Lower `vector.transfer_write` op to `arm_sme.intr.zero` op. Currently only +/// supports 2d scalable vector type `vector<[16x16]xi8>` that maps to the ZA0.B +/// SME tile. This will be extended to support more element types. +struct TransferWriteToArmSMEZeroLowering + : public OpRewritePattern { + TransferWriteToArmSMEZeroLowering(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + auto vType = write.getVectorType(); + if (vType.getRank() != 2) + return failure(); + if (vType.getShape() != ArrayRef({16, 16})) + return failure(); + if (vType.getElementType() != rewriter.getI8Type()) + return failure(); + if (vType.getNumScalableDims() != 2) + return failure(); + auto tile = rewriter.create( + write.getLoc(), rewriter.getI32Type(), + rewriter.getI32IntegerAttr(kZeroZAMask)); + rewriter.replaceOpWithNewOp(write, tile); + return success(); + } +}; +} // namespace + +void mlir::arm_sme::populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} Index: mlir/test/Dialect/ArmSME/vector_ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/vector_ops.mlir @@ -0,0 +1,88 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s + +// CHECK-LABEL: @transfer_write_2d_zero_i8 +// CHECK: %[[C255:.*]] = arith.constant 255 : i32 +// CHECK: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () +func.func @transfer_write_2d_zero_i8() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' +// lowering only occurs for vector types of correct rank, shape, element size +// and number of scalable dims. + +// CHECK-LABEL: @transfer_write_2d_zero__bad_type +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_type() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16]xi4> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi4>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_shape +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_shape() { + %c0 = arith.constant 0 : index + %c8 = arith.constant 8 : index + %vscale = vector.vscale + %dim = arith.muli %c8, %vscale : index + %0 = memref.alloc(%dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[8x8]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<[8x8]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_rank +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_rank() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim, %dim, %dim) : memref + %cst = arith.constant dense<0> : vector<[16x16x16]xi8> + vector.transfer_write %cst, %0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16x16x16]xi8>, memref + memref.dealloc %0 : memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_num_scalable_dims +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_num_scalable_dims() { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %dim = arith.muli %c16, %vscale : index + %0 = memref.alloc(%dim) : memref<16x?xi8> + %cst = arith.constant dense<0> : vector<16x[16]xi8> + vector.transfer_write %cst, %0[%c0, %c0] {in_bounds = [true, true]} : vector<16x[16]xi8>, memref<16x?xi8> + memref.dealloc %0 : memref<16x?xi8> + return +}