This is an archive of the discontinued LLVM Phabricator instance.

[mlir] Gradually lower vector to SCF
ClosedPublic

Authored by springerm on Apr 16 2021, 12:14 AM.

Details

Summary

Add a new GradualVectorToSCF pass that lowers vector transfer ops to SCF by gradually unpacking one dimension at time. Unpacking stops at 1D, but can be configured to stop earlier, should the HW support (N>1)-d vectors.

The current implementation cannot handle permutation maps, masks, tensor types and unrolling yet. These will be added in subsequent commits. Once features are on par with VectorToSCF, this implementation will replace VectorToSCF.

Diff Detail

Event Timeline

springerm created this revision.Apr 16 2021, 12:14 AM
springerm requested review of this revision.Apr 16 2021, 12:14 AM
Herald added a project: Restricted Project. · View Herald TranscriptApr 16 2021, 12:14 AM
nicolasvasilache accepted this revision.Apr 19 2021, 4:35 AM

Nice!

In a followup we'll prob also want a 3-D -> 1-D test to see things connect properly ?

mlir/lib/Conversion/VectorToSCF/GradualVectorToSCF.cpp
39

__vector_to_scf_lowering__ ?
I imagine there could be some collision at some future point in time.

78

typo indices

93

you could just do:

using mlir::edsc::op::operator+;

right above this line.

131

in MLIR we omit trivial braces.

347

same comment re. inplace op update

399

This is usually considered dangerous and we prefer rewriter.clone to keep track of stuff.
But this goes away if you use root updates.

400

If types don't change, we usually prefer:

rewriter.startRootUpdate(xferOp);
... // updates
rewriter.finalizeRootUpdate(xferOp);

an so no need to clone, erase ops etc.

402

This is usually considered dangerous and we prefer rewriter.clone to keep track of stuff.
But this goes away if you use root updates.

405

The 2 paths are almost identical code, can we create a single templated pattern ?

mlir/test/lib/Transforms/TestVectorTransforms.cpp
12

In MLIR we usually use the term "progressive" for this: "progressive lowering" (not necessarily in code but def. in other communications like prez/posts etc.).
Could you please update everywhere for consistency?

This revision is now accepted and ready to land.Apr 19 2021, 4:35 AM
springerm marked 8 inline comments as done.

Address review comments.

Update commit message.

Change commit message.

Harbormaster completed remote builds in B99616: Diff 338713.

Add cmake file.

springerm updated this revision to Diff 338746.Apr 20 2021, 1:05 AM

Added unit test.

springerm updated this revision to Diff 338750.Apr 20 2021, 1:18 AM

Update test case.

springerm added inline comments.Apr 20 2021, 1:19 AM
mlir/lib/Conversion/VectorToSCF/GradualVectorToSCF.cpp
347

There is no rewriter.replaceAllUsesWith. I couldn't find a way to replace all uses without cloning/creating a new op. Value.replaceAllUsesWith seems dangerous because it does not go through rewriter.

405

You mean PrepareTransferWriteConversion and PrepareTransferReadConversion? What's identical is mostly the checks in the beginning of the function. I put those in a separate function.

mlir/lib/Conversion/VectorToSCF/GradualVectorToSCF.cpp
347

If you do inplace update (i.e. to a first good approximation: if your return types don't change), then you just update in place and there is no need to replace anything: the use-def chains are already connected and don't change.

So it's both shorter code and more efficient :)

mlir/lib/Conversion/VectorToSCF/GradualVectorToSCF.cpp
405

I was also thinking about some of the helper stuff:

ScopedContext scope(rewriter, xferOp.getLoc());
auto allocType = MemRefType::get({}, xferOp.getVectorType());
auto buffer = setAllocAtFunctionEntry(allocType, xferOp);

Seems like we could have something like:

template <OpType>
struct PrepareTransferConversion
    : public OpRewritePattern<OpType> {
  using OpRewritePattern<OpType>::OpRewritePattern;

  PrepareTransferConversion(Lambda doit) : _doit(doit) {}

  LogicalResult matchAndRewrite(OpType xferOp,
                                PatternRewriter &rewriter) const override {
    if (xferOp->hasAttr(kPassLabel))
      return failure();
    if (xferOp.getVectorType().getRank() <= kTargetRank)
      return failure();
    if (xferOp.mask())
      return failure();
    if (!xferOp.permutation_map().isIdentity())
      return failure();

    ScopedContext scope(rewriter, xferOp.getLoc());
    auto allocType = MemRefType::get({}, xferOp.getVectorType());
    auto buffer = setAllocAtFunctionEntry(allocType, xferOp);
    // ... other common things

    _doit(...);
}

If that sounds too convoluted feel free to ignore.

tschuett added inline comments.
mlir/lib/Conversion/VectorToSCF/ProgressiveVectorToSCF.cpp
42 ↗(On Diff #338750)

LLVM has TargetTransformInfo, which you can query at runtime for such values.

Harbormaster completed remote builds in B99638: Diff 338750.
This revision was automatically updated to reflect the committed changes.