This is an archive of the discontinued LLVM Phabricator instance.

[MLIR][SCF] Remove loop invariant arguments of scf.while
ClosedPublic

Authored by avarmapml on Jan 10 2022, 2:31 AM.

Details

Summary
  • This commit adds a canonicalization pattern on scf.while to remove the loop invariant arguments.
  • An argument is considered loop invariant if the iteration argument value is the same as the corresponding one being yielded (at the same position) in both the before/after block of scf.while.
  • For the arguments removed, their use within scf.while and their corresponding scf.while's result are replaced with their corresponding initial value.

Signed-off-by: Abhishek Varma <abhishek.varma@polymagelabs.com>

Diff Detail

Event Timeline

avarmapml created this revision.Jan 10 2022, 2:31 AM
avarmapml requested review of this revision.Jan 10 2022, 2:31 AM
ftynse requested changes to this revision.Jan 10 2022, 3:04 AM
ftynse added a subscriber: ftynse.

This seems to be based on a wrong assumption about the iteration arguments being always passed from the before to the after block, and in the same order. This is not guaranteed by the op. I'm okay with a pattern that only works in that specific case, but the pattern must check for it explicitly.

mlir/lib/Dialect/SCF/SCF.cpp
2372–2374

There is no guarantee that beforeBlockArgs and afterBlockArgs have the same length. IIRC, llvm::zip is actually zip-shortest, but make sure this assumption is used consciously.

2385

Couldn't this rather directly take the index + 1-th operand of the condition op and check it instead of iterating over the potentially large number of value users. Same below.

2387

There's no relation between the position of the value in the condition op and iter args prescribed by the op semantics.

For example, the following is perfectly valid and is likely the case we want to simplify here. I just swapped the arguments when passing them from the "before" to the "after" block.

scf.while <...> iter_args(%arg0 = ..., %arg1 = ...) {
  scf.condition %cond, %arg1, %arg0
} do {
^bb0(%arg2, %arg3):
  scf.yield %arg3, %arg2
}

And here we may want to remove %arg1.

scf.while <...> iter_args(%arg0 = ..., %arg1 = ...) {
  ...
  scf.condition %cond, %arg1
} do {
^bb0(%arg2):
  %non_invariant = ...
  scf.yield %non_invariant %arg2
}
2406–2408

RAUW in patterns is usually a bad idea. It will work in this case because it's a simple rewrite pattern, but it would be better to find a workaround.

2423

Nit: I got confused by "creating a new block" that refers to a compound statement in the C++ code rather than to the IR block. Using compound statements with RAII is a pretty common idiom in MLIR that doesn't need a special comment.

2441–2443

Again, there is no guarantee that the argument types for before and after blocks match. The pattern didn't check it either, so it feels like this may end up creating invalid IR.

2450–2451

We should rather do rewriter.replaceOp(op, ...).

2452–2453

This will leave some of the vector entries default-initialized (to nullptr, I suppose). I am surprised that mergeBlocks below is fine with that.

2460–2461

The code should rather populate newBefore/AfterBlockArgs properly and avoid doing RAUW above.

This revision now requires changes to proceed.Jan 10 2022, 3:04 AM
mehdi_amini added inline comments.Jan 10 2022, 12:18 PM
mlir/lib/Dialect/SCF/SCF.cpp
2387

It's interesting: I implemented this canonicalization on MHLO WhileOp a couple of weeks ago, and it was very trivial (50 lines: https://github.com/tensorflow/tensorflow/commit/710788431c91b23b5aef8fbfcac1f964c724c9b0#diff-1b007bb7b9fd6d952764a854d28c0e7386595284233dfcd35dc897e0217fffaeR5547-R5594 ) because the condition only takes a single i1.

2406–2408

I also didn't directly see how to avoid this in MHLO, but is it a problem for canonicalization?

rriddle added inline comments.Jan 10 2022, 12:21 PM
mlir/lib/Dialect/SCF/SCF.cpp
2406–2408

but is it a problem for canonicalization

Yes. Any mutation in a pattern outside of the rewriter is essentially undefined behavior.

mehdi_amini added inline comments.Jan 10 2022, 12:23 PM
mlir/lib/Dialect/SCF/SCF.cpp
2406–2408

Can you remind me why this is a problem outside of dialect conversion?

2406–2408

(also it'd be great to be able to assert on this)

rriddle added inline comments.Jan 10 2022, 12:27 PM
mlir/lib/Dialect/SCF/SCF.cpp
2406–2408

a) We often use canonicalization patterns in pattern drivers other than the greedy rewriter.
b) We've had quite a few crashes in the past with mutations not being updated in the internal maps of the greedy rewriter (often nowadays it doesn't create a crash/just a bit of a different work order).

Those are the two major ones, but the mentality of "it works" is really dangerous for this because it isn't guaranteed to work. We could update the driver in some seemingly legal way (which we have before), and end up crashing existing code. It also breaks composability, because now patterns can only be used in very specific use cases/drivers. None of those are desirable to me, and I'd rather expand on the PatternRewriter API. All or nothing for these types of API contracts is the only way to prevent confusion/headaches/crashes down the line.

avarmapml updated this revision to Diff 398913.Jan 11 2022, 4:57 AM
  • Added two patterns to take care of various cases that may arise.
  • Not using RAUW rather resorting to the rewrite driver.
avarmapml marked 9 inline comments as done.Jan 11 2022, 5:05 AM

Hi @ftynse
Thank you for your input. I clearly missed out on the entire scf.while construct earlier.
I've tried addressing the same in this new update.
I've added 2 patterns (also demonstrated with necessary test cases) :-

  1. Pattern 1: It aims to figure out which of the before block's arguments are loop invariant.
  2. Pattern 2: It aims to remove those values from scf.condition which are defined outside the block.

The above two patterns aims to achieve the non-trivial case you enlisted.

It brings up another opportunity/case : A variant of Pattern 2 which would aim to do the analysis for scf.yield.

I believe this would then cover many possible scenarios.

Added two patterns because it made sense to logically separate the two out which in turn makes them extensible.

You may re-review now.

avarmapml marked 6 inline comments as done.Jan 11 2022, 5:06 AM
mehdi_amini added inline comments.Jan 11 2022, 11:09 AM
mlir/lib/Dialect/SCF/SCF.cpp
2361

Isn't arg1 intended to be in first position here?
Or maybe the example is missing an %arg2 = %a in the iter_args?

I think it'd help if you were naming the variable differently and matching the names of the "after" block arguments with the condition operands name.

2422

I'm not sure why the condOpIndices is needed here? Can't you just fuse this loop with the previous one?

(also in general I prefer to stick with signed int when possible)

2518
2533

afterBlockArgs is unused in this loop.

2545–2546

We should be able to rewrite the while in-place.

I chatted with River offline yesterday, if we're missing APIs on the rewriter, then let's take revisions such as this one as opportunities to add the missing APIs, even if their implementation is "unsafe" (don't allow roll-back). At least the patterns will be written in an optimal way.

ftynse requested changes to this revision.Jan 12 2022, 1:12 AM

Thanks! There seems to be a simpler way to express the same transformation.

mlir/lib/Dialect/SCF/SCF.cpp
2361

Isn't arg1 intended to be in first position here?

%arg3_after is always %arg0_before. This way, %arg0_before is always %a. So this looks okay to me.

%arg1_after would also work though, which deserves a comment.

Renaming would definitely help.

2403

Nit: drop the number of stack elements from SmallVector unless you have a strong reason why a specific number is used.

2406–2408

Re how to avoid RAUW: if the value being replaced in a block argument, create a new block without this argument then merge in the old block into the new block and provide the replacement value to the merge call. This is messy and should be better supported by the rewriter.

2422

+1.

I would just walk the use-def backwards, this is the natural way to express such things. We know that i-th before argument comes from either i-th iter_arg or from i-th yield operand. So take the i-th yield operand and see if it either equal to the i-th iter_arg directly or if it some k-th argument of the "after" block. In the former case, we have already proven that the i-th iter_arg is loop invariant. In the latter case, since we know that k-th argument of the "after" block comes from the (k+1)-th operand of the condition, we can further check if the (k+1)-th operand of the condition is equal to the i-th argument of the "before" block. If so, the i-th iter_args is also invariant.

The only iteration we need is over the arguments of yield. I would be also tempted to express this as a pattern that removes one argument at a time since the greedy rewriter will keep applying it as long as it matches, but it is more expensive and there is a limit on the number of repeated applications.

2545–2546

Unless I am missing something, we can't rewrite in-place if the number of returned values changes (it does here).

This revision now requires changes to proceed.Jan 12 2022, 1:12 AM
bondhugula added a comment.EditedJan 12 2022, 7:15 AM

I think it's a big missing piece that there isn't a way to replace uses of a block argument in a rewrite pattern other than by RAUW. There are many rewrite patterns in the tree that use RAUW because there isn't another way -- the approach is completely safe now in the greedy rewrite driver as it would only potentially cause the rewrite driver to pick up the updated operations in its next outer loop iteration instead of the same one -- so there is less progress and slower convergence FWIW. (One would be in trouble however if a populate... method is exposed and the pattern gets added in something that's not the pattern rewrite driver.)

if we're missing APIs on the rewriter, then let's take revisions such as this one as opportunities to add the missing APIs, even if their implementation is "unsafe" (don't allow roll-back).
At least the patterns will be written in an optimal way.

I couldn't parse your last clause. Is there a typo?

mehdi_amini added inline comments.Jan 12 2022, 12:13 PM
mlir/lib/Dialect/SCF/SCF.cpp
2545–2546

Right: I think I was thinking about the first pattern above

if we're missing APIs on the rewriter, then let's take revisions such as this one as opportunities to add the missing APIs, even if their implementation is "unsafe" (don't allow roll-back).
At least the patterns will be written in an optimal way.

I couldn't parse your last clause. Is there a typo?

I meant that we should write patterns using only rewrite APIs, but without sacrificing the performance of the rewrite: so as much "in-place" as possible. When we're limited by the rewriter API surface we likely should add more instead of writing patterns in a convoluted way to fit the rewriter limitations.

if we're missing APIs on the rewriter, then let's take revisions such as this one as opportunities to add the missing APIs, even if their implementation is "unsafe" (don't allow roll-back).
At least the patterns will be written in an optimal way.

I couldn't parse your last clause. Is there a typo?

I meant that we should write patterns using only rewrite APIs, but without sacrificing the performance of the rewrite: so as much "in-place" as possible. When we're limited by the rewriter API surface we likely should add more instead of writing patterns in a convoluted way to fit the rewriter limitations.

Right - I agree.

avarmapml marked 8 inline comments as done.

Addressed review comments.
Major change is w.r.t the way loop invariants were decided.
This update traverses each scf.yield operands and back traces to query whether the
corresponding before block argument should be removed.

avarmapml marked 2 inline comments as done.Jan 19 2022, 1:27 AM
ftynse accepted this revision.Jan 20 2022, 3:07 AM

LGTM with comments addressed. Thanks for iterating!

mlir/lib/Dialect/SCF/SCF.cpp
2431

You do like these loops... There is rarely a need to iterate over a list of block arguments or op results to find the right value. Here, this does the same without loops.

auto yieldOpBlockArg = yieldOpArg.dyn_cast<BlockArgument>();
if (yieldOpBlockArg && yieldOpBlockArg.getOwner() == &afterBlock) {
  Value condOpArg = condOpArgs[yieldOpBlockArg.getArgNumber()];
  // ...
}
2464

Nit: I suppose you could do something like ValueRange(newYieldOpArgs).getTypes() and avoid having an extra vector for newBeforeBlockType, but it's not a big deal.

2468–2475

Nit: you can write directly to newBeforeBlockArgs[i] instead of creating a variable

mlir/test/Dialect/SCF/canonicalize.mlir
869–871

Why isn't the order deterministic here?

872–882

I'd just CHECK, we still see the transformation is correct even if the operations are not on immediately next lines (or if they change the syntax to have multiple lines). This makes tests less brittle.

877

Drop the comment here (there is actually a patch that removes it, which would break this test). We shouldn't test for irrelevant things.
Similarly, don't match for a specific block name, just ^{{.*}}( should be fine.

This revision is now accepted and ready to land.Jan 20 2022, 3:07 AM
mehdi_amini added inline comments.Jan 20 2022, 4:33 PM
mlir/lib/Dialect/SCF/SCF.cpp
2375

When does it make sense to keep implicitly captured values (%a and %b here) as operands of scf.condition?

2467

This loop isn't trivial and like deserve some doc

mlir/test/Dialect/SCF/canonicalize.mlir
913

You don't need to check for things like that (types), it is covered by the verifier already

In general try to check for the strict minimal amount of things.

avarmapml marked 9 inline comments as done.

Addressed review comments.
Two major changes to notice are :-

  1. Updates were made to use yield op's block argument number instead of resorting to loops.
  2. Test cases CHECK lines were made more compact/strict by not comparing types.
mlir/lib/Dialect/SCF/SCF.cpp
2375

You're correct with this observation but this pattern will remove the invariant arguments from before block and replace their uses within.

And that is where the second pattern in this PR (RemoveLoopInvariantValueYielded) comes into play which eliminates such scf.condition operands and the corresponding after block arguments.

2431

Haha.
This was really useful. It never crossed my mind to check for such available APIs.
Thank you!

2464

This was helpful too - thanks.

mlir/test/Dialect/SCF/canonicalize.mlir
869–871

They're indeed deterministic. I'm using CHECK now for all the lines as you mentioned in another comment that it'd make the test cases less brittle.

913

I understand. I've made the changes for other test lines too. You may check.

Hi @mehdi_amini @ftynse
The review comments were useful and I've made the said updates in the revision.
You may re-review/accept/land this revision.

Hi @ftynse @mehdi_amini @bondhugula ,

I've addressed all the useful review comments made so far.

If there are no other review comments, can this revision be landed?

mehdi_amini added inline comments.Jan 24 2022, 8:09 PM
mlir/lib/Dialect/SCF/SCF.cpp
2375

It seemed to me that we can get "for free" the final form in the first pattern instead of doing "half the work" and having to re-do more expensive work afterward?

2449

I'd like that we first check if there is anything to do before starting to build vectors/maps: the common path (do nothing) should be as fast as possible.

avarmapml updated this revision to Diff 403880.Jan 27 2022, 9:35 PM
avarmapml marked 2 inline comments as done.

Making early exit as early as possible before starting to build vectors/maps.

avarmapml added inline comments.Jan 27 2022, 9:36 PM
mlir/lib/Dialect/SCF/SCF.cpp
2375

The idea was to logically separate the task into two kinds of patterns that we'd like to canonicalize :-

  1. Identify and remove loop invariant before block arguments.
  2. Identify and remove loop invariant return operands which in turn affects after block arguments.

Early exit from both the patterns would be beneficial in this case as you rightly pointed out.

Good enough for me right now! I'll let Alex have another look.

mehdi_amini added inline comments.Jan 28 2022, 12:00 AM
mlir/lib/Dialect/SCF/SCF.cpp
2400–2406

Nit: move all the declarations closer to their uses, that is after the early exit line 2422 (for those that aren't necessary before)

2418

Nit: no else after break.

avarmapml updated this revision to Diff 403959.Jan 28 2022, 3:59 AM
avarmapml marked 2 inline comments as done.

Handled review comments.

Hi @mehdi_amini @ftynse @bondhugula

If there are no more comments, can this revision be landed soon?

Please rebase on HEAD and reupload. It looks like there is a non-trivial assertion failure.

avarmapml updated this revision to Diff 404847.Feb 1 2022, 2:10 AM

Rebased to latest HEAD.
createBlock API requires argument location too and hence lead to assertion failure.
Fixed the same and brought a couple more declarations closer to their uses.

Hi @ftynse

I've rebased the current revision on HEAD and resolved the assert failures.

This revision was automatically updated to reflect the committed changes.