diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h @@ -30,6 +30,9 @@ createEnableArmStreamingPass(const ArmStreaming mode = ArmStreaming::Default, const bool enableZA = false); +/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles. +std::unique_ptr createTileAllocationPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td --- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.td @@ -39,4 +39,16 @@ let dependentDialects = ["func::FuncDialect"]; } +def TileAllocation + : Pass<"allocate-arm-sme-tiles", "mlir::func::FuncOp"> { + let summary = "Allocate SME tiles"; + let description = [{ + This pass does tile allocation for SME "virtual tiles". It is run at the + 'func.func' op level, replacing 'arm_sme.get_tile_id' ops with (i32) tile + ids. An error will be emitted when there's no tiles left. + }]; + let constructor = "mlir::arm_sme::createTileAllocationPass()"; + let dependentDialects = ["func::FuncDialect"]; +} + #endif // MLIR_DIALECT_ARMSME_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRArmSMETransforms EnableArmStreaming.cpp LegalizeForLLVMExport.cpp + TileAllocation.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ArmSME/Transforms diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -50,17 +50,6 @@ return success(); } }; - -struct GetTileIDConversion : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - LogicalResult - matchAndRewrite(GetTileID op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // TODO: implement tile allocation, currently only tile 0 is supported. - rewriter.replaceOpWithNewOp(op, rewriter.getI32Type(), 0); - return success(); - } -}; } // namespace /// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -0,0 +1,198 @@ +//===- TileAllocation.cpp - Allocate SME ZA tiles -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass allocates SME tiles at the 'func.func' op level for +// 'arm_sme.get_tile_id' ops. It does this using a 16-bit tile mask that has a +// bit for each 128-bit element tile (ZA0.Q-ZA15.Q), the smallest ZA tile +// granule. +// +// The 128-bit tiles overlap with other element tiles as follows (see section +// B2.3.2 of SME spec [1]): +// +// Tile Overlaps +// --------------------------------------------------------------------------- +// ZA0.B ZA0.Q, ZA1.Q, ZA2.Q, ZA3.Q, ZA4.Q, ZA5.Q, ZA6.Q, ZA7.Q, ZA8.Q, +// ZA9.Q, ZA10.Q, ZA11.Q, ZA12.Q, ZA13.Q, ZA14.Q, ZA15.Q +// ZA0.H ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q +// ZA1.H ZA1.Q, ZA3.Q, ZA5.Q, ZA7.Q, ZA9.Q, ZA11.Q, ZA13.Q, ZA15.Q +// ZA0.S ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q +// ZA1.S ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q +// ZA2.S ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q +// ZA3.S ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q +// ZA0.D ZA0.Q, ZA8.Q +// ZA1.D ZA1.Q, ZA9.Q +// ZA2.D ZA2.Q, ZA10.Q +// ZA3.D ZA3.Q, ZA11.Q +// ZA4.D ZA4.Q, ZA12.Q +// ZA5.D ZA5.Q, ZA13.Q +// ZA6.D ZA6.Q, ZA14.Q +// ZA7.D ZA7.Q, ZA15.Q +// +// The tiles in use are tracked via a function attribute 'arm_sme.tiles_in_use' +// that is initalized during the first 'arm_sme.get_tile_id' rewrite and +// updated on each subsequent rewrite. +// +// [1] https://developer.arm.com/documentation/ddi0616/aa +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "allocate-arm-sme-tiles" + +namespace mlir { +namespace arm_sme { +#define GEN_PASS_DEF_TILEALLOCATION +#include "mlir/Dialect/ArmSME/Transforms/Passes.h.inc" +} // namespace arm_sme +} // namespace mlir + +using namespace mlir; +using namespace mlir::arm_sme; + +namespace { + +static constexpr char kTilesInUseAttr[] = "arm_sme.tiles_in_use"; + +enum class TileMask : unsigned { + // clang-format off + kZA0B = 0xffff, // 1111 1111 1111 1111 + + kZA0H = 0xaaaa, // 1010 1010 1010 1010 + kZA1H = 0x5555, // 0101 0101 0101 0101 + + kZA0S = 0x8888, // 1000 1000 1000 1000 + kZA1S = 0x4444, // 0100 0100 0100 0100 + kZA2S = 0x2222, // 0010 0010 0010 0010 + kZA3S = 0x1111, // 0001 0001 0001 0001 + + kZA0D = 0x8080, // 1000 0000 1000 0000 + kZA1D = 0x4040, // 0100 0000 0100 0000 + kZA2D = 0x2020, // 0010 0000 0010 0000 + kZA3D = 0x1010, // 0001 0000 0001 0000 + kZA4D = 0x808, // 0000 1000 0000 1000 + kZA5D = 0x404, // 0000 0100 0000 0100 + kZA6D = 0x202, // 0000 0010 0000 0010 + kZA7D = 0x101, // 0000 0001 0000 0001 + + kZA0Q = 0x8000, // 1000 0000 0000 0000 + kZA1Q = 0x4000, // 0100 0000 0000 0000 + kZA2Q = 0x2000, // 0010 0000 0000 0000 + kZA3Q = 0x1000, // 0001 0000 0000 0000 + kZA4Q = 0x800, // 0000 1000 0000 0000 + kZA5Q = 0x400, // 0000 0100 0000 0000 + kZA6Q = 0x200, // 0000 0010 0000 0000 + kZA7Q = 0x100, // 0000 0001 0000 0000 + kZA8Q = 0x80, // 0000 0000 1000 0000 + kZA9Q = 0x40, // 0000 0000 0100 0000 + kZA10Q = 0x20, // 0000 0000 0010 0000 + kZA11Q = 0x10, // 0000 0000 0001 0000 + kZA12Q = 0x8, // 0000 0000 0000 1000 + kZA13Q = 0x4, // 0000 0000 0000 0100 + kZA14Q = 0x2, // 0000 0000 0000 0010 + kZA15Q = 0x1, // 0000 0000 0000 0001 + + kNone = 0x0, // 0000 0000 0000 0000 + // clang-format on + + LLVM_MARK_AS_BITMASK_ENUM(kZA0B) +}; + +/// Returns the set of masks relevant for the given type. +static ArrayRef getMasks(Type type) { + static const SmallVector ZA_B_MASKS = {TileMask::kZA0B}; + static const SmallVector ZA_H_MASKS = {TileMask::kZA0H, + TileMask::kZA1H}; + static const SmallVector ZA_S_MASKS = { + TileMask::kZA0S, TileMask::kZA1S, TileMask::kZA2S, TileMask::kZA3S}; + static const SmallVector ZA_D_MASKS = { + TileMask::kZA0D, TileMask::kZA1D, TileMask::kZA2D, TileMask::kZA3D, + TileMask::kZA4D, TileMask::kZA5D, TileMask::kZA6D, TileMask::kZA7D}; + static const SmallVector ZA_Q_MASKS = { + TileMask::kZA0Q, TileMask::kZA1Q, TileMask::kZA2Q, TileMask::kZA3Q, + TileMask::kZA4Q, TileMask::kZA5Q, TileMask::kZA6Q, TileMask::kZA7Q, + TileMask::kZA8Q, TileMask::kZA9Q, TileMask::kZA10Q, TileMask::kZA11Q, + TileMask::kZA12Q, TileMask::kZA13Q, TileMask::kZA14Q, TileMask::kZA15Q}; + switch (cast(type).getWidth()) { + default: + llvm_unreachable("unexpected type!"); + case 8: + return ZA_B_MASKS; + case 16: + return ZA_H_MASKS; + case 32: + return ZA_S_MASKS; + case 64: + return ZA_D_MASKS; + case 128: + return ZA_Q_MASKS; + } +} + +/// Allocates a tile to 'tileID' or returns an error if there are no tiles left. +static LogicalResult getTile(GetTileID tileIDOp, TileMask &tilesInUse, + unsigned &tileID) { + auto masks = getMasks(tileIDOp.getType()); + for (const auto &it : llvm::enumerate(masks)) { + const auto tileMask = it.value(); + if ((tilesInUse & tileMask) == TileMask::kNone) { + tilesInUse |= tileMask; + tileID = it.index(); + return success(); + } + } + return tileIDOp.emitError("ran out of SME virtual tiles!"); +} + +struct GetTileIDConversion : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GetTileID tileIDOp, + PatternRewriter &rewriter) const override { + auto funcOp = tileIDOp->getParentOfType(); + TileMask tilesInUse; + if (auto tilesInUseAttr = + funcOp->getAttrOfType(kTilesInUseAttr)) + tilesInUse = static_cast(tilesInUseAttr.getInt()); + else + tilesInUse = TileMask::kNone; + + unsigned tileID; + if (failed(getTile(tileIDOp, tilesInUse, tileID))) + return failure(); + + funcOp->setAttr(kTilesInUseAttr, + rewriter.getI32IntegerAttr((unsigned)tilesInUse)); + + auto tileType = tileIDOp.getType(); + rewriter.replaceOpWithNewOp( + tileIDOp, tileType, rewriter.getIntegerAttr(tileType, tileID)); + return success(); + } +}; + +struct TileAllocationPass + : public arm_sme::impl::TileAllocationBase { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + ConversionTarget target(getContext()); + patterns.add(patterns.getContext()); + target.addLegalOp(); + target.addIllegalOp(); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::arm_sme::createTileAllocationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/ArmSME/tile-allocation.mlir b/mlir/test/Dialect/ArmSME/tile-allocation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/tile-allocation.mlir @@ -0,0 +1,377 @@ +// RUN: mlir-opt %s -allocate-arm-sme-tiles -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +// CHECK-LABEL: mixed_tiles +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65534 : i32} +func.func @mixed_tiles() { + // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q + // CHECK-NEXT: arith.constant 0 + %za0_h = arm_sme.get_tile_id : i16 + // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q + // CHECK-NEXT: arith.constant 1 + %za1_s = arm_sme.get_tile_id : i32 + // ZA3.D ZA3.Q, ZA11.Q + // CHECK-NEXT: arith.constant 3 + %za3_d = arm_sme.get_tile_id : i64 + // ZA7.Q + // CHECK-NEXT: arith.constant 7 + %za7_q = arm_sme.get_tile_id : i128 + // ZA15.Q is still free. + return +} + +// ----- + +// CHECK-LABEL: za_b +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_b() { + // CHECK-NEXT: arith.constant 0 + %za0_b = arm_sme.get_tile_id : i8 + return +} + +// ----- + +func.func @za_b__out_of_tiles() { + %za0_b = arm_sme.get_tile_id : i8 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i8 + return +} + +// ----- + +func.func @za_b_overlapping_za_q() { + %za0_b = arm_sme.get_tile_id : i8 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i128 + return +} + +// ----- + +// CHECK-LABEL: za0_h +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 43690 : i32} +func.func @za0_h() { + // CHECK-NEXT: arith.constant 0 + %za0_h = arm_sme.get_tile_id : i16 + return +} + +// ----- + +// CHECK-LABEL: za_h +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_h() { + // CHECK-NEXT: arith.constant 0 + %za0_h = arm_sme.get_tile_id : i16 + // CHECK-NEXT: arith.constant 1 + %za1_h = arm_sme.get_tile_id : i16 + return +} + +// ----- + +func.func @za_h__out_of_tiles() { + %za0_h = arm_sme.get_tile_id : i16 + %za1_h = arm_sme.get_tile_id : i16 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i16 + return +} + +// ----- + +// CHECK-LABEL: za_h_overlapping_za_s +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_h_overlapping_za_s() { + // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q + // CHECK-NEXT: arith.constant 0 + %za0_h = arm_sme.get_tile_id : i16 + // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q + // CHECK-NEXT: arith.constant 1 + %za1_s = arm_sme.get_tile_id : i32 + // ZA3.Q, ZA7.Q, ZA11.Q, ZA15.Q + // CHECK-NEXT: arith.constant 3 + %za3_s = arm_sme.get_tile_id : i32 + return +} + +// ----- + +// CHECK-LABEL: za_h_overlapping_za_d +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_h_overlapping_za_d() { + // ZA0.Q, ZA2.Q, ZA4.Q, ZA6.Q, ZA8.Q, ZA10.Q, ZA12.Q, ZA14.Q + // CHECK-NEXT: arith.constant 0 + %za0_h = arm_sme.get_tile_id : i16 + // ZA1.Q, ZA9.Q + // CHECK-NEXT: arith.constant 1 + %za1_d = arm_sme.get_tile_id : i64 + // ZA3.Q, ZA11.Q + // CHECK-NEXT: arith.constant 3 + %za3_d = arm_sme.get_tile_id : i64 + // ZA5.Q, ZA13.Q + // CHECK-NEXT: arith.constant 5 + %za5_d = arm_sme.get_tile_id : i64 + // ZA7.Q, ZA15.Q + // CHECK-NEXT: arith.constant 7 + %za7_d = arm_sme.get_tile_id : i64 + return +} + +// ----- + +func.func @za_h_overlapping_za_q() { + %za0_h = arm_sme.get_tile_id : i16 + %za0_q = arm_sme.get_tile_id : i128 + %za2_q = arm_sme.get_tile_id : i128 + %za4_q = arm_sme.get_tile_id : i128 + %za6_q = arm_sme.get_tile_id : i128 + %za8_q = arm_sme.get_tile_id : i128 + %za10_q = arm_sme.get_tile_id : i128 + %za12_q = arm_sme.get_tile_id : i128 + %za14_q = arm_sme.get_tile_id : i128 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i128 + return +} + +// ----- + +// CHECK-LABEL: za0_s +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 34952 : i32} +func.func @za0_s() { + // CHECK-NEXT: arith.constant 0 + %za0_s = arm_sme.get_tile_id : i32 + return +} + +// ----- + +// CHECK-LABEL: za_s +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_s() { + // CHECK-NEXT: arith.constant 0 + %za0_s = arm_sme.get_tile_id : i32 + // CHECK-NEXT: arith.constant 1 + %za1_s = arm_sme.get_tile_id : i32 + // CHECK-NEXT: arith.constant 2 + %za2_s = arm_sme.get_tile_id : i32 + // CHECK-NEXT: arith.constant 3 + %za3_s = arm_sme.get_tile_id : i32 + return +} + +// ----- + +func.func @za_s__out_of_tiles() { + %za0_s = arm_sme.get_tile_id : i32 + %za1_s = arm_sme.get_tile_id : i32 + %za2_s = arm_sme.get_tile_id : i32 + %za3_s = arm_sme.get_tile_id : i32 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i32 + return +} + +// ----- + +// CHECK-LABEL: za_s_overlapping_za_d +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_s_overlapping_za_d() { + // ZA0.Q, ZA4.Q, ZA8.Q, ZA12.Q + // CHECK-NEXT: arith.constant 0 + %za0_s = arm_sme.get_tile_id : i32 + // ZA1.Q, ZA5.Q, ZA9.Q, ZA13.Q + // CHECK-NEXT: arith.constant 1 + %za1_s = arm_sme.get_tile_id : i32 + // ZA2.Q, ZA6.Q, ZA10.Q, ZA14.Q + // CHECK-NEXT: arith.constant 2 + %za2_s = arm_sme.get_tile_id : i32 + // ZA3.Q, ZA11.Q + // CHECK-NEXT: arith.constant 3 + %za3_d = arm_sme.get_tile_id : i64 + // ZA7.Q, ZA15.Q + // CHECK-NEXT: arith.constant 7 + %za7_d = arm_sme.get_tile_id : i64 + return +} + +// ----- + +func.func @za_s_overlapping_za_q() { + %za0_s = arm_sme.get_tile_id : i32 + %za1_q = arm_sme.get_tile_id : i128 + %za2_q = arm_sme.get_tile_id : i128 + %za3_q = arm_sme.get_tile_id : i128 + %za5_q = arm_sme.get_tile_id : i128 + %za6_q = arm_sme.get_tile_id : i128 + %za7_q = arm_sme.get_tile_id : i128 + %za9_q = arm_sme.get_tile_id : i128 + %za10_q = arm_sme.get_tile_id : i128 + %za11_q = arm_sme.get_tile_id : i128 + %za13_q = arm_sme.get_tile_id : i128 + %za14_q = arm_sme.get_tile_id : i128 + %za15_q = arm_sme.get_tile_id : i128 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i128 + return +} + +// ----- + +// CHECK-LABEL: za0_d +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32896 : i32} +func.func @za0_d() { + // CHECK-NEXT: arith.constant 0 + %za0_d = arm_sme.get_tile_id : i64 + return +} + +// ----- + +// CHECK-LABEL: za_d +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_d() { + // CHECK-NEXT: arith.constant 0 + %za0_d = arm_sme.get_tile_id : i64 + // CHECK-NEXT: arith.constant 1 + %za1_d = arm_sme.get_tile_id : i64 + // CHECK-NEXT: arith.constant 2 + %za2_d = arm_sme.get_tile_id : i64 + // CHECK-NEXT: arith.constant 3 + %za3_d = arm_sme.get_tile_id : i64 + // CHECK-NEXT: arith.constant 4 + %za4_d = arm_sme.get_tile_id : i64 + // CHECK-NEXT: arith.constant 5 + %za5_d = arm_sme.get_tile_id : i64 + // CHECK-NEXT: arith.constant 6 + %za6_d = arm_sme.get_tile_id : i64 + // CHECK-NEXT: arith.constant 7 + %za7_d = arm_sme.get_tile_id : i64 + return +} + +// ----- + +func.func @za_d__out_of_tiles() { + %za0_d = arm_sme.get_tile_id : i64 + %za1_d = arm_sme.get_tile_id : i64 + %za2_d = arm_sme.get_tile_id : i64 + %za3_d = arm_sme.get_tile_id : i64 + %za4_d = arm_sme.get_tile_id : i64 + %za5_d = arm_sme.get_tile_id : i64 + %za6_d = arm_sme.get_tile_id : i64 + %za7_d = arm_sme.get_tile_id : i64 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i64 + return +} + +// ----- + +func.func @za_d_overlapping_za_q() { + %za0_d = arm_sme.get_tile_id : i64 + %za1_q = arm_sme.get_tile_id : i128 + %za2_q = arm_sme.get_tile_id : i128 + %za3_q = arm_sme.get_tile_id : i128 + %za4_q = arm_sme.get_tile_id : i128 + %za5_q = arm_sme.get_tile_id : i128 + %za6_q = arm_sme.get_tile_id : i128 + %za7_q = arm_sme.get_tile_id : i128 + %za9_q = arm_sme.get_tile_id : i128 + %za10_q = arm_sme.get_tile_id : i128 + %za11_q = arm_sme.get_tile_id : i128 + %za12_q = arm_sme.get_tile_id : i128 + %za13_q = arm_sme.get_tile_id : i128 + %za14_q = arm_sme.get_tile_id : i128 + %za15_q = arm_sme.get_tile_id : i128 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i128 + return +} + +// ----- + +// CHECK-LABEL: za0_q +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 32768 : i32} +func.func @za0_q() { + // CHECK-NEXT: arith.constant 0 + %za0_q = arm_sme.get_tile_id : i128 + return +} + +// ----- + +// CHECK-LABEL: za_q +// CHECK-SAME: attributes {arm_sme.tiles_in_use = 65535 : i32} +func.func @za_q() { + // CHECK-NEXT: arith.constant 0 + %za0_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 1 + %za1_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 2 + %za2_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 3 + %za3_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 4 + %za4_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 5 + %za5_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 6 + %za6_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 7 + %za7_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 8 + %za8_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 9 + %za9_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 10 + %za10_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 11 + %za11_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 12 + %za12_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 13 + %za13_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 14 + %za14_q = arm_sme.get_tile_id : i128 + // CHECK-NEXT: arith.constant 15 + %za15_q = arm_sme.get_tile_id : i128 + return +} + +// ----- + +func.func @za_q__out_of_tiles() { + %za0_q = arm_sme.get_tile_id : i128 + %za1_q = arm_sme.get_tile_id : i128 + %za2_q = arm_sme.get_tile_id : i128 + %za3_q = arm_sme.get_tile_id : i128 + %za4_q = arm_sme.get_tile_id : i128 + %za5_q = arm_sme.get_tile_id : i128 + %za6_q = arm_sme.get_tile_id : i128 + %za7_q = arm_sme.get_tile_id : i128 + %za8_q = arm_sme.get_tile_id : i128 + %za9_q = arm_sme.get_tile_id : i128 + %za10_q = arm_sme.get_tile_id : i128 + %za11_q = arm_sme.get_tile_id : i128 + %za12_q = arm_sme.get_tile_id : i128 + %za13_q = arm_sme.get_tile_id : i128 + %za14_q = arm_sme.get_tile_id : i128 + %za15_q = arm_sme.get_tile_id : i128 + // expected-error@+2 {{failed to legalize operation 'arm_sme.get_tile_id' that was explicitly marked illegal}} + // expected-error@+1 {{ran out of SME virtual tiles!}} + %next_tile = arm_sme.get_tile_id : i128 + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSME/vector-ops.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -convert-vector-to-arm-sme -enable-arm-streaming="mode=locally enable-za" \ -// RUN: -convert-vector-to-llvm="enable-arm-sme" -test-lower-to-llvm | \ +// RUN: mlir-opt %s -enable-arm-streaming="mode=locally enable-za" \ +// RUN: -convert-vector-to-arm-sme -convert-vector-to-llvm="enable-arm-sme" \ +// RUN: -allocate-arm-sme-tiles -test-lower-to-llvm | \ // RUN: mlir-translate -mlir-to-llvmir | \ // RUN: %lli_aarch64_cmd --march=aarch64 --mattr="+sve,+sme" \ // RUN: --entry-function=entry \