This is an archive of the discontinued LLVM Phabricator instance.

[mlir][memref] Make result normalization aware of the number symbols
ClosedPublic

Authored by Lewuathe on May 9 2023, 11:58 PM.

Details

Summary

Memref normalization fails to recognize the non-zero symbols used in the memref type itself with strided, offset information. It causes the crash with the type like memref<128x512xf32, strided<[?, ?], offset: ?>>. The original issue is here. https://github.com/llvm/llvm-project/issues/61345

Diff Detail

Event Timeline

Lewuathe created this revision.May 9 2023, 11:58 PM
Herald added a project: Restricted Project. · View Herald TranscriptMay 9 2023, 11:58 PM
Lewuathe requested review of this revision.May 9 2023, 11:58 PM
Lewuathe updated this revision to Diff 520920.May 9 2023, 11:58 PM

Apply format.

Lewuathe updated this revision to Diff 520921.May 10 2023, 12:01 AM

Check the normalization after cast operation.

@aartbik @rriddle @bondhugula Sorry for bothering you but could you take a look at this one when you get a chance?

@nicolasvasilache @springerm @dcaballe Sorry for bothering you time to time but could you review this one when you get a chance?

springerm added a comment.EditedJun 12 2023, 11:35 PM

Sorry, I'm not familiar with this part of the codebase. If we can just take the number of symbols from the memref type, why does this function even have the numSymbolicOperands parameter?

@springerm I'm not 100% sure, but normalizeMemRefType does not always get the number of symbols from the memref type itself. In the following case, we have to provide the number of symbols from the allocOp so that the caller should be able to determine how to get it.

MemRefType newMemRefType =
    normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size());

https://github.com/llvm/llvm-project/blob/d8562e27e05b90d8957d20444c724293ddf1ba0c/mlir/lib/Dialect/Affine/Utils/Utils.cpp#L1724

Lewuathe added a comment.EditedJun 14 2023, 9:25 PM

Can someone take a look at this one?

Looks like this issue is still present.
https://discourse.llvm.org/t/assertion-failed-when-lowering-simple-tensor-dilect-example/71336/3

Could someone review this patch?

@springerm I'm not 100% sure, but normalizeMemRefType does not always get the number of symbols from the memref type itself. In the following case, we have to provide the number of symbols from the allocOp so that the caller should be able to determine how to get it.

MemRefType newMemRefType =
    normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size());

https://github.com/llvm/llvm-project/blob/d8562e27e05b90d8957d20444c724293ddf1ba0c/mlir/lib/Dialect/Affine/Utils/Utils.cpp#L1724

It would be nice to get an answer about the intent for this.

I added: assert(allocOp->getSymbolOperands().size() == memrefType.getLayout().getAffineMap().getNumSymbols()); here and didn't hit any test failure.

The function was moved by this patch from transform dialect.

The original patch to create the function is here.

https://github.com/llvm/llvm-project/commit/76d07503f0c69f6632e6d8d4736e2a4cb4055a92

There is still no clue about the decision of the function interface.

https://reviews.llvm.org/D84490

@avarmapml @bondhugula Do you remember why normalizeMemRefType function has a second argument to get the number of symbols while we can extract that information from the given memrefType?

If there is no specific reason, I think we can omit the parameter getting numSymbolicOperands.

Lewuathe updated this revision to Diff 535248.Jun 27 2023, 11:48 PM

Omit the number of symbol operands from normalizeMemRefType.

mehdi_amini accepted this revision.Jun 28 2023, 8:18 AM

LGTM, since we didn't find a justification to keep it. It can always be reintroduced if someone needs this.

This revision is now accepted and ready to land.Jun 28 2023, 8:18 AM

@mehdi_amini Thank you so much for reviewing!