Page MenuHomePhabricator

Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline

[mlir][linalg] Enable parallel partial reduction tiling with multiple dims
AcceptedPublic

Authored by qedawkins on Aug 21 2023, 7:56 PM.

Details

Summary

This extends transform.structured.tile_reduction_using_forall to
operations with multiple reduction dimensions as implied by the thread
counts. This enables reduction splitting strategies for operations with
higher dimensionality.

Diff Detail

Event Timeline

qedawkins created this revision.Aug 21 2023, 7:56 PM
Herald added a project: Restricted Project. · View Herald TranscriptAug 21 2023, 7:56 PM
qedawkins requested review of this revision.Aug 21 2023, 7:56 PM
Groverkss added inline comments.Aug 23 2023, 10:13 PM
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
663

You can use .empty()

713

Why not use an ArrayRef?

725

nit: Don't use auto here.

731–736

It's not very clear what exactly is happening here. Could you add more explanation?

809

You can use indVars here.

Address comments

qedawkins marked 4 inline comments as done.Aug 23 2023, 10:46 PM
qedawkins added inline comments.
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
731–736

Let me know if the explanation here makes sense.

@Groverkss gentle bump on the review here when you have time.

This revision is now accepted and ready to land.Aug 29 2023, 8:59 AM
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
652

Can we write this with LinalgOp::getReductionDims and a followup filter ?
It is unclear to me whether this custom logic has something load-bearing offahnd.

qedawkins added inline comments.Aug 29 2023, 9:14 AM
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
652

It's both getting the reduction dims and also identifying which thread counts in the scf.forall correspond to reduction dimensions. I can do this but the logic here will look quite similar. I'll also add a comment.

thanks for generalizing this transform!

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
725

Can we extract this in a meaningfully named helper function?
With one single reduction dimension the intent was clear but now the nesting is deeper than I'd like (in an already too long function).

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
407

it is weird to me that you only need to specify 2 entries in num_threads here, I would have expected you'd need [0, 4, 2] (like in your second test below).
Are / were we somehow too permissive in the specification?
Would be good to tighten the verifier to force alignment of number of dimensions on the rank of the linalg op when appropriate.

It would be good to also have a [red, par, red] test for the "interleaved parallel" case.

qedawkins added inline comments.Aug 30 2023, 7:26 AM
mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
407

This is intentional, as I'm trying to tile parallel dimensions as well as reductions here. As far as I could tell, this was never explicitly prohibited by the pattern and I find it convenient to be able to tile both at the same time (and otherwise avoid nested foralls which interact poorly with distribution later on). The interleaved parallel case is a good idea though, will add a test for it.

In terms of forcing rank to align, unlike the scf.for version of this pattern, additional tile sizes require corresponding entries in the mapping which restricts the mapping options for distribution. For example, now I need to distribute explicitly along gpu.thread<x> in addition to gpu.thread<z> and gpu.thread<y> if I want to tile the parallel and first reduction dimensions only, and adding more dimensions requires going to linearized thread indices which don't work well when we are intentionally avoiding distribution along a specific dimension (e.g. x for later use with warp distribution patterns).

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir
407

Re tiling parallel and reduction at once, this is a great idea indeed, thanks for pushing on this.
I was thinking that we could insert a <none> or <seq> mapping kind to allow us to skip or lower to loops but if this is too tedious for now let's table it.