Index: mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp =================================================================== --- mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -28,19 +28,15 @@ namespace { -/// Conversion pattern for vector.transfer_write. Currently only supports: +/// Conversion pattern for vector.transfer_write. /// -/// %cst = arith.constant dense<0> : vector<[16]x[16]xi8> -/// vector.transfer_write %cst, %arg0 : vector<[16]x[16]xi8>, memref +/// vector.transfer_write %vector, %source[%c0, %c0] : vector<[16]x[16]xi8>, +/// memref /// /// is converted to: /// -/// %0 = arm_sme.zero : vector<[16]x[16]xi8> -/// arm_sme.tile_store %arg0[%c0, %c0], %0 : memref, -/// vector<[16]x[16]xi8> -/// -/// The conversion from arith.constant dense<0> to arm_sme.zero is done in -/// ConstantOpToArmSMELowering. +/// arm_sme.tile_store %vector, %source[%c0, %c0] : memref, +/// vector<[16]x[16]xi8> struct TransferWriteToArmSMELowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -48,26 +44,12 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const final { auto vType = writeOp.getVectorType(); - if (vType.getRank() != 2) - return failure(); - if (vType.getShape() != ArrayRef({kMinNumElts, kMinNumElts})) - return failure(); - if (vType.getElementType() != rewriter.getI8Type()) - return failure(); - if (vType.getScalableDims().size() != 2) + if (!arm_sme::isValidSMETileVectorType(vType)) return failure(); if (!llvm::isa(writeOp.getSource().getType())) return failure(); - auto constant = writeOp.getVector().getDefiningOp(); - if (!constant) - return failure(); - - auto denseAttr = dyn_cast(constant.getValueAttr()); - if (!denseAttr || !isSplatZero(vType.getElementType(), denseAttr)) - return failure(); - rewriter.replaceOpWithNewOp( writeOp, writeOp.getVector(), writeOp.getSource(), writeOp.getIndices()); Index: mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir =================================================================== --- mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -1,15 +1,104 @@ // RUN: mlir-opt %s -convert-vector-to-arm-sme -split-input-file -allow-unregistered-dialect | FileCheck %s -// CHECK-LABEL: func.func @transfer_write_2d_zero( -// CHECK-SAME: %[[ARG_0:.*]]: memref) { -func.func @transfer_write_2d_zero(%arg0 : memref) { -// CHECK: %[[C_0:.*]] = arith.constant 0 : index -// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8> -// CHECK: arm_sme.tile_store %[[ZERO]], %[[ARG_0]][%[[C_0]], %[[C_0]]] : memref, vector<[16]x[16]xi8> -// CHECK: return +// CHECK-LABEL: func.func @transfer_write_2d_i8( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[16]x[16]xi8>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[16]x[16]xi8> +func.func @transfer_write_2d_i8(%vector : vector<[16]x[16]xi8>, %dest : memref) { %c0 = arith.constant 0 : index - %cst = arith.constant dense<0> : vector<[16]x[16]xi8> - vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_2d_i16( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xi16>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[8]x[8]xi16> +func.func @transfer_write_2d_i16(%vector : vector<[8]x[8]xi16>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xi16>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_2d_i32( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xi32>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[4]x[4]xi32> +func.func @transfer_write_2d_i32(%vector : vector<[4]x[4]xi32>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xi32>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_2d_i64( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xi64>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[2]x[2]xi64> +func.func @transfer_write_2d_i64(%vector : vector<[2]x[2]xi64>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xi64>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_2d_f16( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xf16>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[8]x[8]xf16> +func.func @transfer_write_2d_f16(%vector : vector<[8]x[8]xf16>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xf16>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_2d_bf16( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[8]x[8]xbf16>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[8]x[8]xbf16> +func.func @transfer_write_2d_bf16(%vector : vector<[8]x[8]xbf16>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[8]x[8]xbf16>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_2d_f32( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[4]x[4]xf32>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[4]x[4]xf32> +func.func @transfer_write_2d_f32(%vector : vector<[4]x[4]xf32>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[4]x[4]xf32>, memref + return +} + +// ----- + +// CHECK-LABEL: func.func @transfer_write_2d_f64( +// CHECK-SAME: %[[VECTOR:.*]]: vector<[2]x[2]xf64>, +// CHECK-SAME: %[[DEST:.*]]: memref) { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: arm_sme.tile_store %[[VECTOR]], %[[DEST]]{{\[}}%[[C0]], %[[C0]]] : memref, vector<[2]x[2]xf64> +func.func @transfer_write_2d_f64(%vector : vector<[2]x[2]xf64>, %dest : memref) { + %c0 = arith.constant 0 : index + vector.transfer_write %vector, %dest[%c0, %c0] {in_bounds = [true, true]} : vector<[2]x[2]xf64>, memref return } @@ -74,27 +163,3 @@ %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor return %0 : tensor } - -// ----- - -// CHECK-LABEL: @transfer_write_2d_zero__non_zero_value -// CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.zero -// CHECK-NOT: arm_sme.tile_store -func.func @transfer_write_2d_zero__non_zero_value(%arg0 : memref) { - %c0 = arith.constant 0 : index - %cst = arith.constant dense<1> : vector<[16]x[16]xi8> - vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref - return -} - -// ----- - -// CHECK-LABEL: @transfer_write_2d_zero__vec_unknown_defining_op -// CHECK: vector.transfer_write -// CHECK-NOT: arm_sme.tile_store -func.func @transfer_write_2d_zero__vec_unknown_defining_op(%arg0 : memref, %arg1 : vector<[16]x[16]xi8>) { - %c0 = arith.constant 0 : index - vector.transfer_write %arg1, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, memref - return -}