Index: mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td =================================================================== --- mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td +++ mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td @@ -36,6 +36,167 @@ let dependentDialects = ["scf::SCFDialect"]; } +//===----------------------------------------------------------------------===// +// ArmSME type definitions +//===----------------------------------------------------------------------===// + +class SMETileType dims, string description> + : ShapedContainerType<[datatype], + And<[IsVectorOfRankPred<[2]>, + CPred<[{::llvm::cast<::mlir::VectorType>($_self).allDimsScalable()}]>, + CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef({" # !interleave(dims, ", ") # "})">]>, + description>; + +def nxnxv16i8 : SMETileType">; +def nxnxv8i16 : SMETileType">; +def nxnxv4i32 : SMETileType">; +def nxnxv2i64 : SMETileType">; +def nxnxv1i128 : SMETileType">; + +def nxnxv8f16 : SMETileType">; +def nxnxv8bf16 : SMETileType">; +def nxnxv4f32 : SMETileType">; +def nxnxv2f64 : SMETileType">; + +def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128, + nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>; + +//===----------------------------------------------------------------------===// +// ArmSME op definitions +//===----------------------------------------------------------------------===// + +class ArmSME_Op traits = []> : + Op {} + +def CastTileToVector : ArmSME_Op<"cast_tile_to_vector", [Pure]> { + let summary = "Cast from tile id to 2-d scalable vector type"; + let description = [{ + A `cast_tile_to_vector` operation does a cast from a tile id to a 2-d + scalable vector type, which represents an SME "virtual tile". This is used + in conjunction with `cast_vector_to_tile` to preserve dataflow and type + legality when lowering vector ops that have both inputs and outputs, to SME + intrinsics that have only inputs. + + Example: + ```mlir + + // input + %tile = vector.load %mem1[%c0] : memref, vector<[4]x[4]xi32> + vector.store %tile, %mem2[%c0] : memref, vector<[4]x[4]xi32> + + // lower vector.load -> SME intrinsics + + %tile_id = arm_sme.get_tile_id : i32 + scf.for %vnum = %c0 to %num_vectors step %c1 { + // ... + "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + } + %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> + vector.store %tile, %mem2[%c0] : memref, vector<[4]x[4]xi32> + ``` + + In the example above, the `vector.load` can't be replaced with an SME + intrinsic that has no outputs since it is used by the `vector.store`. + However, by inserting a `cast_tile_to_vector` op after the load intrinsics + the `vector.load` can be replaced. This enables "local" rewrites on + individual vector ops, rather than "global" rewrites that would have to + look at the vector op uses and also lower them. + + The opposite is true for the `vector.store`, when lowered to intrinsics + they would be preceded by a `cast_vector_to_tile` op. Once the lowering is + complete the canonicalizer will fold the casts away. The + `cast_vector_to_tile` op example shows the other half of the lowering. + + These casts are expected to be folded, but may persist if there's an + incomplete lowering where a vector op has been lowered to SME but the uses + haven't, much like if `-reconcile-unrealized-casts` fails. Currently these + cast ops cannot be lowered to LLVM, but may be in the future. + }]; + let arguments = (ins AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id); + let results = (outs SMETile:$vector); + let assemblyFormat = + "$tile_id attr-dict `:` type($tile_id) `to` type($vector)"; +} + +def CastVectorToTile : ArmSME_Op<"cast_vector_to_tile", [Pure]> { + let summary = "Cast from 2-d scalable vector type to tile id"; + let description = [{ + A `cast_vector_to_tile` operation does a cast from a 2-d scalable vector + type, which represents an SME "virtual tile", to a tile id. This is used in + conjunction with `cast_tile_to_vector` to preserve dataflow and type + legality when lowering vector ops that have both inputs and outputs, to SME + intrinsics that have only inputs. + + Example: + ```mlir + + // input + %tile = vector.load %mem1[%c0] : memref, vector<[4]x[4]xi32> + vector.store %tile, %mem2[%c0] : memref, vector<[4]x[4]xi32> + + // lower vector.load -> SME intrinsics + %tile_id = arm_sme.get_tile_id : i32 + scf.for %vnum = %c0 to %num_vectors step %c1 { + // ... + "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + } + %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> + vector.store %tile, %mem2[%c0] : memref, vector<[4]x[4]xi32> + + // lower vector.store -> SME intrinsics + %tile_id_0 = arm_sme.get_tile_id : i32 + scf.for %vnum = %c0 to %num_vectors step %c1 { + // ... + "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id_0, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + } + %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> + scf.for %vnum = %c0 to %num_vectors step %c1 { + // ... + %tile_id_1 = arm_sme.cast_vector_to_tile %tile : : (vector<[4]x[4]xi32>) -> i32 + "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id_1, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> () + } + + // canonicalization will look through `cast_vector_to_tile` and fold the + // cast ops away. + ``` + }]; + let arguments = (ins SMETile:$vector); + let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id); + let assemblyFormat = + "$vector attr-dict `:` type($vector) `to` type($tile_id)"; + let hasCanonicalizeMethod = 1; +} + +def GetTileID : ArmSME_Op<"get_tile_id", [Pure]> { + let summary = "Returns an SME \"virtual tile\" id that is not in use"; + let description = [{ + A `get_tile_id` operation returns a scalar integer of given type + representing a tile id of an SME "virtual tile" that is not in use. + + Example: + ```mlir + // Allocate an 8-bit element ZA tile + %za0_b = arm_sme.get_tile_id : i8 + ``` + + Example: + ``` + // Allocate two 16-bit element ZA tiles + %za0_h = arm_sme.get_tile_id : i16 + %za1_h = arm_sme.get_tile_id : i16 + ``` + + Example: + ``` + // Allocate a 128-bit element ZA tile + %za0_q = arm_sme.get_tile_id : i128 + ``` + }]; + + let results = (outs AnyTypeOf<[I8, I16, I32, I64, I128]>:$tile_id); + let assemblyFormat = "attr-dict `:` type($tile_id)"; +} + //===----------------------------------------------------------------------===// // ArmSME Intrinsic op definitions //===----------------------------------------------------------------------===// Index: mlir/include/mlir/IR/OpBase.td =================================================================== --- mlir/include/mlir/IR/OpBase.td +++ mlir/include/mlir/IR/OpBase.td @@ -488,6 +488,7 @@ def I16 : I<16>; def I32 : I<32>; def I64 : I<64>; +def I128 : I<128>; // Any signed integer type irrespective of its width. def AnySignedInteger : Type< Index: mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp =================================================================== --- mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp +++ mlir/lib/Dialect/ArmSME/IR/ArmSME.cpp @@ -34,3 +34,14 @@ #include "mlir/Dialect/ArmSME/IR/ArmSME.cpp.inc" >(); } + +// cast_vector_to_tile(cast_tile_to_vector(tile_id)) -> tile_id +LogicalResult CastVectorToTile::canonicalize(CastVectorToTile op, + PatternRewriter &rewriter) { + if (auto vectorOp = + dyn_cast(op.getVector().getDefiningOp())) { + op.replaceAllUsesWith(vectorOp.getTileId()); + return success(); + } + return failure(); +} Index: mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp =================================================================== --- mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -7,9 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/ArmSME/IR/ArmSME.h" #include "mlir/Dialect/ArmSME/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" using namespace mlir; @@ -43,6 +45,17 @@ return success(); } }; + +struct GetTileIDConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(GetTileID op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // TODO: implement tile allocation, currently only tile 0 is supported. + rewriter.replaceOpWithNewOp(op, rewriter.getI32Type(), 0); + return success(); + } +}; } // namespace void mlir::populateArmSMELegalizeForLLVMExportPatterns( @@ -52,9 +65,11 @@ void mlir::configureArmSMELegalizeForExportTarget( LLVMConversionTarget &target) { - target.addLegalOp(); + target.addLegalOp(); // Mark 'func.func' ops as legal if either: // 1. no 'arm_za' function attribute is present. Index: mlir/test/Dialect/ArmSME/canonicalize.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/canonicalize.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-opt -canonicalize -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s + +// ----- + +// CHECK-LABEL: @canonicalize_casts +// CHECK-SAME: %[[TILE_ID:.*]]: i8 +func.func @canonicalize_casts(%tile_id_0 : i8) -> i8 { + // CHECK-NOT: arm_sme.cast_tile_to_vector + // CHECK-NOT: arm_sme.cast_vector_to_tile + %tile = arm_sme.cast_tile_to_vector %tile_id_0 : i8 to vector<[16]x[16]xi8> + %tile_id_1 = arm_sme.cast_vector_to_tile %tile : vector<[16]x[16]xi8> to i8 + // CHECK-NEXT: return %[[TILE_ID]] : i8 + return %tile_id_1 : i8 +} Index: mlir/test/Dialect/ArmSME/invalid.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/invalid.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +// ----- + +func.func @arm_sme_cast_tile_to_vector__bad_vector_type(%tile_id : i8) -> vector<[16]xi8> { + // expected-error@+1 {{op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}} + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]xi8> + return %0 : vector<[16]xi8> +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile__bad_rank_1d(%vector : vector<[16]xi8>) -> i8 { + // expected-error@+1 {{op operand #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[16]xi8>'}} + %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]xi8> to i8 + return %0 : i8 +} + +// ----- + +func.func @arm_sme_get_tile_id__bad_type() -> i1 { + // expected-error@+1 {{op result #0 must be 8-bit signless integer or 16-bit signless integer or 32-bit signless integer or 64-bit signless integer or 128-bit signless integer}} + %0 = arm_sme.get_tile_id : i1 + return %0 : i1 +} Index: mlir/test/Dialect/ArmSME/roundtrip.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/ArmSME/roundtrip.mlir @@ -0,0 +1,25 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt | FileCheck %s + +// ----- + +func.func @arm_sme_cast_tile_to_vector(%tile_id : i8) -> vector<[16]x[16]xi8> { + // CHECK: arm_sme.cast_tile_to_vector {{.*}} : i8 to vector<[16]x[16]xi8> + %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8> + return %0 : vector<[16]x[16]xi8> +} + +// ----- + +func.func @arm_sme_cast_vector_to_tile(%vector : vector<[16]x[16]xi8>) -> i8 { + // CHECK: arm_sme.cast_vector_to_tile {{.*}} : vector<[16]x[16]xi8> to i8 + %0 = arm_sme.cast_vector_to_tile %vector : vector<[16]x[16]xi8> to i8 + return %0 : i8 +} + +// ----- + +func.func @arm_sme_get_tile_id() -> i32 { + // CHECK: arm_sme.get_tile_id : i32 + %0 = arm_sme.get_tile_id : i32 + return %0 : i32 +}