This is an archive of the discontinued LLVM Phabricator instance.

[mlir][linalg] Add pattern to convert batch_matmul to matmul
AbandonedPublic

Authored by ThomasRaoux on Feb 17 2023, 12:57 PM.

Details

Summary

If the batch dimension of a batch matmul is 1 it can be converted to a
linalg.matmul. This avoid reshape when converting mix of generic ops
with batch matmul.

Diff Detail

Event Timeline

ThomasRaoux created this revision.Feb 17 2023, 12:57 PM
ThomasRaoux requested review of this revision.Feb 17 2023, 12:57 PM
mravishankar requested changes to this revision.Feb 17 2023, 7:36 PM

Probably needs to be into NamedOpConversion pass

This revision now requires changes to proceed.Feb 17 2023, 7:36 PM
nicolasvasilache requested changes to this revision.Feb 20 2023, 1:21 AM

I am unclear why this special casing is needed...
I routinely convert higher-D ops into lower-D generic ops with rank-reducing patterns.
In my use cases, I am not relying on the named ops for further transformations / optimizations.

IMO this just points at the need to finally start that inverse generalization transform: linalg.generic -> linalg.named op.
I was still dragging my feet because I am still hoping we can do significantly better than special casing, but when the special casing leaks into N^2 land, it is time to just do it.

So bottom line, could you start a new file called Specialization.cpp with a functional transform API FailureOr<LinalgOp> specialize(linalg::GenericOp genericOp).

For this use case, you should only need:

if (isaContractionOpInterface(genericOp) && /*check indexing maps*/ && /*check reduction type*/)
  return rewriter.replaceOpWithNewOp<linalg::MatmulOp>(...);

That switch can grow over time and is the only place in the compiler where we need such logic.
This will compose with rank-reducing patterns that we have filtere over time to properly use rank-reducing slices.

I am unclear why this special casing is needed...
I routinely convert higher-D ops into lower-D generic ops with rank-reducing patterns.
In my use cases, I am not relying on the named ops for further transformations / optimizations.

IMO this just points at the need to finally start that inverse generalization transform: linalg.generic -> linalg.named op.
I was still dragging my feet because I am still hoping we can do significantly better than special casing, but when the special casing leaks into N^2 land, it is time to just do it.

So bottom line, could you start a new file called Specialization.cpp with a functional transform API FailureOr<LinalgOp> specialize(linalg::GenericOp genericOp).

For this use case, you should only need:

if (isaContractionOpInterface(genericOp) && /*check indexing maps*/ && /*check reduction type*/)
  return rewriter.replaceOpWithNewOp<linalg::MatmulOp>(...);

That switch can grow over time and is the only place in the compiler where we need such logic.
This will compose with rank-reducing patterns that we have filtere over time to properly use rank-reducing slices.

We are seeing quite a few problems related to named ops indeed. See the analysis made by Sean here:
https://github.com/iree-org/iree/issues/12214#issuecomment-1437481667

I'm not sure I understand what you are suggesting for this case though. The problem is that when we have matmul with generic ops around the generic ops will get reshaped but not the named ops, therefore this leaves reshape ops in the graph blocking further optimizations.
Are you saying the right fix should be to just change the existing pattern on linalg generic and make it work on LinalgOp interface and therefore convert batch.matmul into linalg.generic?

That sounds like a good direction but this means it would apply generalization even if user doesn't want that, so the op may get specialized again but in some cases it won't which makes it harder for user to handle all the cases.
Could you give a bit more details on how you think the flow should work?

ThomasRaoux abandoned this revision.Feb 23 2023, 1:54 PM