Page MenuHomePhabricator

[mlir][Linalg] Allow `tileConsumerAndFuseProducers` to return the values of the fused operations.
Needs ReviewPublic

Authored by mravishankar on May 9 2022, 3:11 PM.

Details

Reviewers
nicolasvasilache
Summary

Current tileConsumerAndFuseProducers method allows for only
replacing the uses of the tiled consumer, and not the fused producers.
Based on how fusion works currently, it is not always possible/easy to
return a tensor value from the tiled loop nest to replace the uses of
the fused producers since the same tile of the fused producer might be
recomputed while computing multiple tiles of the consumer op. This
patch adds an option that will allow the transformation to return the
tensor replacements for the fused producer ops, but expects the caller
to control the tile sizes so that the tiles of the fused producer are
computed only once, by controlling which loops are tiled. There might
be an analysis that can determine this, but for now this left to
control from the caller.

Depends D125147

Diff Detail

Event Timeline

mravishankar created this revision.May 9 2022, 3:11 PM
Herald added a project: Restricted Project. · View Herald TranscriptMay 9 2022, 3:11 PM
mravishankar requested review of this revision.May 9 2022, 3:11 PM

Add patch dependency.

A first round of comments to improve the doc and help digest the complex implementation of the function.
I suspect it will need to be broken into smaller more composable pieces.

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
550

type: Interchange

568

typo: construct

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
350–371

Please rephrase, clarify and use better variable names.

Assuming `tiledOp` is the tiled version of `originalOp` and `mixedXXX` are YYY.
Assuming `tiledOp` is fused (or do we want to fuse it at this point?) into the current loop nest.
Update `tileLoopOps` with (one?) additional iter_args and (one?) additional yield operand such that ZZZ.

Example:
IR
353

the name isn't descriptive enough for this seemingly complicated process.

358

Better description plz.

Returns the OpResult of the outermost loop that corresponds to the original operation before it has been tiled for fusion.
(Any single result assumption here?).
417

You have access to: 1. the top-level loop, 2. fusedOps and 3. replacements.
IIUC 3. can be recovered from 1+2 and 2 can be recovered from 1+3.

Can we make one or the other an accessor that computes the information rather than storing it?
This will reduce opportunities for errors with them going out of sync.

423

I am concerned that we are making this already very complex function even more complex ..

Side note, do you still need tileDistribution in this function or could we get rid of it (in a followup).

mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
612–617

You can rewrite as a 1-liner: fuseProducersGreedily(..., returnFusedOpValues)

mravishankar marked 3 inline comments as done.

Rebase and address comments.

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
350–371

Added IR example near the implementation.

353

Happy to rename to any thing that is more descriptive. This is essentially what this method does. (Also, its the bare minimum I could think of to make fusion + scf.for + tensors work.)

358

No single result assumption here.

417

I dont think so. You need to know both 2, and 3 to be able to map which result maps to a replacement of which of the fused op.

423

For a fusion that is generating scf.for with tensors, I cant see a simpler implementation than this. If we move away from using scf.for in general, then this would be much easier.

I do need the tileDistribution here. Essentially there is a custom tileConsumerAndFuseProducer method in IREE that I am trying to replace. That needs distribution as well (with the replacement for the fused ops where possible, without which fusion is useless since the original op ends up in the generated executable anyway). When IREE moves away from using scf.for + tensors, we can drop the distribution aspect of this. Still think the change in this patch is required for completeness (unless we drop support of fusion + scf.for + tensors completely).