This is an archive of the discontinued LLVM Phabricator instance.

[mlir][ArmSME] Add custom get_tile_id and cast ops
ClosedPublic

Authored by c-rhodes on Jul 11 2023, 3:31 AM.

Details

Summary

This patch adds three new custom ops to the ArmSME dialect:

  • arm_sme.get_tile_id - returns a scalar integer representing an SME "virtual tile" that is not in use.
  • arm_sme.cast_tile_to_vector - casts from a tile id to a 2-d scalable vector type, which represents an SME "virtual tile".
  • arm_sme.cast_vector_to_tile - casts from a 2-d scalable vector type, which represents an SME "virtual tile", to a tile id.

The 'arm_sme.get_tile_id' op currently only supports tile 0, a follow-up
patch will implement proper tile allocation. A further follow-up patch
will demonstrate load/store to/from ZA using these ops.

See the op descriptions for further details and examples.

Thanks to @paulwalker-arm and @awarzynski for helping drive this.

Diff Detail

Event Timeline

c-rhodes created this revision.Jul 11 2023, 3:31 AM
Herald added a reviewer: ftynse. · View Herald Transcript
Herald added a reviewer: dcaballe. · View Herald Transcript
Herald added a project: Restricted Project. · View Herald Transcript
c-rhodes requested review of this revision.Jul 11 2023, 3:31 AM
c-rhodes edited the summary of this revision. (Show Details)Jul 11 2023, 4:20 AM
c-rhodes added a subscriber: paulwalker-arm.

LGTM % a few minor details. Great work!

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
54

What about f32 and f64?

188

It may be useful to also indicate what is the intended live range of the allocated tile?

c-rhodes updated this revision to Diff 539100.Jul 11 2023, 7:48 AM

Add FP types.

c-rhodes marked an inline comment as done.Jul 11 2023, 7:50 AM

LGTM % a few minor details. Great work!

thanks for taking a look!

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
54

What about f32 and f64?

Thanks for pointing that out, I've also added F16/BF16.

188

It may be useful to also indicate what is the intended live range of the allocated tile?

Sorry could clarify? Not sure exactly what you mean

It may be useful to also indicate what is the intended live range of the allocated tile?

Sorry could clarify? Not sure exactly what you mean

I wasn't entirely sure when the tile allocated by this will be "deallocated" since I misread the description of D154955 earlier, but upon a closer look I realized the tile allocated is valid in the scope of the function. I thought it'd be a good idea to make this clear in the description (and also how one probably wouldn't be able to pass this as an argument to another function call)

c-rhodes marked an inline comment as done.Jul 11 2023, 8:16 AM

It may be useful to also indicate what is the intended live range of the allocated tile?

Sorry could clarify? Not sure exactly what you mean

I wasn't entirely sure when the tile allocated by this will be "deallocated" since I misread the description of D154955 earlier, but upon a closer look I realized the tile allocated is valid in the scope of the function. I thought it'd be a good idea to make this clear in the description (and also how one probably wouldn't be able to pass this as an argument to another function call)

Yes that's right the scope is a function, we could add a deallocation op to free up tiles within functions and also support spilling/filling to memory instead of throwing a "ran out of tiles" error in the future. I'll fix the description to capture what you raised. Cheers.

Great work @c-rhodes , thank you! I've actually immediately rebased https://reviews.llvm.org/D154867 on top of this change and that immediately solved the "data flow" issue 🙏🏻

Overall this looks solid to me. I've left a few minor suggestions - mostly to clarify the documentation.

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
46–47

[nit] Perhaps extract the preconditions to dedicated definition?

Also, when would this be triggered:

CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"

It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).

75–76

Perhaps:

  • "This is used in conjunction with `cast_vector_to_tile" --> "This would normally be used in conjunction with "virtual tile load" operations to model the output of such Ops. This is required to preserve data-flow as SME intrinsics do not return values."

Basically, this Op and CastVectorToTile complement each other, right? And I guess that's what we want to say here? But IMHO, this description should focus on CastTileToVector.

81–96

This example is a bit busy.I would focus on the Op that's defined here (i.e. CastTileToVector), so that this description is self-contained (try to avoid references to CastVectorToTile). My suggestion:

EXAMPLE:
Input:
```lang=cpp
    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
```

After lowering `vector.load`:
```lang=cpp
    %tile_id = arith.constant 0 : i32
    scf.for %vnum = %c0 to %num_vectors step %c1 {
      // ...
      "arm_sme.intr.ld1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
    }
    %tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[4]xi32>
    vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
```

Another question - are vector.load and vector.store the right Ops here? We don't really lower from these ATM.

105–108

This comment refers to CastVectorToTile

110–113

This comment refers to "these casts", but this is just one cast ;-)

130–161

This example is a bit busy.I would focus on the Op that's defined here (i.e. CastVectorToTile), so that this description is self-contained (try to avoid references to CastTileToVector). My suggestion:

EXAMPLE:
Input:
```mlir

%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
vector.store %tile, %mem2[%c0] : memref<?xi32>, vector<[4]x[4]xi32>
```
 
Output after lowering `vector.store`:
```mlir
%tile = vector.load %mem1[%c0] : memref<?xi32>, vector<[4]x[4]xi32>

%tile_id = arm_sme.cast_vector_to_tile %tile : : (vector<[4]x[4]xi32>) -> i32
scf.for %vnum = %c0 to %num_vectors step %c1 {
  // ...
  "arm_sme.intr.st1w.horiz"(%pg, %ptr, %tile_id, %vnum) : (vector<[4]xi1>, !llvm.ptr, i32, i32) -> ()
}
```

Additionally, canonicalization will look through `cast_vector_to_tile` Ops and fold the
cast ops away if they come from `cast_tile_to_vector`.
178

[nit] Is there any "tile allocation" really taking place? Perhaps "Allocate and return a "virtual tile" ID"?

mlir/test/Dialect/ArmSME/canonicalize.mlir
10–11

What about "the other way round"?

%tile_id = arm_sme.cast_vector_to_tile %tile_1 : vector<[16]x[16]xi8> to i8  
%tile_2 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>
mlir/test/Dialect/ArmSME/invalid.mlir
5

How about:

func.func @arm_sme_cast_tile_to_vector__bad_vector_type(%tile_id : i8) -> vector<[16]x16xi8>

and other combinations? For example:

  • vector<[16]x[16]xi4>
  • vector<16x[16]xi8>
mlir/test/Dialect/ArmSME/roundtrip.mlir
8

Could you add one more other element type? For example, vector<[1]x[1]xi128> (i.e. the other extreme).

c-rhodes updated this revision to Diff 539587.Jul 12 2023, 8:44 AM

Address comments

c-rhodes marked 10 inline comments as done and an inline comment as not done.Jul 12 2023, 8:49 AM
c-rhodes added inline comments.
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
46–47

Also, when would this be triggered:

CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>({" # !interleave(dims, ", ") # "})"

It feels like a "VectorType verifier" that could be safely skipped (i.e. nothing SME specific).

Please could you clarify, not sure what you mean? This verifies the shape, i.e. vector<[16]x[16]xi8> is (16, 16).

75–76

Perhaps:

  • "This is used in conjunction with `cast_vector_to_tile" --> "This would normally be used in conjunction with "virtual tile load" operations to model the output of such Ops. This is required to preserve data-flow as SME intrinsics do not return values."

Basically, this Op and CastVectorToTile complement each other, right? And I guess that's what we want to say here? But IMHO, this description should focus on CastTileToVector.

Thanks for the suggestion this has cleaned it up nicely

178

[nit] Is there any "tile allocation" really taking place? Perhaps "Allocate and return a "virtual tile" ID"?

There isn't from the perspective of the op I suppose, it's the pass that does that. Updated the comment.

mlir/test/Dialect/ArmSME/canonicalize.mlir
10–11

What about "the other way round"?

%tile_id = arm_sme.cast_vector_to_tile %tile_1 : vector<[16]x[16]xi8> to i8  
%tile_2 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[16]x[16]xi8>

Good spot!

mlir/test/Dialect/ArmSME/roundtrip.mlir
8

Could you add one more other element type? For example, vector<[1]x[1]xi128> (i.e. the other extreme).

I've added tests for all element types

dcaballe accepted this revision.Jul 12 2023, 11:05 PM

LGTM! I like the abstraction! Awesome to see this moving forward!

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
44

Is there a construct to make this a VectorType instead of a ShapedType? I guess the subsequent predicates constraint the shaped type a bit more but it would be great if this could be an vector type directly.

118

This abstraction sounds really great!

131

I guess you also considered introducing a single cast op that could cast both ways depending on the order of the operand/types. I think having two makes sense since this cast is kind of crossing two domains...

166

nit: do the quotes imply anything on virtual tiles? I think I don't get what it is :)

This revision is now accepted and ready to land.Jul 12 2023, 11:05 PM
c-rhodes marked 4 inline comments as done.Jul 13 2023, 1:41 AM
c-rhodes added inline comments.
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
44

Is there a construct to make this a VectorType instead of a ShapedType? I guess the subsequent predicates constraint the shaped type a bit more but it would be great if this could be an vector type directly.

I don't believe there is an existing one, the vector ones in mlir/include/mlir/IR/OpBase.td I used for reference here also use this, but perhaps ShapedContainerType could be copied for a VectorType.

131

I guess you also considered introducing a single cast op that could cast both ways depending on the order of the operand/types. I think having two makes sense since this cast is kind of crossing two domains...

That didn't cross my mind actually, it's a good point I think of these casts as being similar to builtin.unrealized_conversion_cast and that does similar but with a single cast op like you say, perhaps this could be a single cast as well.

166

nit: do the quotes imply anything on virtual tiles? I think I don't get what it is :)

To be honest I would prefer we just use tile, but the rationale is that these are not real tiles but merely “views” into ZA.

awarzynski accepted this revision.Jul 13 2023, 6:23 AM

LGTM, thanks for addressing my comments :)

mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
44

How about ScalableVectorOf? Could that be applicable here?

46–47

I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant. But I am probably just failing to understand the underlying rationale. No harm in keeping this.

166

I've been suggesting "virtual tiles" as different people mean different things when referring to tiles. "SME virtual tiles" is just a way to highlight that:

  • We mean the tiles in the context of the Arm SME extension (as opposed to e.g. tiles when tiling a matmul).
  • These tiles are actually "views" into ZA rather than "tiles". A "tile" to me suggests that it's something "square" and so "ZA tile" could, incorrectly, imply "a square section of ZA".

It's a name that one of our architects at Arm has been using and I feel that's very fitting. Naming is hard!

c-rhodes added inline comments.Jul 13 2023, 6:35 AM
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
44

How about ScalableVectorOf? Could that be applicable here?

All of the existing scalable vector predicates only check is scalable (i.e. any dim) not all dims scalable

46–47

I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.

And this verifies that :)

166

In the context of SME I think it's clear what a tile is, but I have no strong feelings either way.

c-rhodes added inline comments.Jul 13 2023, 6:47 AM
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
46–47

I am just thinking that every vector that you create like this will satisfy this condition and to me this check feels redundant.

And this verifies that :)

To clarify, without this check:

%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>

would be valid

c-rhodes updated this revision to Diff 540436.Jul 14 2023, 8:50 AM

Minor update to add scalar int type to name of functions for get_tile_id tests

Matt added a subscriber: Matt.Jul 14 2023, 2:45 PM
awarzynski added inline comments.Jul 17 2023, 12:16 AM
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
46–47

To clarify, without this check:

%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>

would be valid

Could you double-check? This works fine:

module {
  func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
    %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
    return %0 : vector<[4]x[16]xi8>
  }
}

You will need to replace SMETile with AnyVectorOfAnyRank in the definition of CastTileToVector, but IsVectorOfShape should trigger in both cases, right?

c-rhodes added inline comments.Jul 17 2023, 1:30 AM
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
46–47

To clarify, without this check:

%tile = arm_sme.cast_tile_to_vector %tile_id : i32 to vector<[4]x[8]xi32>

would be valid

Could you double-check? This works fine:

module {
  func.func @arm_sme_cast_tile_to_vector_i8(%arg0: i8) -> vector<[4]x[16]xi8> {
    %0 = arm_sme.cast_tile_to_vector %arg0 : i8 to vector<[4]x[16]xi8>
    return %0 : vector<[4]x[16]xi8>
  }
}

this fails for me (as expected):

build/bin/mlir-opt foo.mlir
foo.mlir:4:8: error: 'arm_sme.cast_tile_to_vector' op result #0 must be vector<[16]x[16]xi8> of 8-bit signless integer values or vector<[8]x[8]xi16> of 16-bit signless integer values or vector<[4]x[4]xi32> of 32-bit signless integer values or vector<[2]x[2]xi64> of 64-bit signless integer values or vector<[1]x[1]xi128> of 128-bit signless integer values or vector<[8]x[8]xf16> of 16-bit float values or vector<[8]x[8]xbf16> of bfloat16 type values or vector<[4]x[4]xf32> of 32-bit float values or vector<[2]x[2]xf64> of 64-bit float values, but got 'vector<[4]x[16]xi8>'
  %0 = arm_sme.cast_tile_to_vector %tile_id : i8 to vector<[4]x[16]xi8>
       ^
foo.mlir:4:8: note: see current operation: %0 = "arm_sme.cast_tile_to_vector"(%arg0) : (i8) -> vector<[4]x[16]xi8>

and doesn't if I remove IsVectorOfShape<dims> check.

You will need to replace SMETile with AnyVectorOfAnyRank in the definition of CastTileToVector, but IsVectorOfShape should trigger in both cases, right?

I'm not sure I follow, please could you clarify?

awarzynski added inline comments.Jul 17 2023, 2:36 AM
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
46–47

You are right and I am wrong, sorry.

I've just checked the generated CPP code and it's this:

((::llvm::cast<::mlir::VectorType>(type).getShape() == ArrayRef<int64_t>({16, 16})))
// other similar checks

So RHS is taken from:

def SMETile : AnyTypeOf<[nxnxv16i8, nxnxv8i16, nxnxv4i32, nxnxv2i64, nxnxv1i128,
                         nxnxv8f16, nxnxv8bf16, nxnxv4f32, nxnxv2f64]>;

I thought that for this example (vector<[4]x[8]xi32>) it would check the following instead:

((::llvm::cast<::mlir::VectorType>(type).getShape() == ArrayRef<int64_t>({4, 8})))

i.e. take the RHS from the input (vector<[4]x[8]xi32>). Hence the confusion.

c-rhodes added inline comments.Jul 17 2023, 2:56 AM
mlir/include/mlir/Dialect/ArmSME/IR/ArmSME.td
46–47

no worries, thanks for checking

I'll land this tomorrow unless there's any further comments by then.

This revision was landed with ongoing or failed builds.Jul 18 2023, 12:42 AM
This revision was automatically updated to reflect the committed changes.