Index: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h =================================================================== --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h +++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.h @@ -14,6 +14,7 @@ #define MLIR_DIALECT_ARMSME_IR_ARMSME_H #include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" Index: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td =================================================================== --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -33,6 +33,7 @@ https://developer.arm.com/documentation/ddi0616 https://developer.arm.com/documentation/ddi0602/2023-03/SME-Instructions }]; + let dependentDialects = ["scf::SCFDialect"]; } //===----------------------------------------------------------------------===// @@ -119,6 +120,11 @@ def LLVM_aarch64_sme_st1d_vert : ArmSME_IntrStoreOp<"st1d.vert">; def LLVM_aarch64_sme_st1q_vert : ArmSME_IntrStoreOp<"st1q.vert">; +def LLVM_aarch64_sme_str + : ArmSME_IntrOp<"str">, + Arguments<(ins Arg, + Arg)>; + def LLVM_aarch64_sme_za_enable : ArmSME_IntrOp<"za.enable">; def LLVM_aarch64_sme_za_disable : ArmSME_IntrOp<"za.disable">; Index: mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/ArmSME/Transforms/Transforms.h @@ -15,6 +15,11 @@ class LLVMTypeConverter; class RewritePatternSet; +namespace arm_sme { +void populateVectorTransferLoweringPatterns(LLVMTypeConverter &converter, + RewritePatternSet &patterns); +} // namespace arm_sme + /// Collect a set of patterns to lower ArmSME ops to ops that map to LLVM /// intrinsics. void populateArmSMELegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, Index: mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp =================================================================== --- mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -109,6 +109,7 @@ if (armSME) { configureArmSMELegalizeForExportTarget(target); populateArmSMELegalizeForLLVMExportPatterns(converter, patterns); + arm_sme::populateVectorTransferLoweringPatterns(converter, patterns); } if (amx) { configureAMXLegalizeForExportTarget(target); Index: mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/IR/CMakeLists.txt @@ -10,5 +10,6 @@ LINK_LIBS PUBLIC MLIRIR MLIRLLVMDialect + MLIRSCFDialect MLIRSideEffectInterfaces ) Index: mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp LegalizeForLLVMExport.cpp + LowerVectorOps.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms @@ -12,5 +13,7 @@ MLIRArmSMEDialect MLIRFuncDialect MLIRLLVMCommonConversion + MLIRVectorDialect + MLIRSCFDialect MLIRPass ) Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" using namespace mlir; using namespace mlir::arm_sme; @@ -51,7 +52,8 @@ void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); // Mark 'func.func' ops as legal if either: Index: mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp @@ -0,0 +1,111 @@ +//===- LowerVectorOps.cpp - Lower vector ops to SME -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrite patterns to lower vector dialect ops to ArmSME. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + +using namespace mlir; +using namespace mlir::arm_sme; + +static constexpr unsigned kZeroZAMask = 255; + +/// Returns true if 'val' is a splat of zero, false otherwise. +static bool isSplatZero(Type elemType, DenseElementsAttr val) { + if (llvm::isa(elemType)) + return val && val.isSplat() && val.getSplatValue().isZero(); + if (llvm::isa(elemType)) + return val && val.isSplat() && val.getSplatValue().isZero(); + return false; +} + +namespace { +/// Lower `vector.transfer_write` op to `arm_sme.intr.zero` op. Currently only +/// supports 2d scalable vector type `vector<[16x16]xi8>` that maps to the ZA0.B +/// SME tile. This will be extended to support more element types. +struct TransferWriteToArmSMEZeroLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::TransferWriteOp write, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto vType = write.getVectorType(); + if (vType.getRank() != 2) + return failure(); + if (vType.getShape() != ArrayRef({16, 16})) + return failure(); + if (vType.getElementType() != rewriter.getI8Type()) + return failure(); + if (vType.getNumScalableDims() != 2) + return failure(); + + auto memRefType = llvm::dyn_cast(write.getSource().getType()); + if (!memRefType) + return failure(); + + auto constant = write.getVector().getDefiningOp(); + if (!constant) + return failure(); + + auto denseAttr = dyn_cast(constant.getValueAttr()); + if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) + return failure(); + + auto loc = write.getLoc(); + + // Create 'arm_sme.intr.zero' intrinsic to zero ZA. + auto tile = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask)); + rewriter.create(loc, tile); + + // Create loop that iterates from 0 to SVLB-1 inclusive (the number of + // vectors in ZA) and stores each ZA vector to memory. + auto step = rewriter.create(loc, 1); + auto minElems = rewriter.create(loc, 16); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create(loc, minElems, vscale); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + rewriter.setInsertionPointToStart(forOp.getBody()); + + // Create 'arm_sme.intr.str' intrinsic to store ZA vector. + auto vnum = rewriter.create( + loc, rewriter.getI64Type(), forOp.getInductionVar()); + auto offset = + rewriter.create(loc, rewriter.getI64Type(), 0); + Value ptr = + getStridedElementPtr(loc, memRefType, adaptor.getSource(), + ValueRange{vnum.getResult(0), offset}, rewriter); + auto idx = rewriter.create( + loc, rewriter.getZeroAttr(rewriter.getI32Type())); + rewriter.create(loc, idx, ptr); + + rewriter.eraseOp(write); + + return success(); + } +}; +} // namespace + +void mlir::arm_sme::populateVectorTransferLoweringPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(converter); +} Index: mlir/test/Dialect/ArmSME/vector_ops.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/vector_ops.mlir @@ -0,0 +1,116 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s + +// CHECK-LABEL: @transfer_write_2d_zero_i8 +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C255:.*]] = arith.constant 255 : i32 +// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[VNUM_I64:.*]] = builtin.unrealized_conversion_cast %[[VNUM]] : index to i64 +// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[C0_2:.*]] = arith.constant 0 : i32 +// CHECK-NEXT: "arm_sme.intr.str"(%[[C0_2]], %[[GEP]]) : (i32, !llvm.ptr) -> () +func.func @transfer_write_2d_zero_i8(%arg0 : memref) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<[16x16]xi8> + vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref + return +} + +// ----- + +// The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' +// lowering only occurs for vector types of correct rank, shape, element size +// and number of scalable dims. + +// CHECK-LABEL: @transfer_write_2d_zero__bad_type +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_type(%arg0 : memref) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<[16x16]xi4> + vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi4>, memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_shape +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<[8x8]xi8> + vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[8x8]xi8>, memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_rank +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<[16x16x16]xi8> + vector.transfer_write %cst, %arg0[%c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<[16x16x16]xi8>, memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__bad_num_scalable_dims +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__bad_num_scalable_dims(%arg0 : memref<16x?xi8>) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<16x[16]xi8> + vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<16x[16]xi8>, memref<16x?xi8> + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__non_memref_type +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<[16x16]xi8> + %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, tensor + return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<1> : vector<[16x16]xi8> + vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref + return +} + +// ----- + +// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op +// CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.intr.zero +func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref, %arg1 : vector<[16x16]xi8>) { + %c0 = arith.constant 0 : index + vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref + return +} Index: mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector_ops.mlir =================================================================== --- /dev/null +++ mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector_ops.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ +// RUN: -convert-vector-to-llvm="enable-arm-sme" -test-lower-to-llvm | \ +// RUN: mlir-translate -mlir-to-llvmir | \ +// RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ +// RUN: --entry-function=entry \ +// RUN: --dlopen=%mlir_native_utils_lib_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func.func @entry() -> i32 { + %i0 = arith.constant 0 : i32 + %i1 = arith.constant 1 : i8 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %c16 = arith.constant 16 : index + %vscale = vector.vscale + %svl_b = arith.muli %c16, %vscale : index + + // Allocate memory and fill with ones. + // + // NOTE: this has to be written as nested loops with upper bound scaled by + // vscale, since 'vector.transfer_write' doesn't support scalable vectors and + // there's no lowering to SME for splat of one and 'vector.store' also isn't + // working. + %za_b = memref.alloca(%svl_b, %svl_b) : memref + scf.for %i = %c0 to %svl_b step %c1 { + scf.for %j = %c0 to %svl_b step %c1 { + memref.store %i1, %za_b[%i, %j] : memref + } + } + + // Verify memory is ones by doing a mul reduction with initial value of one. + %init_0 = arith.constant 1 : i64 + %mul_reduce = scf.for %iv = %c0 to %svl_b step %c1 + iter_args(%iter = %init_0) -> (i64) { + %za_b_vec = vector.load %za_b[%iv, %c0] : memref, vector<16xi8> + + %inner_mul_reduce = scf.for %iv2 = %c0 to %svl_b step %c1 + iter_args(%inner_iter = %init_0) -> (i64) { + %t = vector.extractelement %za_b_vec[%iv2 : index] : vector<16xi8> + %t_i64 = arith.extui %t : i8 to i64 + %inner_mul_reduce_next = arith.muli %inner_iter, %t_i64 : i64 + scf.yield %inner_mul_reduce_next : i64 + } + + %mul_reduce_next = arith.muli %iter, %inner_mul_reduce : i64 + scf.yield %mul_reduce_next : i64 + } + + // CHECK: 1 + vector.print %mul_reduce : i64 + + // This will get lowered to: + // + // zero {za} + // for vnum = 0; vnum < SVLb; ++vnum; + // str za[vnum], [ptr] + // ... + // + %cst_0 = arith.constant dense<0> : vector<[16x16]xi8> + vector.transfer_write %cst_0, %za_b[%c0, %c0] {in_bounds = [true, true]} : vector<[16x16]xi8>, memref + + // Verify memory is zeroed by doing an add reduction with initial value of + // zero. + %init_1 = arith.constant 0 : i64 + %add_reduce = scf.for %iv = %c0 to %svl_b step %c1 + iter_args(%iter = %init_1) -> (i64) { + %za_b_vec = vector.load %za_b[%iv, %c0] : memref, vector<[16]xi8> + + %inner_add_reduce = scf.for %iv2 = %c0 to %svl_b step %c1 + iter_args(%inner_iter = %init_1) -> (i64) { + %t = vector.extractelement %za_b_vec[%iv2 : index] : vector<[16]xi8> + %t_i64 = arith.extui %t : i8 to i64 + %inner_add_reduce_next = arith.muli %inner_iter, %t_i64 : i64 + scf.yield %inner_add_reduce_next : i64 + } + + %add_reduce_next = arith.addi %iter, %inner_add_reduce : i64 + scf.yield %add_reduce_next : i64 + } + + // CHECK-NEXT: 0 + vector.print %add_reduce : i64 + + return %i0 : i32 +}