diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h --- a/mlir/include/mlir/Conversion/Passes.h +++ b/mlir/include/mlir/Conversion/Passes.h @@ -59,6 +59,7 @@ #include "mlir/Conversion/VectorToGPU/VectorToGPU.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Conversion/VectorToSME/VectorToSME.h" #include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h" namespace mlir { diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1061,6 +1061,22 @@ ]; } +//===----------------------------------------------------------------------===// +// VectorToME +//===----------------------------------------------------------------------===// + +def ConvertVectorToSME : Pass<"convert-vector-to-sme"> { + let summary = "Lower the operations from the vector dialect into the ArmSME " + "dialect"; + let description = [{ + Pass that converts vector dialect operations into equivalent ArmSME dialect + operations. + }]; + let dependentDialects = [ + "arm_sme::ArmSMEDialect" + ]; +} + //===----------------------------------------------------------------------===// // VectorToLLVM //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Conversion/VectorToSME/VectorToSME.h b/mlir/include/mlir/Conversion/VectorToSME/VectorToSME.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Conversion/VectorToSME/VectorToSME.h @@ -0,0 +1,25 @@ +//===- VectorToSME.h - Convert vector to ArmSME dialect -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_CONVERSION_VECTORTOSME_VECTORTOSME_H_ +#define MLIR_CONVERSION_VECTORTOSME_VECTORTOSME_H_ + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +class Pass; + +#define GEN_PASS_DECL_CONVERTVECTORTOSME +#include "mlir/Conversion/Passes.h.inc" + +/// Collect a set of patterns to lower Vector ops to ArmSME ops that map to LLVM +/// intrinsics. +void populateVectorToSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx); + +} // namespace mlir + +#endif // MLIR_CONVERSION_VECTORTOSME_VECTORTOSME_H_ diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -36,6 +36,48 @@ let dependentDialects = ["scf::SCFDialect"]; } +//===----------------------------------------------------------------------===// +// ArmSME custom op definitions +//===----------------------------------------------------------------------===// + +class ArmSME_Op traits = []> : + Op {} + +def ZeroOp : ArmSME_Op<"zero"> { + let summary = "initialize ZA to 0"; + let description = [{ + Initialise ZA to 0. + + Example: + + ```mlir + arm_sme.zero + ``` + }]; + let assemblyFormat = "attr-dict"; +} + +def TileStoreOp : ArmSME_Op<"tile_store"> { + let summary = "tile store operation"; + let description = [{ + Store a 2D SME tile to memory. + + Example: + + ```mlir + sme.tile_store %arg1[%c0, %c0] : memref + ``` + }]; + let arguments = (ins Arg:$base, + Variadic:$indices); + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return ::llvm::cast(getBase().getType()); + } + }]; + let assemblyFormat = "$base `[` $indices `]` attr-dict `:` type($base)"; +} + //===----------------------------------------------------------------------===// // ArmSME Intrinsic op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -50,3 +50,4 @@ add_subdirectory(VectorToGPU) add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) +add_subdirectory(VectorToSME) diff --git a/mlir/lib/Conversion/VectorToSME/CMakeLists.txt b/mlir/lib/Conversion/VectorToSME/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToSME/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_conversion_library(MLIRVectorToSME + VectorToSME.cpp + VectorToSMEPass.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToSME + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRArmSMEDialect + MLIRTransforms + ) diff --git a/mlir/lib/Conversion/VectorToSME/VectorToSME.cpp b/mlir/lib/Conversion/VectorToSME/VectorToSME.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToSME/VectorToSME.cpp @@ -0,0 +1,92 @@ +//===- VectorToSME.cpp - Conversion from Vector to the SME dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToSME/VectorToSME.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Target/LLVMIR/TypeToLLVM.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/Casting.h" +#include + +using namespace mlir; +using namespace mlir::vector; + +static constexpr unsigned kMinNumElts = 16; + +/// Returns true if 'val' is a splat of zero, false otherwise. +static bool isSplatZero(Type elemType, DenseElementsAttr val) { + if (llvm::isa(elemType)) + return val && val.isSplat() && val.getSplatValue().isZero(); + if (llvm::isa(elemType)) + return val && val.isSplat() && val.getSplatValue().isZero(); + return false; +} + +/// Look at `vector.transfer_write` operations and convert suitable candidates +/// to ArmSME operations, e.g.: +/// +/// %cst = arith.constant dense<0> : vector<[16]x[16]xi8> +/// vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref +/// +/// is converted to: +/// +/// arm_sme.zero +/// arm_sme.tile_store %arg0[%c0, %c0] : memref +struct TransferWriteToArmSME + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) const final { + auto vType = writeOp.getVectorType(); + if (vType.getRank() != 2) + return failure(); + if (vType.getShape() != ArrayRef({kMinNumElts, kMinNumElts})) + return failure(); + if (vType.getElementType() != rewriter.getI8Type()) + return failure(); + if (vType.getScalableDims().size() != 2) + return failure(); + + auto loc = writeOp.getLoc(); + + auto memRefType = llvm::dyn_cast(writeOp.getSource().getType()); + if (!memRefType) + return failure(); + + auto constant = writeOp.getVector().getDefiningOp(); + if (!constant) + return failure(); + + auto denseAttr = dyn_cast(constant.getValueAttr()); + if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) + return failure(); + + auto zero = rewriter.create(loc); + (void)zero; + + rewriter.create(loc, writeOp.getSource(), + writeOp.getIndices()); + + rewriter.eraseOp(constant); + rewriter.eraseOp(writeOp); + + return success(); + } +}; + +void mlir::populateVectorToSMEPatterns(RewritePatternSet &patterns, + MLIRContext &ctx) { + patterns.add(&ctx); +} diff --git a/mlir/lib/Conversion/VectorToSME/VectorToSMEPass.cpp b/mlir/lib/Conversion/VectorToSME/VectorToSMEPass.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/VectorToSME/VectorToSMEPass.cpp @@ -0,0 +1,60 @@ +//===- VectorToSMEPass.cpp - Conversion from Vector to the SME dialect ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToSME/VectorToSME.h" + +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_CONVERTVECTORTOSME +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::vector; + +namespace { +struct LowerVectorToSMEPass + : public impl::ConvertVectorToSMEBase { + + using Base::Base; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override; +}; +} // namespace + +void LowerVectorToSMEPass::runOnOperation() { + // Convert to the LLVM IR dialect. + RewritePatternSet patterns(&getContext()); + + LLVMConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + + populateVectorToSMEPatterns(patterns, getContext()); + + if (failed( + applyPartialConversion(getOperation(), target, std::move(patterns)))) + signalPassFailure(); +} diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -7,14 +7,21 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" using namespace mlir; using namespace mlir::arm_sme; +static constexpr unsigned kMinNumElts = 16; +static constexpr unsigned kZeroZAMask = 255; + namespace { /// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func' /// ops to enable the ZA storage array. @@ -45,11 +52,6 @@ }; } // namespace -void mlir::populateArmSMELegalizeForLLVMExportPatterns( - LLVMTypeConverter &converter, RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} - void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { target.addLegalOphasAttr("arm_za") || hasDisableZA; }); } + +struct ZeroOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(ZeroOp zero, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = zero.getLoc(); + + // Create 'arm_sme.intr.zero' intrinsic to zero ZA. + auto tile = rewriter.create( + loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask)); + rewriter.create(loc, tile); + + rewriter.eraseOp(zero); + return success(); + } +}; + +struct TileStoreOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(TileStoreOp store, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto memRefType = llvm::dyn_cast(store.getMemRefType()); + if (!memRefType) + return failure(); + + auto loc = store.getLoc(); + + // Create loop that iterates from 0 to SVLB-1 inclusive (the number of + // vectors in ZA) and stores each ZA vector to memory. + auto step = rewriter.create(loc, 1); + auto minElems = rewriter.create(loc, kMinNumElts); + auto vscale = + rewriter.create(loc, rewriter.getIndexType()); + auto lowerBound = rewriter.create(loc, 0); + auto upperBound = rewriter.create(loc, minElems, vscale); + auto forOp = rewriter.create(loc, lowerBound, upperBound, step); + rewriter.setInsertionPointToStart(forOp.getBody()); + + // Create 'arm_sme.intr.str' intrinsic to store ZA vector. + auto vnumI64 = rewriter.create( + loc, rewriter.getI64Type(), forOp.getInductionVar()); + auto offset = + rewriter.create(loc, rewriter.getI64Type(), 0); + Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(), + ValueRange{vnumI64, offset}, rewriter); + auto vnumI32 = rewriter.create( + loc, rewriter.getI32Type(), forOp.getInductionVar()); + rewriter.create(loc, vnumI32, ptr); + + rewriter.eraseOp(store); + return success(); + } +}; + +void mlir::populateArmSMELegalizeForLLVMExportPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); + patterns.add(converter); +} diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LowerVectorOps.cpp @@ -75,28 +75,8 @@ auto tile = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask)); rewriter.create(loc, tile); - - // Create loop that iterates from 0 to SVLB-1 inclusive (the number of - // vectors in ZA) and stores each ZA vector to memory. - auto step = rewriter.create(loc, 1); - auto minElems = rewriter.create(loc, kMinNumElts); - auto vscale = - rewriter.create(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create(loc, 0); - auto upperBound = rewriter.create(loc, minElems, vscale); - auto forOp = rewriter.create(loc, lowerBound, upperBound, step); - rewriter.setInsertionPointToStart(forOp.getBody()); - - // Create 'arm_sme.intr.str' intrinsic to store ZA vector. - auto vnumI64 = rewriter.create( - loc, rewriter.getI64Type(), forOp.getInductionVar()); - auto offset = - rewriter.create(loc, rewriter.getI64Type(), 0); - Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getSource(), - ValueRange{vnumI64, offset}, rewriter); - auto vnumI32 = rewriter.create( - loc, rewriter.getI32Type(), forOp.getInductionVar()); - rewriter.create(loc, vnumI32, ptr); + rewriter.create(loc, write.getSource(), + write.getIndices()); rewriter.eraseOp(write); diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -0,0 +1,29 @@ +// RUN: mlir-opt %s -convert-vector-to-sme -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s + +// CHECK-LABEL: @transfer_write_2d_zero_i8 +// CHECK-SAME: %[[ARG0:.*]]: memref) +// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK: %[[C255:.*]] = arith.constant 255 : i32 +// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () +// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index +// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index +// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 +// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index +// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index +// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index +// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { +// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64 +// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 +// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> +// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 +// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 +// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 +// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32 +// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> () +func.func @transfer_write_2d_zero_i8(%arg0 : memref) { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0> : vector<[16]x[16]xi8> + vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref + return +} diff --git a/mlir/test/Dialect/ArmSME/vector-ops.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir rename from mlir/test/Dialect/ArmSME/vector-ops.mlir rename to mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -1,27 +1,14 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" -split-input-file | mlir-opt | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-sme -split-input-file | mlir-opt | FileCheck %s -// CHECK-LABEL: @transfer_write_2d_zero_i8 -// CHECK-SAME: %[[ARG0:.*]]: memref) -// CHECK-NEXT: %[[MEM_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK: %[[C255:.*]] = arith.constant 255 : i32 -// CHECK-NEXT: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () -// CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index -// CHECK-NEXT: %[[MIN_ZA_VECTORS:.*]] = arith.constant 16 : index -// CHECK-NEXT: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 -// CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index -// CHECK-NEXT: %[[C0_0:.*]] = arith.constant 0 : index -// CHECK-NEXT: %[[NUM_ZA_VECTORS:.*]] = arith.muli %[[MIN_ZA_VECTORS]], %[[VSCALE_IDX]] : index -// CHECK-NEXT: scf.for %[[VNUM:.*]] = %[[C0_0]] to %[[NUM_ZA_VECTORS]] step %[[C1]] { -// CHECK-NEXT: %[[VNUM_I64:.*]] = arith.index_castui %[[VNUM]] : index to i64 -// CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64 -// CHECK-NEXT: %[[ALIGNED_BASE:.*]] = llvm.extractvalue %[[MEM_DESC]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEM_DESC]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> -// CHECK-NEXT: %[[OFF0:.*]] = llvm.mul %[[VNUM_I64]], %[[STRIDE0]] : i64 -// CHECK-NEXT: %[[OFF1:.*]] = llvm.add %[[OFF0]], %[[C0_1]] : i64 -// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALIGNED_BASE]]{{\[}}%[[OFF1]]] : (!llvm.ptr, i64) -> !llvm.ptr, i8 -// CHECK-NEXT: %[[VNUM_I32:.*]] = arith.index_castui %[[VNUM]] : index to i32 -// CHECK-NEXT: "arm_sme.intr.str"(%[[VNUM_I32]], %[[GEP]]) : (i32, !llvm.ptr) -> () -func.func @transfer_write_2d_zero_i8(%arg0 : memref) { + +// CHECK-LABEL: @transfer_write_2d_zero +// CHECK: vector.transfer_write +func.func @transfer_write_2d_zero(%arg0 : memref) { +// CHECK-LABEL: func.func @transfer_write_2d_zero( +// CHECK-SAME: %[[VAL_0:.*]]: memref) { +// CHECK: %[[VAL_1:.*]] = arith.constant 0 : index +// CHECK: arm_sme.zero +// CHECK: arm_sme.tile_store %[[VAL_0]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : memref %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]xi8> vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref @@ -36,7 +23,6 @@ // CHECK-LABEL: @transfer_write_2d_zero__bad_type // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero func.func @transfer_write_2d_zero__bad_type(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]xi4> @@ -48,7 +34,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__bad_shape // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__bad_shape(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[8]x[8]xi8> @@ -60,7 +46,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__bad_rank // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__bad_rank(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]x[16]xi8> @@ -72,7 +58,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__non_memref_type // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__non_memref_type(%arg0 : tensor) -> tensor { %c0 = arith.constant 0 : index %cst = arith.constant dense<0> : vector<[16]x[16]xi8> @@ -84,7 +70,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__non_zero_value // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref) { %c0 = arith.constant 0 : index %cst = arith.constant dense<1> : vector<[16]x[16]xi8> @@ -96,7 +82,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op // CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.intr.zero +// CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref, %arg1 : vector<[16]x[16]xi8>) { %c0 = arith.constant 0 : index vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref