This is an archive of the discontinued LLVM Phabricator instance.

[mlir] Revamp implementation of sub-byte load/store emulation.
ClosedPublic

Authored by mravishankar on Aug 16 2023, 3:23 PM.

Details

Summary

When handling sub-byte emulation, the sizes of the converted memrefs
also need to be updated (this was not done in the current
implementation). This adds the additional complexity of having to
linearize the memrefs as well. Consider a memref<3x3xi4> where the
i4 elements are packed. This has a overall size of 5 bytes (rounded
up to number of bytes). This can only be represented by a
memref<5xi8>. A memref<3x2xi8> would imply an implicit padding of
4 bits at the end of each row. So incorporate linearization into the
sub-byte load-store emulation.

This patch also updates some of the utility functions to make better
use of statically available information using OpFoldResult and
makeComposedFoldedAffineApplyOps.

Diff Detail

Event Timeline

mravishankar created this revision.Aug 16 2023, 3:23 PM
Herald added a project: Restricted Project. · View Herald Transcript
mravishankar requested review of this revision.Aug 16 2023, 3:23 PM

Add patterns to resolve extract_strided_metadata to the emulation patterns.

yzhang93 accepted this revision.Aug 16 2023, 4:40 PM

Thanks for refactoring the codes and fixing the bug! Overall looks good to me.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
242

Nit: Can you add some comments for the logic behind this?

This revision is now accepted and ready to land.Aug 16 2023, 4:40 PM
hanchung requested changes to this revision.Aug 16 2023, 5:19 PM
hanchung added inline comments.
mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
31–62

Can we just break it into two functions? Looking through the comment and usage, they are two methods with the same interface to me. Breaking it into two function and use them correctly helps readability a lot, and people won't be confused why std::ignore is used; they dont have to get back to the comments. The method name should just tell them what they get.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
231–233

mlir style nit: do not add braces for single statement.

236–238

ditto

265–267

nit: remove braces

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
56

hmm, should we pass the type to SmallVector? I'm surprised that it's working.

57–59

nit: remove braces

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
87–90

I don't follow this. I thought that we should do vector.load on a linearized based pointer (i.e., void*) with a linearized index?

This revision now requires changes to proceed.Aug 16 2023, 5:19 PM
mravishankar marked 5 inline comments as done.Aug 16 2023, 8:13 PM
mravishankar added inline comments.
mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
31–62

I think the std::ignore is unrelated. Its about whether the caller needs LinearizedMemRefInfo or linearizedIndices. Depending on the use case you need one or the other, and the logic to compute these are pretty much similar. So I'd rather not have a combinatorial explosion in the number of API entries, with different callers having to do pretty issue the same sequence of calls. There is no IR overhead since all of this is done through makeFoldedComposedAffineApplyOps.

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
56

I dont think we need to. Not sure I see the issue here. Its creating a new vector from indices and is resized if it empty.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
87–90

Yes, it is using the linearizedIndices to do the vector.load. Not sure I follow the question.

Address comments.

hanchung accepted this revision.Aug 17 2023, 11:16 AM
hanchung added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
64–68

style nit: use auto. dyn_cst and cast already spell the type.

139–140

Don't we need a check to avoid infinite loop? It's probably get converged after applying the pattern 10 times, but I think we still want to bail out if it's already converted?

if (op.getMemRefType() == convertedType)
  return failure();
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
87–90

I think I follow the logic now... My question was why it is adpator.getBase(), but not something related to stridedMetadata.getBaseBuffer(). It looks like you assume that all the sources should be flattened to 1D memref in the pass as well?

This revision is now accepted and ready to land.Aug 17 2023, 11:16 AM
mravishankar marked an inline comment as done.Aug 17 2023, 12:02 PM
mravishankar added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
139–140

I dont think we need to do that. If the memRefType is correct already, then the dialect conversion framework will not see the op as illegal and not even call the conversion pattern.

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
87–90

This is the type of the operand after the producer has been modified. The TypeConverter ensures that it is a linearized type (see the associated change in the type converter). So at this point the adaptor already has a linearized memref. The base buffer of the strided metadata does not have the offset of the memref included. We should be using the base + offset + linearizedindices. The adaptor.getBase() is already at base + offset.

Address minor comments, and add 0d memref test.

hanchung accepted this revision.Aug 17 2023, 12:54 PM
hanchung added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
139–140

Good point, I missed the mechanism about TypeConversion. Thanks for the explanation!

This revision was landed with ongoing or failed builds.Aug 17 2023, 1:28 PM
This revision was automatically updated to reflect the committed changes.
mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp