Page MenuHomePhabricator

Please use GitHub pull requests for new patches. Phabricator shutdown timeline

[mlir] Narrow bitwidth emulation for MemRef load
ClosedPublic

Authored by yzhang93 on May 25 2023, 4:41 PM.

Details

Summary

This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type
codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth
(e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the
pattern for int4 memref.load. memref.store pattern should be added in a seperate patch.

Diff Detail

Event Timeline

There are a very large number of changes, so older changes are hidden. Show Older Changes

Overall looks good, just some nits. Please also consider renaming the pass name and file name, thanks!

mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h
16–18 ↗(On Diff #527488)

please update the comment as well.

22 ↗(On Diff #527488)

There are target bitwidth, one is for load/store emulation, and the other is for arith computation domain. This is the one related to load/store emulation, please update the variable name more concretely.

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
39–40

Please add a comment about "users need to add conversions about the computation domain of narrow types".

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp
89–93 ↗(On Diff #525876)

I think we should if op.getMemRefType() is as same as typeConverter->convertType(op.getMemRefType()). If they are not the same, we need the emulation.

109–119 ↗(On Diff #527488)

Can you format it in a better way? Adding spaces and new lines could help, IMO. maybe something like: https://github.com/llvm/llvm-project/blob/7f374b6902fad9caed41284a57d573abe9ada9d1/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h#L450-L476

180–182 ↗(On Diff #527488)

It should be typeConverter->convertType(oldElementType).

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp
103–105 ↗(On Diff #527488)

The naming is ambiguous (and mismatch between the name and flag), can we rename it to something like loadEmulationBitwidth?

This revision now requires changes to proceed.Jun 1 2023, 3:05 PM
yzhang93 requested review of this revision.Jun 1 2023, 10:59 PM
yzhang93 updated this revision to Diff 527745.
yzhang93 marked 8 inline comments as done.
hanchung accepted this revision.Jun 2 2023, 12:24 PM

LGTM if the comments are addressed.

mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
11–25

I think we can remove VectorOps.h from the includes.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
99–100

we can remove the declaration of source and write it like

MemRefType sourceType = adaptor.getMemRefType();
104–106

This can be simplified to op.getMemRefType().getElementType().getIntOrFloatBitWidth()

132

I would just use adaptor.getMemref() here

190–192

The trunci op is not needed if they have the same number of bits.

225–227

can this be auto?

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
58–63

the if-else already covers all the cases, this can be simplified.

72–80

ditto, this can be simplified

This revision is now accepted and ready to land.Jun 2 2023, 12:24 PM
yzhang93 updated this revision to Diff 527998.Jun 2 2023, 2:50 PM
yzhang93 marked 8 inline comments as done.
yzhang93 added inline comments.Jun 2 2023, 2:52 PM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
99–100

Looks like there's no member named 'getMemRefType' in 'mlir::memref::LoadOpAdaptor'. I'll keep the use of what I have.

just two nits

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
231–240

I think we don't need else, this can save us one level of indents.

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
72–80

can we remove the else keyword? that would save us a level of indent. same for above one.

yzhang93 updated this revision to Diff 528020.Jun 2 2023, 3:25 PM
mravishankar requested changes to this revision.Jun 5 2023, 10:56 AM

Nice! Mostly looks good. Just a few comments.

mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
24

Make this private and add a getLoadStoreBitwidth method.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
30–40

Language nit :

/// When data is loaded/stored in `targetBits` granularity, but is used in `sourceBits` granularity
/// (`sourceBits` < `targetBits`), the `targetBits` is treated as an array of elements of width `sourceBits`.
/// Return the bit offset of the value at position `srcIdx`. For example, if
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
/// element has 4 bits.
91

Nit : For statements spanning multiple lines, still it is recommended to use braces.

106

Instead of assert just return a failure

return notifyMatchFailure(op, "only dstBits %srcBits == 0 supported");
140

I think this needs to happen on the linearizedOffset. Basically

  1. find the linearizedOffset.
  2. Divide by the scaling factor (which is dstBits / srcBits)
  3. Load the value.
  4. Get the offset in bits
185

Note: This is only relevant for big-endian... Maybe add a comment somewhere that this is the only mode supported for now. Another robust option is to allow setting this in the TypeConverter, and assert that it is the endian-ness expected. Without that it can lead to subtle bugs.

193

I am trying to understand when this case happens. The resultType

This revision now requires changes to proceed.Jun 5 2023, 10:56 AM
yzhang93 requested review of this revision.Jun 5 2023, 4:05 PM
yzhang93 updated this revision to Diff 528616.
yzhang93 marked 8 inline comments as done.
yzhang93 added inline comments.Jun 5 2023, 4:06 PM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
140

@hanchung and I discussed about this before and we thought only the last index needs to be modified. However, I just rethink about this and I agree with you that the scaling needs to happen after the offset is linearized. @hanchung let me know if this makes sense to you.

193

This happens when the load bitwidth and computation bitwidth are the same, e.g., when we specify --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8"

mravishankar requested changes to this revision.Jun 7 2023, 9:40 PM
mravishankar added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
107

Nit: This statement spans two lines. Please use braces.

154

Two things here

  1. First not sure why you need to special case sourceRank == 1?
  2. I think this computation is very different from what was there before. The

linearizedOffset = (adjustedOffset[0] + adjustedOffset[1] + ...) * srcBits / dstBits
Here you seem to be dividing the scalar (= dstBits / srcBits) as many times as the sourceRank which seems off.

193

this looks like a premature optimization to me. If the result width and compute width is the same, then there should not be a need to do this.... If the compute width is higher, then the trunc and ext should be folded away as a canonicalization. In any case I actually dont see a test with the --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8". Maybe we just do

if (resultTy != srcElementType) {
   result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
 }
This revision now requires changes to proceed.Jun 7 2023, 9:40 PM
yzhang93 requested review of this revision.Jun 8 2023, 12:20 PM
yzhang93 updated this revision to Diff 529694.
yzhang93 marked an inline comment as done.
yzhang93 added inline comments.Jun 8 2023, 12:23 PM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
154

My bad. Thanks for pointing this out.

193

My idea behind this is if the emulated memref load bits and the computation bits are the same, e.g., 8 bits, and the actual load is 4 bits in the example. We'll need to return a 8 bits data but only the last 4 bits are the data we needed. So that's why I added a mask to make the first 4 bits zero, and only the last 4 bits are valid. I also added a test for the "arith-compute-bitwidth=8 memref-load-bitwidth=8" case. We can chat in detail if this doesn't make sense to you.

mravishankar requested changes to this revision.Jun 8 2023, 2:54 PM

Thanks Vivian, there are a couple of more bugs in this patch... also left a suggestion to use makeComposedAffineApplyOp which will make the code and IR more readable.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
153

THanks for the changes. Now I understand better. I think I found an issue (sorry if it was triggered by a suggestion from me). This will segfault for zero-rank memrefs.

So this has to be

Value linearizedOffset = builder.create<arith::ConstantIndexOp>(loc, 0).;
Value linearizedSize = builder.create<arith::ConstantIndexOp>(loc, 1);
for (int i = 0; i < sourceRank; ++i) {
  linearizedOffset = rewriter.create<arith::AddIOp>(loc, linearizedOffset, adjustedOffsets[i]);
  linearizedSize = rewriter.create<arith::MulIOp>(loc, linearizedSize, baseSizes[i]);
}

Better yet... instead of creating all these ops we can use makeComposedAffineApplyOp

OpFoldResult linearizedOffset = rewriter.getIndexAttr(0);
OpFoldResult linearizedSize = rewriter.getIndexAttr(1);
AffineExpr s0, s1, s2;
bindSymbols(s0, s1, s2);
for (auto i : llvm::seq<int>(0, sourceRank)) {
  linearizedOffset = makeComposedAffineApplyOp(rewriter, loc, s0 + s1 * s2, {linearizedOffset, indices[i], baseStrides[i]);
  linearizedSize = makeComposedAffineApplyOp(rewriter, loc, s0 * s1, {linearizedSize, baseSizes[i]});
}
OpFoldResult scaler =rewriter.getIndexAttr(dstBits/srcBits);
linearizedOffset = makeComposedAffineApply(rewriter, loc, s0 floorDiv s1, {linearizedOffset, scaler});

Then you can get the Value for linearizedOffset/linearizedSize using getOrCreateConstantIndexOp.

This will fold away any statically known values, and will also make the code easier to read, the IR easier to read, while reducing index arithmetic overhead.

173

If I am not mistaken, baseOffset also needs to be scaled.

193

Maybe... But i cant think of a valid program where the load is in 4 bits but the use of it is directly in 8-bits....

This revision now requires changes to proceed.Jun 8 2023, 2:54 PM
yzhang93 requested review of this revision.Jun 13 2023, 10:30 AM
yzhang93 updated this revision to Diff 530979.
yzhang93 marked an inline comment as done.Jun 13 2023, 10:36 AM
yzhang93 added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
153

Thanks for pointing out the zero-rank problem and I appreciate your suggestions.

I tried what you suggested with AffineApplyOp, but kept having this "error: failed to legalize operation 'memref.load' that was explicitly marked illegal %1 = memref.load %0[%arg0] : memref<4xi4>" on the test even with the simplest test. I'm not sure what caused the error, but if you know any potential issue and the way to fix it please let me know.

Currently I refactor the codes and add the case for sourceRank==0. I think we probably want to treat these cases separately, because when sourceRank==0 we don't need to do linearization with memref.reinterpret_cast op.

yzhang93 updated this revision to Diff 533814.Jun 22 2023, 4:38 PM
yzhang93 marked 2 inline comments as done.Jun 22 2023, 4:58 PM
yzhang93 added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
153

In the latest revision, I refactored the codes of linearization part with AffineApplyOp as suggested. I also added the conversion pattern for memref::AssumeAlignmentOp, as this is required for e2e test. The tests were updated accordingly.

mravishankar accepted this revision.Jun 23 2023, 6:35 PM

Thanks! This revision looks good!

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
62

Nit: Avoid using auto here. It is only used in LLVM when the type is obvious from the context, and here it is not.

86

Nit: avoid using the same variable in two contexts.

264

I am still not sure about this one... Not really sure this actually happens in practice, but harmless enough. (Could you just leave a comment explaining it isnt clear that this is needed, or something to record this discussion.

This revision is now accepted and ready to land.Jun 23 2023, 6:35 PM
yzhang93 updated this revision to Diff 534651.Jun 26 2023, 10:56 AM
yzhang93 marked 6 inline comments as done.
hanchung accepted this revision.Jun 26 2023, 2:15 PM
This revision was automatically updated to reflect the committed changes.