diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -238,7 +238,7 @@ arm_sme.tile_store %0, %arg0[%c0, %c0] : vector<[16]x[16]xi8>, memref ``` }]; - let arguments = (ins nxnxv16i8:$valueToStore, + let arguments = (ins SMETile:$valueToStore, Arg:$base, Variadic:$indices); let extraClassDeclaration = [{ diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir --- a/mlir/test/Dialect/ArmSME/roundtrip.mlir +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -193,10 +193,81 @@ } // ----- - -func.func @arm_sme_store_tile(%tile : vector<[16]x[16]xi8>, %dest : memref) -> () { +func.func @arm_sme_tile_store_i8(%tile : vector<[16]x[16]xi8>, %dest : memref) -> () { // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[16]x[16]xi8> %c0 = arith.constant 0 : index arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[16]x[16]xi8> return } + +// ----- + +func.func @arm_sme_tile_store_i16(%tile : vector<[8]x[8]xi16>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[8]x[8]xi16> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[8]x[8]xi16> + return +} + +// ----- + +func.func @arm_sme_tile_store_i32(%tile : vector<[4]x[4]xi32>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[4]x[4]xi32> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xi32> + return +} + +// ----- + +func.func @arm_sme_tile_store_i64(%tile : vector<[2]x[2]xi64>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[2]x[2]xi64> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[2]x[2]xi64> + return +} + +// ----- + +func.func @arm_sme_tile_store_i128(%tile : vector<[1]x[1]xi128>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[1]x[1]xi128> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[1]x[1]xi128> + return +} + +// ----- + +func.func @arm_sme_tile_store_f16(%tile : vector<[8]x[8]xf16>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[8]x[8]xf16> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[8]x[8]xf16> + return +} + +// ----- + +func.func @arm_sme_tile_store_bf16(%tile : vector<[8]x[8]xbf16>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[8]x[8]xbf16> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[8]x[8]xbf16> + return +} + +// ----- + +func.func @arm_sme_tile_store_f32(%tile : vector<[4]x[4]xf32>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[4]x[4]xf32> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[4]x[4]xf32> + return +} + +// ----- + +func.func @arm_sme_tile_store_f64(%tile : vector<[2]x[2]xf64>, %dest : memref) -> () { + // CHECK: arm_sme.tile_store {{.*}} : memref, vector<[2]x[2]xf64> + %c0 = arith.constant 0 : index + arm_sme.tile_store %tile, %dest[%c0, %c0] : memref, vector<[2]x[2]xf64> + return +}