diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -20,8 +20,6 @@ using namespace mlir; using namespace mlir::arm_sme; -static constexpr unsigned kZeroZAMask = 255; - namespace { /// Insert 'llvm.aarch64.sme.za.enable' intrinsic at the start of 'func.func' /// ops to enable the ZA storage array. @@ -51,21 +49,41 @@ } }; -/// Lower 'arm_sme.zero'. Use 'arm_sme.cast_tile_to_vector' to model the return -/// value. The latter is a nop, which should be folded away (e.g. during -/// canonicalisation). +/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or +/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar +/// integer, to an i32 that can be passed as the `tile` parameter to the SME +/// intrinsics. Or returns `tile` if already i32. +Value castTileIDToI32(Value tile, Location loc, + ConversionPatternRewriter &rewriter) { + assert((isa( + tile.getDefiningOp())) && + "expected ArmSME GetTileID or CastVectorToTile op!"); + unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth(); + if (tileElementWidth < 32) + return rewriter.create(loc, rewriter.getI32Type(), tile); + if (tileElementWidth > 32) + return rewriter.create(loc, rewriter.getI32Type(), tile); + return tile; +} + +/// Lower 'arm_sme.zero' to SME intrinsics. /// /// BEFORE: /// ```mlir -/// %0 = arm_sme.zero : vector<[16]x[16]xi8> +/// %v = arm_sme.zero : vector<[4]x[4]xi32> /// ``` /// /// AFTER: /// ```mlir -/// %1 = arm_sme.get_tile_id : i8 -/// %2 = arm_sme.cast_tile_to_vector %1 : i8 to vector<[16]x[16]xi8> -/// "arm_sme.intr.zero"(%c255_i32) : (i32) -> () +/// %tile_id = arm_sme.get_tile_id : i32 +/// %zero_mask = arith.shli %c17_i32, %tile_id : i32 +/// "arm_sme.intr.zero"(%zero_mask) : (i32) -> () +/// %v = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32> /// ``` +/// +/// The 'arm_sme.cast_tile_to_vector' (which models the return) and the +/// 'arith.shli' (which generates the mask) will be folded away after tile +/// allocation and canonization. struct ZeroOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -75,18 +93,50 @@ auto loc = zero.getLoc(); // Get Tile ID for the `zero` intrinsic. - // TODO: Map this to a valid `mask` for the `zero` intrinsic. auto tileId = rewriter.create( loc, zero.getVectorType().getElementType()); - // Create 'arm_sme.intr.zero' intrinsic to zero ZA. - // FIXME: Replace the hard-coded mask with a valid value based - // on `tileId`. - auto mask = rewriter.create( - loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(kZeroZAMask)); - rewriter.create(loc, mask); - - // Create `CastTileToVectorOp` to use it as the output + auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth(); + + // Get the base mask for tile based on the element size. + // The base mask is just the mask to zero the first tile (of a size). + // These masks are derived from: + // https://developer.arm.com/documentation/ddi0602/2022-06/SME-Instructions/ZERO--Zero-a-list-of-64-bit-element-ZA-tiles- + auto baseMaskForSize = [&] { + switch (tileElementWidth) { + case 8: + // Zeroing the 8-bit ZA0.B tile is equivalent to zeroing all eight + // 64-bit element tiles named ZA0.D to ZA7.D. + return 0b1111'1111; + case 16: + // Zeroing the 16-bit ZA0.H tile is equivalent to zeroing 64-bit element + // tiles named ZA0.D, ZA2.D, ZA4.D, and ZA6.D. + // Shift this left once for ZA1.H. + return 0b0101'0101; + case 32: + // Zeroing the 32-bit ZA0.S tile is equivalent to zeroing 64-bit + // element tiles named ZA0.D and ZA4.D. + // Shift left by 1, 2, or 3 respectively for ZA1.S, ZA2.S, ZA3.S. + return 0b0001'0001; + case 64: + // Zeroing one of the a 64-bit tiles ZA0.D to ZA7.D just requires + // setting the bit for that tile. + return 0b0000'0001; + default: + llvm_unreachable("bad element size"); + } + }(); + auto maskType = rewriter.getI32Type(); + auto baseMask = rewriter.create( + loc, maskType, rewriter.getIntegerAttr(maskType, baseMaskForSize)); + + // The actual mask is just the base mask shifted by the tile ID. + // This will be folded to a constant after tile allocation. + auto tileMask = rewriter.create( + loc, baseMask, castTileIDToI32(tileId, loc, rewriter)); + rewriter.create(loc, tileMask); + + // Create `CastTileToVectorOp` to use as the output. rewriter.replaceOpWithNewOp(zero, zero.getType(), tileId); @@ -94,23 +144,6 @@ } }; -/// Extends or truncates `tile`, which should be an `arm_sme::GetTileID` or -/// `arm_sme::CastVectorToTile` op returning an 8/16/32/64/128-bit scalar -/// integer, to an i32 that can be passed as the `tile` parameter to the SME -/// intrinsics. Or returns `tile` if already i32. -Value castTileIDToI32(Value tile, Location loc, - ConversionPatternRewriter &rewriter) { - assert((isa( - tile.getDefiningOp())) && - "expected ArmSME GetTileID or CastVectorToTile op!"); - unsigned tileElementWidth = tile.getType().getIntOrFloatBitWidth(); - if (tileElementWidth < 32) - return rewriter.create(loc, rewriter.getI32Type(), tile); - if (tileElementWidth > 32) - return rewriter.create(loc, rewriter.getI32Type(), tile); - return tile; -} - /// Lower `arm_sme.load_tile_slice` to SME intrinsics. struct LoadTileSliceToArmSMELowering : public ConvertOpToLLVMPattern { diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir @@ -0,0 +1,129 @@ +// RUN: mlir-opt %s -convert-vector-to-llvm="enable-arm-sme" \ +// RUN: -allocate-arm-sme-tiles -canonicalize \ +// RUN: -allow-unregistered-dialect \ +// RUN: | FileCheck %s + +// ----- + +// CHECK-LABEL: zero_za_b +func.func @zero_za_b() { + // CHECK-DAG: %[[TILE_ID:.*]] = arith.constant 0 : i8 + // CHECK-DAG: %[[ZERO_MASK:.*]] = arith.constant 255 : i32 + + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA0B:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID]] : i8 to vector<[16]x[16]xi8> + %zero_za0b = arm_sme.zero : vector<[16]x[16]xi8> + "prevent.dce"(%zero_za0b) : (vector<[16]x[16]xi8>) -> () + return +} + +// ----- + +// CHECK-LABEL: zero_za_h +func.func @zero_za_h() { + // CHECK-DAG: %[[TILE_ID_ZA0H:.*]] = arith.constant 0 : i16 + // CHECK-DAG: %[[TILE_ID_ZA1H:.*]] = arith.constant 1 : i16 + + // CHECK-DAG: %[[ZERO_MASK_ZA0H:.*]] = arith.constant 85 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA1H:.*]] = arith.constant 170 : i32 + + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0H]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA0H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0H]] : i16 to vector<[8]x[8]xi16> + %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16> + "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16> + %zero_za1h = arm_sme.zero : vector<[8]x[8]xi16> + "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> () + return +} + +// ----- + +// CHECK-LABEL: zero_za_s +func.func @zero_za_s() { + // CHECK-DAG: %[[TILE_ID_ZA0S:.*]] = arith.constant 0 : i32 + // CHECK-DAG: %[[TILE_ID_ZA1S:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[TILE_ID_ZA2S:.*]] = arith.constant 2 : i32 + // CHECK-DAG: %[[TILE_ID_ZA3S:.*]] = arith.constant 3 : i32 + + // CHECK-DAG: %[[ZERO_MASK_ZA0S:.*]] = arith.constant 17 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA1S:.*]] = arith.constant 34 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA2S:.*]] = arith.constant 68 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA3S:.*]] = arith.constant 136 : i32 + + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0S]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA0S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0S]] : i32 to vector<[4]x[4]xi32> + %zero_za0s = arm_sme.zero : vector<[4]x[4]xi32> + "prevent.dce"(%zero_za0s) : (vector<[4]x[4]xi32>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1S]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA1S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1S]] : i32 to vector<[4]x[4]xi32> + %zero_za1s = arm_sme.zero : vector<[4]x[4]xi32> + "prevent.dce"(%zero_za1s) : (vector<[4]x[4]xi32>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2S]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA2S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2S]] : i32 to vector<[4]x[4]xi32> + %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32> + "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32> + %zero_za3s = arm_sme.zero : vector<[4]x[4]xi32> + "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> () + return +} + +// ----- + +// CHECK-LABEL: zero_za_d +func.func @zero_za_d() { + // CHECK-DAG: %[[TILE_ID_ZA0D:.*]] = arith.constant 0 : i64 + // CHECK-DAG: %[[TILE_ID_ZA1D:.*]] = arith.constant 1 : i64 + // CHECK-DAG: %[[TILE_ID_ZA2D:.*]] = arith.constant 2 : i64 + // CHECK-DAG: %[[TILE_ID_ZA3D:.*]] = arith.constant 3 : i64 + // CHECK-DAG: %[[TILE_ID_ZA4D:.*]] = arith.constant 4 : i64 + // CHECK-DAG: %[[TILE_ID_ZA5D:.*]] = arith.constant 5 : i64 + // CHECK-DAG: %[[TILE_ID_ZA6D:.*]] = arith.constant 6 : i64 + // CHECK-DAG: %[[TILE_ID_ZA7D:.*]] = arith.constant 7 : i64 + + // CHECK-DAG: %[[ZERO_MASK_ZA0D:.*]] = arith.constant 1 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA1D:.*]] = arith.constant 2 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA2D:.*]] = arith.constant 4 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA3D:.*]] = arith.constant 8 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA4D:.*]] = arith.constant 16 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA5D:.*]] = arith.constant 32 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA6D:.*]] = arith.constant 64 : i32 + // CHECK-DAG: %[[ZERO_MASK_ZA7D:.*]] = arith.constant 128 : i32 + + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA0D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA0D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA0D]] : i64 to vector<[2]x[2]xi64> + %zero_za0d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za0d) : (vector<[2]x[2]xi64>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA1D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1D]] : i64 to vector<[2]x[2]xi64> + %zero_za1d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za1d) : (vector<[2]x[2]xi64>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA2D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA2D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA2D]] : i64 to vector<[2]x[2]xi64> + %zero_za2d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za2d) : (vector<[2]x[2]xi64>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA3D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3D]] : i64 to vector<[2]x[2]xi64> + %zero_za3d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za3d) : (vector<[2]x[2]xi64>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA4D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA4D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA4D]] : i64 to vector<[2]x[2]xi64> + %zero_za4d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za4d) : (vector<[2]x[2]xi64>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA5D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA5D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA5D]] : i64 to vector<[2]x[2]xi64> + %zero_za5d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za5d) : (vector<[2]x[2]xi64>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA6D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA6D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA6D]] : i64 to vector<[2]x[2]xi64> + %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> () + // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> () + // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64> + %zero_za7d = arm_sme.zero : vector<[2]x[2]xi64> + "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> () + return +} diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir --- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir +++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir @@ -9,8 +9,10 @@ // CHECK-DAG: %[[C255:.*]] = arith.constant 255 : i32 // CHECK-DAG: %[[PTRUE_ALL:.*]] = arith.constant dense : vector<[16]xi1> // CHECK-DAG: %[[C0_I64:.*]] = builtin.unrealized_conversion_cast %[[C0]] : index to i64 -// CHECK-DAG: "arm_sme.intr.zero"(%[[C255]]) : (i32) -> () // CHECK-DAG: %[[TILE_ID:.*]] = arm_sme.get_tile_id : i8 +// CHECK-DAG: %[[EXT_TILE_ID:.*]] = arith.extui %[[TILE_ID]] : i8 to i32 +// CHECK-DAG: %[[TILE_MASK:.*]] = arith.shli %[[C255]], %[[EXT_TILE_ID]] : i32 +// CHECK-DAG: "arm_sme.intr.zero"(%[[TILE_MASK]]) : (i32) -> () // CHECK-DAG: %[[VSCALE:.*]] = "llvm.intr.vscale"() : () -> i64 // CHECK-NEXT: %[[VSCALE_IDX:.*]] = builtin.unrealized_conversion_cast %[[VSCALE]] : i64 to index // CHECK-NEXT: %[[SVL_B:.*]] = arith.muli %[[VSCALE_IDX]], %[[MIN_SVL_B]] : index