diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -203,19 +203,22 @@ def ZeroOp : ArmSME_Op<"zero", [Pure]> { let summary = "Initialize the two-dimensional ZA array with 0s"; - let results = (outs nxnxv16i8:$res); + let results = (outs SMETile:$res); let description = [{ Initialise ZA with 0. This operation is convenient wrapper for the SME `zero` intrinsic and instruction. - NOTE: At the moment it is assumed that the element type is `i8` and that - there's only one "virtual tile". - - Example: + Example 1: Zero an 8-bit element ZA tile. ```mlir %0 = arm_sme.zero : vector<[16]x[16]xi8> ``` + + Example 2: Zero a 64-bit element ZA tile. + + ```mlir + %0 = arm_sme.zero : vector<[2]x[2]xi64> + ``` }]; let extraClassDeclaration = [{ VectorType getVectorType() { diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp --- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp +++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp @@ -15,8 +15,6 @@ using namespace mlir; -static constexpr unsigned kMinNumElts = 16; - /// Returns true if 'val' is a splat of zero, false otherwise. static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa(elemType)) @@ -96,15 +94,7 @@ LogicalResult matchAndRewrite(arith::ConstantOp constantOp, PatternRewriter &rewriter) const final { auto vType = dyn_cast(constantOp.getType()); - if (!vType) - return failure(); - if (vType.getRank() != 2) - return failure(); - if (vType.getShape() != ArrayRef({kMinNumElts, kMinNumElts})) - return failure(); - if (vType.getElementType() != rewriter.getI8Type()) - return failure(); - if (vType.getScalableDims().size() != 2) + if (!vType || !arm_sme::isValidSMETileVectorType(vType)) return failure(); auto denseAttr = dyn_cast(constantOp.getValueAttr()); diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -184,7 +184,7 @@ // ----- -func.func @arm_sme_zero() { +func.func @arm_sme_zero_i8() { // CHECK: arm_sme.zero : vector<[16]x[16]xi8> %0 = arm_sme.zero : vector<[16]x[16]xi8> return @@ -192,6 +192,70 @@ // ----- +func.func @arm_sme_zero_i16() { + // CHECK: arm_sme.zero : vector<[8]x[8]xi16> + %0 = arm_sme.zero : vector<[8]x[8]xi16> + return +} + +// ----- + +func.func @arm_sme_zero_i32() { + // CHECK: arm_sme.zero : vector<[4]x[4]xi32> + %0 = arm_sme.zero : vector<[4]x[4]xi32> + return +} + +// ----- + +func.func @arm_sme_zero_i64() { + // CHECK: arm_sme.zero : vector<[2]x[2]xi64> + %0 = arm_sme.zero : vector<[2]x[2]xi64> + return +} + +// ----- + +func.func @arm_sme_zero_i128() { + // CHECK: arm_sme.zero : vector<[1]x[1]xi128> + %0 = arm_sme.zero : vector<[1]x[1]xi128> + return +} + +// ----- + +func.func @arm_sme_zero_f16() { + // CHECK: arm_sme.zero : vector<[8]x[8]xf16> + %0 = arm_sme.zero : vector<[8]x[8]xf16> + return +} + +// ----- + +func.func @arm_sme_zero_bf16() { + // CHECK: arm_sme.zero : vector<[8]x[8]xbf16> + %0 = arm_sme.zero : vector<[8]x[8]xbf16> + return +} + +// ----- + +func.func @arm_sme_zero_f32() { + // CHECK: arm_sme.zero : vector<[4]x[4]xf32> + %0 = arm_sme.zero : vector<[4]x[4]xf32> + return +} + +// ----- + +func.func @arm_sme_zero_f64() { + // CHECK: arm_sme.zero : vector<[2]x[2]xf64> + %0 = arm_sme.zero : vector<[2]x[2]xf64> + return +} + +// ----- + func.func @arm_sme_tile_load_i8(%src : memref) { // CHECK: arm_sme.tile_load {{.*}} : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir @@ -104,16 +104,6 @@ // ----- -// CHECK-LABEL: @arith_constant_dense_2d_zero_i8 -// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8> -func.func @arith_constant_dense_2d_zero_i8() { - %zero = arith.constant dense<0> : vector<[16]x[16]xi8> - "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> () - return -} - -// ----- - // The following tests check the 'vector.transfer_write' -> 'arm_sme.intr.zero' // lowering only occurs for vector types of correct rank, shape, element size // and number of scalable dims. @@ -163,3 +153,87 @@ %0 = vector.transfer_write %cst, %arg0[%c0, %c0] {in_bounds = [true, true]} : vector<[16]x[16]xi8>, tensor return %0 : tensor } + +// ============================================================================= +// arith.constant dense<0> to arm_sme.zero +// ============================================================================= + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_i8 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[16]x[16]xi8> +func.func @arith_constant_dense_2d_zero_i8() { + %zero = arith.constant dense<0> : vector<[16]x[16]xi8> + "prevent.dce"(%zero) : (vector<[16]x[16]xi8>) -> () + return +} + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_i16 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xi16> +func.func @arith_constant_dense_2d_zero_i16() { + %zero = arith.constant dense<0> : vector<[8]x[8]xi16> + "prevent.dce"(%zero) : (vector<[8]x[8]xi16>) -> () + return +} + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_i32 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xi32> +func.func @arith_constant_dense_2d_zero_i32() { + %zero = arith.constant dense<0> : vector<[4]x[4]xi32> + "prevent.dce"(%zero) : (vector<[4]x[4]xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_i64 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[2]x[2]xi64> +func.func @arith_constant_dense_2d_zero_i64() { + %zero = arith.constant dense<0> : vector<[2]x[2]xi64> + "prevent.dce"(%zero) : (vector<[2]x[2]xi64>) -> () + return +} + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_f16 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xf16> +func.func @arith_constant_dense_2d_zero_f16() { + %zero = arith.constant dense<0.0> : vector<[8]x[8]xf16> + "prevent.dce"(%zero) : (vector<[8]x[8]xf16>) -> () + return +} + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_bf16 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[8]x[8]xbf16> +func.func @arith_constant_dense_2d_zero_bf16() { + %zero = arith.constant dense<0.0> : vector<[8]x[8]xbf16> + "prevent.dce"(%zero) : (vector<[8]x[8]xbf16>) -> () + return +} + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_f32 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[4]x[4]xf32> +func.func @arith_constant_dense_2d_zero_f32() { + %zero = arith.constant dense<0.0> : vector<[4]x[4]xf32> + "prevent.dce"(%zero) : (vector<[4]x[4]xf32>) -> () + return +} + +// ----- + +// CHECK-LABEL: @arith_constant_dense_2d_zero_f64 +// CHECK: %[[ZERO:.*]] = arm_sme.zero : vector<[2]x[2]xf64> +func.func @arith_constant_dense_2d_zero_f64() { + %zero = arith.constant dense<0.0> : vector<[2]x[2]xf64> + "prevent.dce"(%zero) : (vector<[2]x[2]xf64>) -> () + return +}