diff --git a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td --- a/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ b/mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -36,6 +36,166 @@ let dependentDialects = ["scf::SCFDialect"]; } +//===----------------------------------------------------------------------===// +// ArmSME type definitions +//===----------------------------------------------------------------------===// + +class SMETileType dims, string description> + : ShapedContainerType<[datatype], + And<[IsVectorOfRankPred<[2]>, allDimsScalableVectorTypePred, + IsVectorOfShape]>, + description>; + +def nxnxv16i8 : SMETileType">; +def nxnxv8i16 : SMETileType">; +def nxnxv4i32 : SMETileType">; +def nxnxv2i64 : SMETileType">; +def nxnxv1i128 : SMETileType">; + +def nxnxv8f16 : SMETileType">; +def nxnxv8bf16 : SMETileType">; +def nxnxv4f32 : SMETileType">; +def nxnxv2f64 : SMETileType">; + +def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128, + nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>; + +// A type constraint that verifies the bitwidth of the scalar integer returned +// from 'arm_sme.get_tile_id' matches the element bitwidth of a "virtual tile". +def TileElementWidthMatchesTileID : TypesMatchWith< + "`tile_id` has the same number of bits as elements in `vector`", + "vector", "tile_id", + "IntegerType::get(" + "$_self.getContext()," + "::llvm::isa(::llvm::cast($_self).getElementType())" + "? ::llvm::cast(" + "::llvm::cast($_self).getElementType())" + ".getWidth()" + ": ::llvm::cast(" + "::llvm::cast($_self).getElementType())" + ".getWidth())">; + +//===----------------------------------------------------------------------===// +// ArmSME op definitions +//===----------------------------------------------------------------------===// + +class ArmSME_Op traits = []> : + Op {} + +def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure, TileElementWidthMatchesTileID]> { + let summary = "Cast from tile id to 2-d scalable vector type"; + let description = [{ + A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d + scalable vector type, which represents an SME "virtual tile". This would + normally be used when lowering operations that return "virtual tile" vector + types to model the output. This is required to preserve dataflow as SME + intrinsics have no return values. + + Example: + + Input: + ```mlir + %tile = vector.load %mem1[%c0] : memref, vector<[4]x[4]xi32> + vector.store %tile, %mem2[%c0] : memref, vector<[4]x[4]xi32> + ``` + + After lowering `vector.load`: + ```mlir + %tile_id = arm_sme.get_tile_id : i32 + scf.for %vnum = %c0 to %num_vectors step %c1 { + // ... + "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + } + %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> + vector.store %tile, %mem2[%c0] : memref, vector<[4]x[4]xi32> + ``` + + In the example above, the `vector.load` can't be replaced with an SME + intrinsic that has no outputs since it is used by the `vector.store`. + However, by inserting a `cast_tile_to_vector` op after the load intrinsics + the `vector.load` can be replaced. This enables "local" rewrites on + individual vector ops, rather than "global" rewrites that would have to + look at the vector op uses and also lower them. + + Canonicalization will look through `arm_sme.cast_tile_to_vector` and fold + the cast away if it comes from a `arm_sme.cast_vector_to_tile`. + }]; + let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id); + let results = (outs SMETile:$vector); + let assemblyFormat = + "$tile_id attr-dict `:` type($tile_id) `to` type($vector)"; + let hasCanonicalizeMethod = 1; +} + +def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure, TileElementWidthMatchesTileID]> { + let summary = "Cast from 2-d scalable vector type to tile id"; + let description = [{ + A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector + type, which represents an SME "virtual tile", to a tile id. This is + required to preserve dataflow as the SME intrinsics have no return values. + + Example: + + Input: + ```mlir + %tile = vector.load %mem1[%c0] : memref, vector<[4]x[4]xi32> + vector.store %tile, %mem2[%c0] : memref, vector<[4]x[4]xi32> + ``` + + After lowering `vector.store`: + ```mlir + %tile = vector.load %mem1[%c0] : memref, vector<[4]x[4]xi32> + scf.for %vnum = %c0 to %num_vectors step %c1 { + // ... + %tile_id = arm_sme.cast_vector_to_tile %tile : (vector<[4]x[4]xi32>) -> i32 + "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + } + ``` + + Canonicalization will look through `arm_sme.cast_vector_to_tile` and fold + the cast away if it comes from a `arm_sme.cast_tile_to_vector`. + }]; + let arguments = (ins SMETile:$vector); + let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id); + let assemblyFormat = + "$vector attr-dict `:` type($vector) `to` type($tile_id)"; + let hasCanonicalizeMethod = 1; +} + +def GetTileID : ArmSME_Op<"get_tile_id", [Pure]> { + let summary = "Returns an SME \"virtual tile\" id"; + let description = [{ + A `get_tile_id` operation returns a scalar integer representing an SME + "virtual tile" id. The bitwidth of the scalar indicates the element + bitwidth of the "virtual tile". + + The scope of a tile id is a function and cannot be passed or returned from + functions. + + Example: + ```mlir + // Allocate and return an 8-bit element "virtual tile" id + %za0_b = arm_sme.get_tile_id : i8 + ``` + + Example: + ``` + // Allocate and return two 16-bit element "virtual tile" ids + %za0_h = arm_sme.get_tile_id : i16 + %za1_h = arm_sme.get_tile_id : i16 + ``` + + Example: + ``` + // Allocate and return an 128-bit element "virtual tile" id + %za0_q = arm_sme.get_tile_id : i128 + ``` + }]; + + let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id); + let assemblyFormat = "attr-dict `:` type($tile_id)"; +} + //===----------------------------------------------------------------------===// // ArmSME Intrinsic op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -308,6 +308,12 @@ def IsScalableVectorTypePred : CPred<[{::llvm::isa<::mlir::VectorType>($_self) && ::llvm::cast($_self).isScalable()}]>; +// Whether a type is a VectorType and all dimensions are scalable. +def allDimsScalableVectorTypePred : And<[ + IsVectorTypePred, + CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]> +]>; + // Whether a type is a TensorType. def IsTensorTypePred : CPred<"::llvm::isa<::mlir::TensorType>($_self)">; @@ -488,6 +494,7 @@ def I16 : I<16>; def I32 : I<32>; def I64 : I<64>; +def I128 : I<128>; // Any signed integer type irrespective of its width. def AnySignedInteger : Type< @@ -745,6 +752,10 @@ == }] # allowedlength>)>]>; +// Whether the shape of a vector matches the given `shape` list. +class IsVectorOfShape shape> + : CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef({" # !interleave(shape, ", ") # "})">; + // Any vector where the number of elements is from the given // `allowedLengths` list class VectorOfLength allowedLengths> : Type< diff --git a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp --- a/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp @@ -34,3 +34,23 @@ #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" >(); } + +// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id +LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op, + PatternRewriter &rewriter) { + if (auto castTileToVectorOp = op.getVector().getDefiningOp()) { + op.replaceAllUsesWith(castTileToVectorOp.getTileId()); + return success(); + } + return failure(); +} + +// cast_tile_to_vector(cast_vector_to_tile(tile)) -> tile +LogicalResult CastTileToVector::canonicalize(CastTileToVector op, + PatternRewriter &rewriter) { + if (auto castVectorToTileOp = op.getTileId().getDefiningOp()) { + op.replaceAllUsesWith(castVectorToTileOp.getVector()); + return success(); + } + return failure(); +} diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" using namespace mlir; @@ -43,6 +45,17 @@ return success(); } }; + +struct GetTileIDConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(GetTileID op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: implement tile allocation, currently only tile 0 is supported. + rewriter.replaceOpWithNewOp(op, rewriter.getI32Type(), 0); + return success(); + } +}; } // namespace void mlir::populateArmSMELegalizeForLLVMExportPatterns( @@ -52,9 +65,11 @@ void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); + target.addLegalOp(); // Mark 'func.func' ops as legal if either: // 1. no 'arm_za' function attribute is present. diff --git a/mlir/test/Dialect/ArmSME/canonicalize.mlir b/mlir/test/Dialect/ArmSME/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/canonicalize.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s + +// ----- + +// CHECK-LABEL: @cast_vector_to_tile__cast_tile_to_vector +// CHECK-SAME: %[[TILE_ID:.*]]: i8 +func.func @cast_vector_to_tile__cast_tile_to_vector(%tile_id_0 : i8) -> i8 { + // CHECK-NOT: arm_sme.cast_tile_to_vector + // CHECK-NOT: arm_sme.cast_vector_to_tile + // CHECK-NEXT: return %[[TILE_ID]] : i8 + %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8> + %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8 + return %tile_id_1 : i8 +} + +// ----- + +// CHECK-LABEL: @cast_tile_to_vector__cast_vector_to_tile +// CHECK-SAME: %[[TILE:.*]]: vector<[16]x[16]xi8> +func.func @cast_tile_to_vector__cast_vector_to_tile(%tile_0 : vector<[16]x[16]xi8>) -> vector<[16]x[16]xi8> { + // CHECK-NOT: arm_sme.cast_vector_to_tile + // CHECK-NOT: arm_sme.cast_tile_to_vector + // CHECK-NEXT: return %[[TILE]] : vector<[16]x[16]xi8> + %tile_id = arm_sme.cast_vector_to_tile %tile_0 : vector<[16]x[16]xi8> to i8 + %tile_1 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8> + return %tile_1 : vector<[16]x[16]xi8> +} diff --git a/mlir/test/Dialect/ArmSME/invalid.mlir b/mlir/test/Dialect/ArmSME/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/invalid.mlir @@ -0,0 +1,73 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func.func @arm_sme_cast_tile_to_vector__bad_tile_id_bitwidth(%tile_id : i8) -> vector<[8]x[8]xi16> { + // expected-error@+1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}} + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[8]x[8]xi16> + return %0 : vector<[8]x[8]xi16> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector__bad_vector_type_rank_1(%tile_id : i8) -> vector<[16]xi8> { + // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}} + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8> + return %0 : vector<[16]xi8> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector__bad_vector_type_i4(%tile_id : i8) -> vector<[16]x[16]xi4> { + // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x[16]xi4>'}} + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi4> + return %0 : vector<[16]x[16]xi4> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_0(%tile_id : i8) -> vector<16x[16]xi8> { + // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<16x[16]xi8>'}} + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<16x[16]xi8> + return %0 : vector<16x[16]xi8> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector__bad_vector_type_non_scalable_dim_1(%tile_id : i8) -> vector<[16]x16xi8> { + // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]x16xi8>'}} + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x16xi8> + return %0 : vector<[16]x16xi8> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_bad_shape(%tile_id : i8) -> vector<[4]x[16]xi8> { + // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'}} + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8> + return %0 : vector<[4]x[16]xi8> +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile__bad_tile_id_bitwidth(%vector : vector<[1]x[1]xi128>) -> i32 { + // expected-error@+1 {{op failed to verify that `tile_id` has the same number of bits as elements in `vector`}} + %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i32 + return %0 : i32 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 { + // expected-error@+1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}} + %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8 + return %0 : i8 +} + +// ----- + +func.func @arm_sme_get_tile_id__bad_type() -> i1 { + // expected-error@+1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}} + %0 = arm_sme.get_tile_id : i1 + return %0 : i1 +} diff --git a/mlir/test/Dialect/ArmSME/roundtrip.mlir b/mlir/test/Dialect/ArmSME/roundtrip.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -0,0 +1,185 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s + +// ----- + +func.func @arm_sme_cast_tile_to_vector_i8(%tile_id : i8) -> vector<[16]x[16]xi8> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8> + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8> + return %0 : vector<[16]x[16]xi8> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_i16(%tile_id : i16) -> vector<[8]x[8]xi16> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xi16> + %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xi16> + return %0 : vector<[8]x[8]xi16> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_i32(%tile_id : i32) -> vector<[4]x[4]xi32> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xi32> + %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> + return %0 : vector<[4]x[4]xi32> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_i64(%tile_id : i64) -> vector<[2]x[2]xi64> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xi64> + %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xi64> + return %0 : vector<[2]x[2]xi64> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_i128(%tile_id : i128) -> vector<[1]x[1]xi128> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i128 to vector<[1]x[1]xi128> + %0 = arm_sme.cast_tile_to_vector %tile_id : i128 to vector<[1]x[1]xi128> + return %0 : vector<[1]x[1]xi128> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_f16(%tile_id : i16) -> vector<[8]x[8]xf16> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xf16> + %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xf16> + return %0 : vector<[8]x[8]xf16> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_bf16(%tile_id : i16) -> vector<[8]x[8]xbf16> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i16 to vector<[8]x[8]xbf16> + %0 = arm_sme.cast_tile_to_vector %tile_id : i16 to vector<[8]x[8]xbf16> + return %0 : vector<[8]x[8]xbf16> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_f32(%tile_id : i32) -> vector<[4]x[4]xf32> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i32 to vector<[4]x[4]xf32> + %0 = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xf32> + return %0 : vector<[4]x[4]xf32> +} + +// ----- + +func.func @arm_sme_cast_tile_to_vector_f64(%tile_id : i64) -> vector<[2]x[2]xf64> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i64 to vector<[2]x[2]xf64> + %0 = arm_sme.cast_tile_to_vector %tile_id : i64 to vector<[2]x[2]xf64> + return %0 : vector<[2]x[2]xf64> +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_i8(%vector : vector<[16]x[16]xi8>) -> i8 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8 + return %0 : i8 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_i16(%vector : vector<[8]x[8]xi16>) -> i16 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xi16> to i16 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xi16> to i16 + return %0 : i16 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_i32(%vector : vector<[4]x[4]xi32>) -> i32 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xi32> to i32 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xi32> to i32 + return %0 : i32 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_i64(%vector : vector<[2]x[2]xi64>) -> i64 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xi64> to i64 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xi64> to i64 + return %0 : i64 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_i128(%vector : vector<[1]x[1]xi128>) -> i128 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[1]x[1]xi128> to i128 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[1]x[1]xi128> to i128 + return %0 : i128 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_f16(%vector : vector<[8]x[8]xf16>) -> i16 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xf16> to i16 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xf16> to i16 + return %0 : i16 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_bf16(%vector : vector<[8]x[8]xbf16>) -> i16 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[8]x[8]xbf16> to i16 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[8]x[8]xbf16> to i16 + return %0 : i16 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_f32(%vector : vector<[4]x[4]xf32>) -> i32 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[4]x[4]xf32> to i32 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[4]x[4]xf32> to i32 + return %0 : i32 +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile_f64(%vector : vector<[2]x[2]xf64>) -> i64 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[2]x[2]xf64> to i64 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[2]x[2]xf64> to i64 + return %0 : i64 +} + +// ----- + +func.func @arm_sme_get_tile_id_i8() -> i8 { + // CHECK: arm_sme.get_tile_id : i8 + %0 = arm_sme.get_tile_id : i8 + return %0 : i8 +} + +// ----- + +func.func @arm_sme_get_tile_id_i16() -> i16 { + // CHECK: arm_sme.get_tile_id : i16 + %0 = arm_sme.get_tile_id : i16 + return %0 : i16 +} + +// ----- + +func.func @arm_sme_get_tile_id_i32() -> i32 { + // CHECK: arm_sme.get_tile_id : i32 + %0 = arm_sme.get_tile_id : i32 + return %0 : i32 +} + +// ----- + +func.func @arm_sme_get_tile_id_i64() -> i64 { + // CHECK: arm_sme.get_tile_id : i64 + %0 = arm_sme.get_tile_id : i64 + return %0 : i64 +} + +// ----- + +func.func @arm_sme_get_tile_id_i128() -> i128 { + // CHECK: arm_sme.get_tile_id : i128 + %0 = arm_sme.get_tile_id : i128 + return %0 : i128 +}