This is an archive of the discontinued LLVM Phabricator instance.

[mlir] GreedyPatternRewriteDriver: Enqueue ancestors in MultiOpPatternRewriteDriver
ClosedPublic

Authored by springerm on Jan 17 2023, 8:41 AM.

Details

Summary

The GreedyPatternRewriteDriver was extended to enqueue ancestors in D140304. With this change, MultiOpPatternRewriteDriver behaves the same way.

Note: MultiOpPatternRewriteDriver now also has a scope that limits how far we go when checking ancestors. By default, this is the first common region of all given ops.

Depends On: D141921

Diff Detail

Event Timeline

springerm created this revision.Jan 17 2023, 8:41 AM
springerm requested review of this revision.Jan 17 2023, 8:41 AM
Herald added a project: Restricted Project. · View Herald TranscriptJan 17 2023, 8:41 AM
springerm edited the summary of this revision. (Show Details)Jan 17 2023, 8:42 AM
mehdi_amini added inline comments.Jan 19 2023, 9:01 AM
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
104

The comment should be updated, now the invariant is less clear: what is the use-case when this is not set?
Also when is the state reset?

336

Not clear to me that it is safe to do this without a scope, basically that means going all the way up to the module all the time?

springerm updated this revision to Diff 491389.Jan 23 2023, 8:17 AM
springerm marked an inline comment as done.

address comments

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
104

Changed it to how it was before.

336

Actually, good point. I think it would make sense to add a scope parameter to the MultiOpPatternRewriter. That way the two rewriters (Region-based and Multi-op-based) behave exactly the same.

springerm edited the summary of this revision. (Show Details)Jan 23 2023, 8:18 AM
mehdi_amini added inline comments.Jan 24 2023, 10:31 AM
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
648

What if scope is nullptr here?

754

I can't understand the condition here, are you sure it is correct? Shouldn't it be instead: while (llvm::any_of(...``?

757

I'm not sure about the complexity of this?

Could there be some memoization?

mehdi_amini added inline comments.Jan 24 2023, 10:33 AM
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
104

scope seems to be set also in MultiOpPatternRewriteDriver::simplifyLocally?

springerm updated this revision to Diff 492029.Jan 25 2023, 1:48 AM
springerm marked 4 inline comments as done.

address comments

mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
104

Done.

Note: It's getting difficult to understand when certain fields are set. This is addressed in D141949.

648

It cannot. The caller (applyOpPatternsAndFold) must provide a valid scope. Added an assertion.

754

You are right. Rewrote the entire thing with memoization.

757

Ops are no longer checked multiple times.

mehdi_amini added inline comments.Jan 25 2023, 9:42 AM
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
753

Should be able to write this as SmallVector<Operation *> remainingOps = to_vector(ops.drop_front())

766

Interesting memoization!

I was thinking about memoizing the region traversal, I would think the following to be close to optimal?

DenseMap<Region*, int count> visitedRegions;
for (Operation *op : ops) {
  Region *parent = op.getParentRegion();
  while (parent) {
    auto it = visitedRegions.insert({parent, 1});
    if (!it.second) {
      // This parent region was already visited, bump counter and stop traversal
      ++it.first->second;
      break;
    }
    parent = regrion->getParentRegion();
  }
}
// From any op, traverse again the parent region: the enclosing region is the one
// that has a count equal to the number of ops.
Region *region = op.front.getParentRegion();
do {
  assert(visitedRegions.count(region));
  if (visitedRegions[region] == ops.size())
    break;
  region = region->getParentRegion();
}
return region;
springerm marked an inline comment as done.Jan 26 2023, 1:33 AM
springerm added inline comments.
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
766
// This parent region was already visited, bump counter and stop traversal

I think the traversal has to continue here. With that the original algorithm may be faster in practice. What do you think?

Previous algo stops iterating parent regions when an ancestor for all ops was found. So I expect to iterate one, maybe two parent regions in most cases. (Assuming that often ops in the same "local area" are rewritten.) In your algorithm, we iterate all the way up to the top-level region for each op.

Previous algorithm was O(num_ops^2 + num_ops * region_depth), but the num_ops^2 is just a lookup in an array. We can optimize it by using a bit vector.

Your algorithm is O(num_ops * region_depth), so asymptotically it is better.

There was a bug where the region was set to the parent once too many in my previous code. Fixed that and using a bit vector now.

springerm updated this revision to Diff 492357.Jan 26 2023, 1:33 AM

address comments

mehdi_amini added inline comments.Jan 26 2023, 4:48 PM
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
756

What about a ops = ops.drop_front() after that? (and sz = ops.size())

766

I think the traversal has to continue here.

Mmmm right, I was thinking about a "union-find" kind of algorithm but I over-simplified the data structure here...
Having to traverse to the top for each op isn't what I intended, that defeats the memoization!

In terms of complexity I didn't get how you computed it? I was seeing it more as O(num_ops^2 * region_depth) isn't it?
The findAncestorOpInRegion call is itself O(region_depth) and is in the inner loop, and it'll traverse the entire hierarchy on every miss.
(this is the call that bothered me a bit in the inside loop that I was trying to avoid)

We can optimize it by using a bit vector.

Is the bit vector changing the complexity? I'm not sure that for small bit vector it'll be more efficient that a simple vector actually, do you know?

mehdi_amini accepted this revision.Jan 26 2023, 5:00 PM

I'm not convinced we have the most optimal findCommonAncestor but I guess we can revisit this later if/when it shows up on a profile!

This revision is now accepted and ready to land.Jan 26 2023, 5:00 PM
springerm marked 2 inline comments as done.Jan 27 2023, 1:39 AM
springerm added inline comments.
mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
766

Right, with the current loop structure: outer loop iterations = region_depth (or num_ops), inner loop iterations = num_ops, findAncestorOpInRegion: region_depth
The bit vector doesn't change the complexity, but maybe it could be beneficial for caching effects. Only one bit needed instead of 64 to mark an op as "remaining". Or maybe it doesn't do anything...

This revision was landed with ongoing or failed builds.Jan 27 2023, 1:39 AM
This revision was automatically updated to reflect the committed changes.
springerm marked an inline comment as done.