Memref normalization fails to recognize the non-zero symbols used in the memref type itself with strided, offset information. It causes the crash with the type like memref<128x512xf32, strided<[?, ?], offset: ?>>. The original issue is here. https://github.com/llvm/llvm-project/issues/61345
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
@aartbik @rriddle @bondhugula Sorry for bothering you but could you take a look at this one when you get a chance?
@nicolasvasilache @springerm @dcaballe Sorry for bothering you time to time but could you review this one when you get a chance?
Sorry, I'm not familiar with this part of the codebase. If we can just take the number of symbols from the memref type, why does this function even have the numSymbolicOperands parameter?
@springerm I'm not 100% sure, but normalizeMemRefType does not always get the number of symbols from the memref type itself. In the following case, we have to provide the number of symbols from the allocOp so that the caller should be able to determine how to get it.
MemRefType newMemRefType = normalizeMemRefType(memrefType, allocOp->getSymbolOperands().size());
Can someone take a look at this one?
Looks like this issue is still present.
https://discourse.llvm.org/t/assertion-failed-when-lowering-simple-tensor-dilect-example/71336/3
@nicolasvasilache @springerm @dcaballe @aartbik @mehdi_amini @rriddle @bondhugula @antiagainst @stephenneuendorffer
Sorry for pinging time to time. Could someone review this patch when available?
It would be nice to get an answer about the intent for this.
I added: assert(allocOp->getSymbolOperands().size() == memrefType.getLayout().getAffineMap().getNumSymbols()); here and didn't hit any test failure.
The function was moved by this patch from transform dialect.
The original patch to create the function is here.
https://github.com/llvm/llvm-project/commit/76d07503f0c69f6632e6d8d4736e2a4cb4055a92
There is still no clue about the decision of the function interface.
https://reviews.llvm.org/D84490
@avarmapml @bondhugula Do you remember why normalizeMemRefType function has a second argument to get the number of symbols while we can extract that information from the given memrefType?
If there is no specific reason, I think we can omit the parameter getting numSymbolicOperands.
LGTM, since we didn't find a justification to keep it. It can always be reintroduced if someone needs this.