This is an archive of the discontinued LLVM Phabricator instance.

[linalg][fusion] Disallow fusion when it would create an invalid expand_shape
ClosedPublic

Authored by bkramer on Jan 4 2022, 8:10 AM.

Details

Summary

The input type of a linalg.generic can be less dynamic than its output
type. If this is the case moving a reshape across the generic op would
create invalid IR, as expand_shape cannot expand arbitrary dynamic
dimensions.

Check that the reshape is actually valid before creating the
expand_shape. This exposes the existing verification logic in reshape
utils and removes the incomplete custom implementation in fusion.

Diff Detail

Event Timeline

bkramer created this revision.Jan 4 2022, 8:10 AM
bkramer requested review of this revision.Jan 4 2022, 8:10 AM
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
654

Should this be a static helper on tensor::CollapseOp ?
It feels like it could be reused in a bunch of places.

mravishankar requested changes to this revision.Jan 4 2022, 9:54 AM
mravishankar added inline comments.
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
795

I am not able to follow why we need this. The source of the expand_shape should be same as the original operand. We should not need an additional cast here, and also not able to follow the example below either.

This revision now requires changes to proceed.Jan 4 2022, 9:54 AM
bkramer added inline comments.Jan 4 2022, 10:54 AM
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
654

ExpansionInfo is local to the fusion pass though.

795

The example fuses tensor.collapse_shape %arg3 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64> into the input of a generic op. The generic op also takes a tensor<?xi64> input. This means it has to expand the tensor<?xi64> into tensor<1x1xi64>, which can't be done in a single expand_shape.

This is the result of incomplete shape information. The original program would verify the runtime shape of the input to be tensor<1xi64> and then run the linalg ops.

I'll reduce the test case a bit more to make it easier to understand.

bkramer updated this revision to Diff 397340.Jan 4 2022, 10:54 AM

Reduce test case more.

mravishankar added inline comments.Jan 4 2022, 1:22 PM
mlir/test/Dialect/Linalg/reshape_fusion.mlir
520

Ok, the issue being fixed here is not really an issue. This op as written is incorrect. (Linalg verifier does not check that)
Based on the indexing maps used here, %0 and %arg1 have to be the same shape, otherwise the op is undefined. The transformation only looks at the indexing maps, and therefore assumes the shape of %0 and %arg1 is the same. So its trying to insert a reshape operation.

bkramer added inline comments.Jan 4 2022, 1:27 PM
mlir/test/Dialect/Linalg/reshape_fusion.mlir
520

What does that mean? The shapes are the same at runtime.

mravishankar added inline comments.Jan 4 2022, 1:37 PM
mlir/test/Dialect/Linalg/reshape_fusion.mlir
520

Yeah, thats why it isnt checked as part of Linalg verifier. Would it be possible to insert a cast for the input itself?, i.e. make the op

%1 = tensor.cat %arg1 : tensor<?xi64> to tensor<1xi64>
%2 = linalg.generic
   {indexing_maps = [affine_map<(d0) -> (d0)>,
                     affine_map<(d0) -> (d0)>,
                     affine_map<(d0) -> (d0)>],
    iterator_types = ["parallel"]}
   ins(%0, %1 : tensor<1xi64>, tensor<1xi64>) {
     ...
   } -> tensor<1xi64>

If it is known to be always true at runtime then it would be preferable to add a cast during so that all the shapes are consistent.
Though I see your point about the cast. I am just vary of ad-hoc adding casts during transformations instead of adding casts while creating the operation itself, where presumably you have more immediate information of the shapes being the same.

mravishankar requested changes to this revision.Jan 5 2022, 2:21 PM

(just marking this as needs revision to get it off my must-review list)

This revision now requires changes to proceed.Jan 5 2022, 2:21 PM
herhut added a subscriber: herhut.Jan 7 2022, 6:46 AM
herhut added inline comments.
mlir/test/Dialect/Linalg/reshape_fusion.mlir
520

So would the contract be that all operands to linalg.generic or linalg operations for that matter, need to have the same static shape annotated? If that is the contract, then we should add checking for it to linalg's verifier. Otherwise we will find these issues only when transformation patterns expect this but it is not true.

Generally, patterns can only rely on static properties that are verified during verification.

I don't see a particular semantic problem with either cases (fully static, partially static, fully dynamic): all are valid and are subject to the same UB as a wrong cast would be.
I think Mahesh's suggestion is better because there is no guarantee in shape reification that you'd return 1 and you may well end up with propagating dynamic information further down.
I wouldn't change the verifier behavior here as it could limit the ability to have partial foldings that I think are valid.

I think Mahesh's suggestion is better because there is no guarantee in shape reification that you'd return 1 and you may well end up with propagating dynamic information further down.

There could well be a separate pass that inserts these casts to avoid the behavior in shape reification if that is critical to your use cases. I think it would still be better to fix shape reification to return a static value where possible. After all, if this is valid linalg IR, reification should do as good as it can with it.

Or, we declare linalg operations with same static knowledge as the canonical form and add inserting casts as a canonicalization pattern. Not ideal, because we would need to run the canonicalizer before fusion, but at least less surprising.

I wouldn't change the verifier behavior here as it could limit the ability to have partial foldings that I think are valid.

If the verifier allows this, then I think fusion should be able to handle it. We should be able to transform all valid IR, otherwise this will become frustrating to debug.

We had a deeper offline discussion with @bkramer
There is a bug that needs to be fixed.
In the same way that we have a "isCastCompatible", we should have a "isShapeCollapse/Expand" compatible.
In this form, the result of that "isShapeCollapse/Expand" is false and the pattern/transformation should fail to apply.

Separately, there is an enabling pattern that would make the type more static by casting the inputs in accordance the semantics of the generic op.
It is unclear whether this should be a blanket canonicalization or just an enabling pattern that needs to be applied selectively to make this case fuse.
This separate enabling pattern is more generally applicable and follows similar orthogonalization principles between legality and profitability that we have been applying everywhere.

bkramer updated this revision to Diff 400558.Jan 17 2022, 8:14 AM
  • Expose the shape verification logic (interface is still a bit ugly)
  • Use shared logic to stop fusion when the expand_shape would be invalid
  • Remove now redundant checking in elemental fusion

Can we make proper use of this new function also in the reshape op verifiers?
We should have only one logic used everywhere.
Thanks for pushing on this cleanup!

nicolasvasilache accepted this revision.Jan 17 2022, 8:37 AM

Ah this already happens transitively, LGTM, thanks @bkramer !

Please update the commit message to reflect your latest changes, thanks!

bkramer retitled this revision from [linalg][fusion] Cast reshape inputs to a known good type to [linalg][fusion] Disallow fusion when it would create an invalid expand_shape.Jan 17 2022, 8:43 AM
bkramer edited the summary of this revision. (Show Details)
mravishankar accepted this revision.Jan 18 2022, 1:41 PM

This makes sense to me. Thanks!

This revision is now accepted and ready to land.Jan 18 2022, 1:41 PM