diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp @@ -92,11 +92,12 @@ ConversionPatternRewriter &rewriter) const override { auto loc = zero.getLoc(); + unsigned tileElementWidth = + zero.getVectorType().getElementType().getIntOrFloatBitWidth(); + // Get Tile ID for the `zero` intrinsic. auto tileId = rewriter.create( - loc, zero.getVectorType().getElementType()); - - auto tileElementWidth = tileId.getType().getIntOrFloatBitWidth(); + loc, rewriter.getIntegerType(tileElementWidth)); // Get the base mask for tile based on the element size. // The base mask is just the mask to zero the first tile (of a size). diff --git a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir --- a/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir +++ b/mlir/test/Dialect/ArmSME/tile-zero-masks.mlir @@ -3,6 +3,9 @@ // RUN: -allow-unregistered-dialect \ // RUN: | FileCheck %s +// This test verifies the tile mask operand of the zero intrinsic zeroes +// the correct tiles. Both integer and floating-point datatypes are checked. + // ----- // CHECK-LABEL: zero_za_b @@ -32,9 +35,9 @@ %zero_za0h = arm_sme.zero : vector<[8]x[8]xi16> "prevent.dce"(%zero_za0h) : (vector<[8]x[8]xi16>) -> () // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA1H]]) : (i32) -> () - // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xi16> - %zero_za1h = arm_sme.zero : vector<[8]x[8]xi16> - "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xi16>) -> () + // CHECK-NEXT: %[[ZERO_ZA1H:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA1H]] : i16 to vector<[8]x[8]xf16> + %zero_za1h = arm_sme.zero : vector<[8]x[8]xf16> + "prevent.dce"(%zero_za1h) : (vector<[8]x[8]xf16>) -> () return } @@ -65,9 +68,9 @@ %zero_za2s = arm_sme.zero : vector<[4]x[4]xi32> "prevent.dce"(%zero_za2s) : (vector<[4]x[4]xi32>) -> () // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA3S]]) : (i32) -> () - // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xi32> - %zero_za3s = arm_sme.zero : vector<[4]x[4]xi32> - "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xi32>) -> () + // CHECK-NEXT: %[[ZERO_ZA3S:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA3S]] : i32 to vector<[4]x[4]xf32> + %zero_za3s = arm_sme.zero : vector<[4]x[4]xf32> + "prevent.dce"(%zero_za3s) : (vector<[4]x[4]xf32>) -> () return } @@ -122,8 +125,8 @@ %zero_za6d = arm_sme.zero : vector<[2]x[2]xi64> "prevent.dce"(%zero_za6d) : (vector<[2]x[2]xi64>) -> () // CHECK: "arm_sme.intr.zero"(%[[ZERO_MASK_ZA7D]]) : (i32) -> () - // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xi64> - %zero_za7d = arm_sme.zero : vector<[2]x[2]xi64> - "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xi64>) -> () + // CHECK-NEXT: %[[ZERO_ZA7D:.*]] = arm_sme.cast_tile_to_vector %[[TILE_ID_ZA7D]] : i64 to vector<[2]x[2]xf64> + %zero_za7d = arm_sme.zero : vector<[2]x[2]xf64> + "prevent.dce"(%zero_za7d) : (vector<[2]x[2]xf64>) -> () return }