This is an archive of the discontinued LLVM Phabricator instance.

[mlir][MemRef] Simplify extract_strided_metadata(expand_shape)
ClosedPublic

Authored by qcolombet on Sep 9 2022, 5:43 PM.

Details

Summary

Add a pattern to the pass that simplifies extract_strided_metadata(other_op(memref)).

The new pattern gets rid of the expand_shape operation while materializing its effects on the sizes, and the strides of the base object.

In other words, this simplification replaces:

baseBuffer, offset, sizes, strides =
    extract_strided_metadata(expand_shape(memref))

With

baseBuffer, baseOffset, baseSizes, baseStrides =
    extract_strided_metadata(memref)
sizes#reassIdx =
  baseSizes#reassDim / product(expandShapeSizes#j,
                          for j in group excluding
                                 reassIdx)
strides#reassIdx =
baseStrides#reassDim * product(expandShapeSizes#j,
                           for j in
                                 reassIdx+1..
                                 reassIdx+group.size-1)

Where reassIdx is a reassociation index for the group at reassDim and expandShapeSizes#j is either:

  • The constant size at dimension j, derived directly from the result type of the expand_shape op, or
  • An affine expression: baseSizes#reassDim / product of all constant sizes in expandShapeSizes.

Diff Detail

Event Timeline

qcolombet created this revision.Sep 9 2022, 5:43 PM
Herald added a project: Restricted Project. · View Herald TranscriptSep 9 2022, 5:43 PM
qcolombet requested review of this revision.Sep 9 2022, 5:43 PM
chelini added inline comments.Sep 12 2022, 12:55 AM
mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
316
369

repeated the

chelini accepted this revision.Sep 12 2022, 2:24 AM
chelini added inline comments.
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
211
231

I would mention explicitly that this pattern does not affect base buffer and offset, as mentioned in the test cases.

This revision is now accepted and ready to land.Sep 12 2022, 2:24 AM

Thanks for the quick review @chelini !

mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
231

Good point!
It is not that clear in the high level comment of that pattern/commit message too since I changed the name of the offset value (from offset to baseOffset).
I'll fix that and also call it out.

mlir/test/Dialect/MemRef/simplify-extract-strided-metadata.mlir
316

Good catch x)

398

BTW something I haven't done yet is making sure this gets simplified into baseSizes#1 / 10 since at this point we know that baseSizes#1 is divisible by both 10 and 3.

qcolombet updated this revision to Diff 459933.Sep 13 2022, 5:44 PM
  • Fix typos
  • Add a few comments
  • Make sure the constant propagation happens for dynSize.floorDiv(sizes) * stride happens, since we know that dynSize is a multiple of sizes
qcolombet marked 3 inline comments as done.Sep 13 2022, 5:56 PM
qcolombet added inline comments.
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
211

I'm actually with you @chelini on that one, but following the LLVM coding standard, I think this is one of the mandatory auto. (I.e., clang-tidy, I think, would replace the type by auto here.)

More specifically this sentence "[...]Don’t “almost always” use auto, but do use auto with initializers like cast<Foo>(...)[...]" in the coding guide.
https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

I'm happy to change it though.
Just let me know.

qcolombet updated this revision to Diff 459954.Sep 13 2022, 7:03 PM
  • Add the handling of 0-D
  • Add a test case for 0-D
qcolombet updated this revision to Diff 459955.Sep 13 2022, 7:05 PM

Fix typo in comment

mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
194

nit: affine.apply

268

Can anything be refactored and reused from MemRefOps.cpp::namespace saturated_arith { ?

qcolombet added inline comments.Sep 14 2022, 9:28 AM
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
268

Thanks for the pointer, I didn't know about these utilities.

At first glance, it doesn't look like it would fit here.

In particular, here we don't saturate when we see a dynamic stride/size, we replace it by 1 and patch up the expression with the proper dynamic value later.

The suffix product that we do here is similar to what computeExpandedLayoutMap does, except in our case we don't want to saturate.

I'm guessing the suffix product is used in a lot of places given that's how strides are computed, maybe we can refactor on that front. Though how we deal with dynamic dimensions is where things diverge and I don't know how to reconcile that at the moment.

Put differently, I don't know yet what the refactoring should look like.

Do you want me to spend more time looking into this or should we move on for now and wait for more use cases to show up?

mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
268

I was hoping dome refactoring into simple utils would be possible because the code below feels like it's been rewritten a bunch of times in different ways.

For instance, I find the logic of the hasDynamicSize + productOfKnownStaticSize to be quite tricky to follow with further uses of hasDynamicSize.
I would hope that we can express this much more like functional-style applications without all these loops and levels of nesting.

For instance, l269 - 300 should look like:

SmallVector<int64_t> suffixProductOfSizes = 
  suffixProduct(expandShapeType.getShape(), begin, end, [](int64_t val, int64_t acc){ return saturated_arith(...); })

I made similar usability comments in https://reviews.llvm.org/D128986 but not exactly the same.

TL;DR I'd really prefer we avoid complex loopy logic for things that can be functional-style and should be reusable.

360

I'm sorry this is really hard to follow and I doubt this code will be maintainable..

qcolombet added inline comments.Sep 14 2022, 11:23 AM
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
268

Okay, let me take a closer look.

Again, unless I missed something, the saturated_arith constructs are not good in that case because IIUC it saturates on the first dynamic value, whereas I need to get the full product without the dynamic dimension.

E.g., for 2x?x3, I need 6, not saturated, because this number will let us compute the actual value of ? from the original dimension for that group.

I feel that saturated_arith is well suited to construct the type of a shape (e.g., a stride is dynamic as soon as one of the dimension is dynamic), but not so much to compute actual values.

360

What do you think would make the code more understandable / maintainable?

Essentially for every reassociation group we have at most one dynamic size.
What this code does is the computation of this dynamic size as: origSize / product(all other sizes in that group).

E.g., for ?x4 -> [0,1,2][3] 2x?x3x4
This will compute the ? of the resulting shape as an affine.apply of ()[s0] -> (s0 / (2 * 3)), where s0 is the original dynamic dimension for that group.

The productOfKnownStaticSize holds the product of the static sizes for the related group (2 * 3 in the example), then dynSizeInput hold the original size for that group.

Maybe I could split the loop between the dimensions that come before the dynamic dimension and the ones that come after. That way I would not have to "propagate" a fake dynamic size expr to keep the final strideExpr stable between the dynamic and non-dynamic case, which should help readability.

mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
268

Yes, I see that the saturated_arith itself is no use when you want to construct the AffineExpr, I just wanted to give you a feel of an API that avoids deep nesting and that will be easier to maintain longer term.

360
Essentially for every reassociation group we have at most one dynamic size.

No need to worry about this immediately but this will change (e.g. https://github.com/llvm/llvm-project/issues/51564)

Yes, less deeply nested fusion would def. be a good start, then we can iterate.

nicolasvasilache added a comment.EditedSep 15 2022, 12:00 AM

Surfacing a general comment that appeared ofline:

I think we can always consider the general case with all ? and let the folders do their job on OpFoldResult:

sizes<?>, strides<?> expand to sizes<?x..x?>, strides<?x..x?> by expansionFactors(?x...x?)

If we reduce to this, this is all about determining SmallVector<OpFoldResult> expansionFactors(...).
In the most general case (NYI), these will be SSA values in the op (see e.g. https://github.com/llvm/llvm-project/issues/51564).

But even in the absence of the general case, I believe we greatly simplify the problem by thinking in those terms.

I'd add a helper SmallVector<OpFoldResult> ExpandShapeOp:::buildExpansionFactors(b, loc, groupId) to the op itself that constructs expansionFactors from the source memref and the result type (actually may need to be in some utils because of Affine cyclic dependencies .. sigh).

We'd likely want a counterpart for CollapseShape when we get there.

I think these suggestions will make the impl. both future proof and readable.

Still I have not tried so I may be missing something.

  • Rework the code to make it more approachable:
    • Compute all the static sizes and strides, first
    • Adjust the strides impacted by the dynamic size, second

Next: Will look into the buildExpansionFactors idea.
I wanted to share this version since it was what was driving our offline conversation with Nicolas.

qcolombet marked an inline comment as done.Sep 15 2022, 10:31 AM
qcolombet marked an inline comment as not done.Sep 15 2022, 12:45 PM
qcolombet updated this revision to Diff 461427.Sep 19 2022, 5:25 PM
  • Move the sizes and strides computations in their own helper functions

The buildExpansionFactors didn't pan out as well as I would have hoped because the expressions are too different when you factor dynamic size and/or strides in the mix.

qcolombet marked an inline comment as done.Sep 19 2022, 5:26 PM
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
170

nit: typo

177

I would make this a first-class citizen of expand: SmallVector<OpFoldResult> memref::ExpandShapeOp::buildExpandedSizes(...).

Then as the op semantics evolve to take the result shape operands, we can update this.

178

ArrayRef ?

226

same for making this a first-class citizen of expand

227

ArrayRef ?

qcolombet added inline comments.Sep 21 2022, 1:06 PM
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
177

That's not possible because the MemRef dialect would depend on Affine dialect if we were to do that.

qcolombet updated this revision to Diff 461997.Sep 21 2022, 1:33 PM
  • Use ArrayRef instead of SmallVectorImpl
  • Fix typo
qcolombet marked 2 inline comments as done.Sep 21 2022, 1:35 PM
qcolombet added inline comments.
mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
177

For the record, my plan was to move that into a common place (e.g., ReshapeOpsUtils), when we think it looks good.

178

Good point!

Thanks for splitting it up this way, let's go with this for now

I will try to simplify some of the impl (at a tradeoff of adding a bit more IR that folds) as a separate PR but this is a great step towards removing logic from LLVM while also expanding the cases we can support.

mlir/lib/Dialect/MemRef/Transforms/SimplifyExtractStridedMetadata.cpp
177

Ah yes, I keep on forgetting this unfortunate layering .. could you please add a TODO?

chelini resigned from this revision.Sep 22 2022, 5:11 AM

Please go ahead with Nicolas. He has better ideas on how to bring this up.

qcolombet marked an inline comment as done.
  • Add a TODO for moving the utility functions to ExpandShapeOp
qcolombet marked 2 inline comments as done.Sep 22 2022, 10:59 AM
qcolombet added a subscriber: chelini.

Thanks again @chelini and @nicolasvasilache for the reviews!