diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -28,8 +28,7 @@ namespace { -/// Look at `vector.transfer_write` operations and convert suitable candidates -/// to ArmSME operations, e.g.: +/// Conversion pattern for vector.transfer_write. Currently only supports: /// /// %cst = arith.constant dense<0> : vector<[16]x[16]xi8> /// vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref @@ -40,6 +39,8 @@ /// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref, /// vector<[16]x[16]xi8> /// +/// The conversion from arith.constant dense<0> to arm_sme.zero is done in +/// ConstantOpToArmSMELowering. struct TransferWriteToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -56,8 +57,6 @@ if (vType.getScalableDims().size() != 2) return failure(); - auto loc = writeOp.getLoc(); - if (!llvm::isa(writeOp.getSource().getType())) return failure(); @@ -69,10 +68,9 @@ if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) return failure(); - auto zero = rewriter.create(loc, vType); - rewriter.replaceOpWithNewOp( - writeOp, zero, writeOp.getSource(), writeOp.getIndices()); + writeOp, writeOp.getVector(), writeOp.getSource(), + writeOp.getIndices()); return success(); } }; @@ -109,10 +107,38 @@ } }; +/// Conversion pattern for dense arith.constant. +struct ConstantOpToArmSMELowering : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::ConstantOp constantOp, + PatternRewriter &rewriter) const final { + auto vType = dyn_cast(constantOp.getType()); + if (!vType) + return failure(); + if (vType.getRank() != 2) + return failure(); + if (vType.getShape() != ArrayRef({kMinNumElts, kMinNumElts})) + return failure(); + if (vType.getElementType() != rewriter.getI8Type()) + return failure(); + if (vType.getScalableDims().size() != 2) + return failure(); + + auto denseAttr = dyn_cast(constantOp.getValueAttr()); + if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) + return failure(); + + rewriter.replaceOpWithNewOp(constantOp, vType); + + return success(); + } +}; + } // namespace void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns, MLIRContext &ctx) { patterns.add(&ctx); + VectorStoreToArmSMELowering, ConstantOpToArmSMELowering>(&ctx); } diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -1,5 +1,4 @@ -// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file | mlir-opt | FileCheck %s - +// RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s // CHECK-LABEL: func.func @transfer_write_2d_zero( // CHECK-SAME: %[[ARG_0:.*]]: memref) { @@ -16,6 +15,16 @@ // ----- +// CHECK-LABEL: @arith_constant_dense_2d_zero_i8 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8> +func.func @arith_constant_dense_2d_zero_i8() { + %zero = arith.constant dense<0> : vector<[16]x[16]xi8> + "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> () + return +} + +// ----- + // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' // lowering only occurs for vector types of correct rank, shape, element size // and number of scalable dims. @@ -70,6 +79,7 @@ // CHECK-LABEL: @transfer_write_2d_zero__non_zero_value // CHECK: vector.transfer_write +// CHECK-NOT: arm_sme.zero // CHECK-NOT: arm_sme.tile_store func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref) { %c0 = arith.constant 0 : index