This is an archive of the discontinued LLVM Phabricator instance.

[mlir] Narrow bitwidth emulation for MemRef load
ClosedPublic

Authored by yzhang93 on May 25 2023, 4:41 PM.

Details

Summary

This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type
codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth
(e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the
pattern for int4 memref.load. memref.store pattern should be added in a seperate patch.

Diff Detail

Event Timeline

yzhang93 created this revision.May 25 2023, 4:41 PM
Herald added a project: Restricted Project. · View Herald TranscriptMay 25 2023, 4:41 PM
yzhang93 requested review of this revision.May 25 2023, 4:41 PM
hanchung retitled this revision from [mlir] Narrow bitwidth emulation for MemRef load r=hanchung to [mlir] Narrow bitwidth emulation for MemRef load.May 26 2023, 11:02 AM
mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h
27

I have a personal issue with the missing new line between the class and the namespace.

mlir/test/Dialect/MemRef/emulate-narrow-int.mlir
82

Newline please.

hanchung requested changes to this revision.EditedMay 26 2023, 11:45 AM

Nice work! I dropped few comments inline. :)

Inspired by IREE project patch, I found that some people would like to emulate f16 computation on f32 domain: https://github.com/openxla/iree/pull/13808

I think we can rename the EmulateNarrowInt to EmulateNarrowType or EmulateNarrowNumerics, and they can add such patterns to it later.

mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp
40–68

Can we move the conversion out of constructor? Users want to control the conversions themselves. I think we can move them to TestEmulateNarrowInt.cpp.

46–49

We don't tie the type conversion to targetWideInt. Instead, we leave the decision to users. The targetWideInt controls how we load a int4, but not the computation domain. Think the case that we want to use byte load for int4, but leave the computation on int4. We can force the test to convert int4 to int8, or control it with a flag (like int4-arith-bitwidth=8).

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp
90–94

I think we should check if converted MemRef type is as same as original type. The emulation is needed only if memref types mismatch.

The newResTy can be the same in the scenario of using byte load but operating on int4 domain. IMO, the working flow is:

  1. Check if the original memref type match converted memref type. If they mismatch, we need the emulation and apply the pattern.
  2. Load the value using converted memref type. (which is already implemented below)
  3. Cast the load value to newResTy if the types mismatch. In the scenario, a int8 value is loaded and we need to cast it to int4 (i.e., newResTy).
100

Can we add a comment to elaborate the core idea? And some comments about why/how we compute linearized_size, linearized_offset, and %stride#1.

%0 = memref.load %0[%v0][%v1] : memref<?x?xi4, strided<[?, ?], offset = ?>>

can be replaced with

%b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
%linearized_offset = /// %v0 * %stride#0 + %v1 * %stride#1
%linearized_size = /// %size0 * %size1
%linearized = memref.reinterpret_cast %b, offset = [%offset], sizes = [%linearized_size], strides = [%stride#1] 
%load = memref.load %linearized[%linearized_offset] : memref<?xi4, strided<?, offset = ?>>
131–132

style nit: names should be in camelCase, i.e., we should name them to linearizedOffset and linearizedSize.

https://llvm.org/docs/CodingStandards.html#name-types-functions-variables-and-enumerators-properly

141–142

style nit: use auto because XXX::get already spells the type.

https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable

200–202

IMO, this should be derived from targetBitwidth. If the bitwidth is less than targetBitwidth, we use targetBitwidth.

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp
67

I can't map the comment to populate functions, maybe we can just remove the comment... It's pretty clear (from the function name) that we populate the patterns to do narrow type emulation.

This revision now requires changes to proceed.May 26 2023, 11:45 AM
yzhang93 requested review of this revision.May 31 2023, 12:45 PM
yzhang93 updated this revision to Diff 527169.
yzhang93 marked 9 inline comments as done.May 31 2023, 1:05 PM

Nice work! I dropped few comments inline. :)

Inspired by IREE project patch, I found that some people would like to emulate f16 computation on f32 domain: https://github.com/openxla/iree/pull/13808

I think we can rename the EmulateNarrowInt to EmulateNarrowType or EmulateNarrowNumerics, and they can add such patterns to it later.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp
90–94

Thanks for your review and suggestions! I have modified the codes accordingly.

kuhar added inline comments.May 31 2023, 1:08 PM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp
89–94
yzhang93 updated this revision to Diff 527268.May 31 2023, 8:32 PM
yzhang93 marked an inline comment as not done.
kuhar added inline comments.Jun 1 2023, 7:44 AM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp
89–94

Also, we should return failure() when there was no rewrite

212–214

nit: when else is wrapped with braces so should the then body

yzhang93 updated this revision to Diff 527488.Jun 1 2023, 10:26 AM
yzhang93 marked 3 inline comments as done.
hanchung requested changes to this revision.Jun 1 2023, 3:05 PM

Overall looks good, just some nits. Please also consider renaming the pass name and file name, thanks!

mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h
17–19

please update the comment as well.

23

There are target bitwidth, one is for load/store emulation, and the other is for arith computation domain. This is the one related to load/store emulation, please update the variable name more concretely.

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
39–40

Please add a comment about "users need to add conversions about the computation domain of narrow types".

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp
90–94

I think we should if op.getMemRefType() is as same as typeConverter->convertType(op.getMemRefType()). If they are not the same, we need the emulation.

110–120

Can you format it in a better way? Adding spaces and new lines could help, IMO. maybe something like: https://github.com/llvm/llvm-project/blob/7f374b6902fad9caed41284a57d573abe9ada9d1/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h#L450-L476

181–183

It should be typeConverter->convertType(oldElementType).

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp
104–106

The naming is ambiguous (and mismatch between the name and flag), can we rename it to something like loadEmulationBitwidth?

This revision now requires changes to proceed.Jun 1 2023, 3:05 PM
yzhang93 requested review of this revision.Jun 1 2023, 10:59 PM
yzhang93 updated this revision to Diff 527745.
yzhang93 marked 8 inline comments as done.
hanchung accepted this revision.Jun 2 2023, 12:24 PM

LGTM if the comments are addressed.

mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp
10–24 ↗(On Diff #527745)

I think we can remove VectorOps.h from the includes.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
98–99 ↗(On Diff #527745)

we can remove the declaration of source and write it like

MemRefType sourceType = adaptor.getMemRefType();
103–105 ↗(On Diff #527745)

This can be simplified to op.getMemRefType().getElementType().getIntOrFloatBitWidth()

131 ↗(On Diff #527745)

I would just use adaptor.getMemref() here

189–191 ↗(On Diff #527745)

The trunci op is not needed if they have the same number of bits.

224–226 ↗(On Diff #527745)

can this be auto?

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
57–62 ↗(On Diff #527745)

the if-else already covers all the cases, this can be simplified.

71–79 ↗(On Diff #527745)

ditto, this can be simplified

This revision is now accepted and ready to land.Jun 2 2023, 12:24 PM
yzhang93 updated this revision to Diff 527998.Jun 2 2023, 2:50 PM
yzhang93 marked 8 inline comments as done.
yzhang93 added inline comments.Jun 2 2023, 2:52 PM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
98–99 ↗(On Diff #527745)

Looks like there's no member named 'getMemRefType' in 'mlir::memref::LoadOpAdaptor'. I'll keep the use of what I have.

just two nits

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
230–239 ↗(On Diff #527998)

I think we don't need else, this can save us one level of indents.

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp
71–79 ↗(On Diff #527745)

can we remove the else keyword? that would save us a level of indent. same for above one.

yzhang93 updated this revision to Diff 528020.Jun 2 2023, 3:25 PM
mravishankar requested changes to this revision.Jun 5 2023, 10:56 AM

Nice! Mostly looks good. Just a few comments.

mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h
23 ↗(On Diff #528020)

Make this private and add a getLoadStoreBitwidth method.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
29–39 ↗(On Diff #528020)

Language nit :

/// When data is loaded/stored in `targetBits` granularity, but is used in `sourceBits` granularity
/// (`sourceBits` < `targetBits`), the `targetBits` is treated as an array of elements of width `sourceBits`.
/// Return the bit offset of the value at position `srcIdx`. For example, if
/// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is
/// located at (x % 2) * 4. Because there are two elements in one i8, and one
/// element has 4 bits.
90 ↗(On Diff #528020)

Nit : For statements spanning multiple lines, still it is recommended to use braces.

105 ↗(On Diff #528020)

Instead of assert just return a failure

return notifyMatchFailure(op, "only dstBits %srcBits == 0 supported");
139 ↗(On Diff #528020)

I think this needs to happen on the linearizedOffset. Basically

  1. find the linearizedOffset.
  2. Divide by the scaling factor (which is dstBits / srcBits)
  3. Load the value.
  4. Get the offset in bits
184 ↗(On Diff #528020)

Note: This is only relevant for big-endian... Maybe add a comment somewhere that this is the only mode supported for now. Another robust option is to allow setting this in the TypeConverter, and assert that it is the endian-ness expected. Without that it can lead to subtle bugs.

192 ↗(On Diff #528020)

I am trying to understand when this case happens. The resultType

This revision now requires changes to proceed.Jun 5 2023, 10:56 AM
yzhang93 requested review of this revision.Jun 5 2023, 4:05 PM
yzhang93 updated this revision to Diff 528616.
yzhang93 marked 8 inline comments as done.
yzhang93 added inline comments.Jun 5 2023, 4:06 PM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
139 ↗(On Diff #528020)

@hanchung and I discussed about this before and we thought only the last index needs to be modified. However, I just rethink about this and I agree with you that the scaling needs to happen after the offset is linearized. @hanchung let me know if this makes sense to you.

192 ↗(On Diff #528020)

This happens when the load bitwidth and computation bitwidth are the same, e.g., when we specify --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8"

mravishankar requested changes to this revision.Jun 7 2023, 9:40 PM
mravishankar added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
192 ↗(On Diff #528020)

this looks like a premature optimization to me. If the result width and compute width is the same, then there should not be a need to do this.... If the compute width is higher, then the trunc and ext should be folded away as a canonicalization. In any case I actually dont see a test with the --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8". Maybe we just do

if (resultTy != srcElementType) {
   result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad);
 }
106 ↗(On Diff #528616)

Nit: This statement spans two lines. Please use braces.

153 ↗(On Diff #528616)

Two things here

  1. First not sure why you need to special case sourceRank == 1?
  2. I think this computation is very different from what was there before. The

linearizedOffset = (adjustedOffset[0] + adjustedOffset[1] + ...) * srcBits / dstBits
Here you seem to be dividing the scalar (= dstBits / srcBits) as many times as the sourceRank which seems off.

This revision now requires changes to proceed.Jun 7 2023, 9:40 PM
yzhang93 requested review of this revision.Jun 8 2023, 12:20 PM
yzhang93 updated this revision to Diff 529694.
yzhang93 marked an inline comment as done.
yzhang93 added inline comments.Jun 8 2023, 12:23 PM
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
192 ↗(On Diff #528020)

My idea behind this is if the emulated memref load bits and the computation bits are the same, e.g., 8 bits, and the actual load is 4 bits in the example. We'll need to return a 8 bits data but only the last 4 bits are the data we needed. So that's why I added a mask to make the first 4 bits zero, and only the last 4 bits are valid. I also added a test for the "arith-compute-bitwidth=8 memref-load-bitwidth=8" case. We can chat in detail if this doesn't make sense to you.

153 ↗(On Diff #528616)

My bad. Thanks for pointing this out.

mravishankar requested changes to this revision.Jun 8 2023, 2:54 PM

Thanks Vivian, there are a couple of more bugs in this patch... also left a suggestion to use makeComposedAffineApplyOp which will make the code and IR more readable.

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
152 ↗(On Diff #529694)

THanks for the changes. Now I understand better. I think I found an issue (sorry if it was triggered by a suggestion from me). This will segfault for zero-rank memrefs.

So this has to be

Value linearizedOffset = builder.create<arith::ConstantIndexOp>(loc, 0).;
Value linearizedSize = builder.create<arith::ConstantIndexOp>(loc, 1);
for (int i = 0; i < sourceRank; ++i) {
  linearizedOffset = rewriter.create<arith::AddIOp>(loc, linearizedOffset, adjustedOffsets[i]);
  linearizedSize = rewriter.create<arith::MulIOp>(loc, linearizedSize, baseSizes[i]);
}

Better yet... instead of creating all these ops we can use makeComposedAffineApplyOp

OpFoldResult linearizedOffset = rewriter.getIndexAttr(0);
OpFoldResult linearizedSize = rewriter.getIndexAttr(1);
AffineExpr s0, s1, s2;
bindSymbols(s0, s1, s2);
for (auto i : llvm::seq<int>(0, sourceRank)) {
  linearizedOffset = makeComposedAffineApplyOp(rewriter, loc, s0 + s1 * s2, {linearizedOffset, indices[i], baseStrides[i]);
  linearizedSize = makeComposedAffineApplyOp(rewriter, loc, s0 * s1, {linearizedSize, baseSizes[i]});
}
OpFoldResult scaler =rewriter.getIndexAttr(dstBits/srcBits);
linearizedOffset = makeComposedAffineApply(rewriter, loc, s0 floorDiv s1, {linearizedOffset, scaler});

Then you can get the Value for linearizedOffset/linearizedSize using getOrCreateConstantIndexOp.

This will fold away any statically known values, and will also make the code easier to read, the IR easier to read, while reducing index arithmetic overhead.

172 ↗(On Diff #529694)

If I am not mistaken, baseOffset also needs to be scaled.

192 ↗(On Diff #528020)

Maybe... But i cant think of a valid program where the load is in 4 bits but the use of it is directly in 8-bits....

This revision now requires changes to proceed.Jun 8 2023, 2:54 PM
yzhang93 requested review of this revision.Jun 13 2023, 10:30 AM
yzhang93 updated this revision to Diff 530979.
yzhang93 marked an inline comment as done.Jun 13 2023, 10:36 AM
yzhang93 added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
152 ↗(On Diff #529694)

Thanks for pointing out the zero-rank problem and I appreciate your suggestions.

I tried what you suggested with AffineApplyOp, but kept having this "error: failed to legalize operation 'memref.load' that was explicitly marked illegal %1 = memref.load %0[%arg0] : memref<4xi4>" on the test even with the simplest test. I'm not sure what caused the error, but if you know any potential issue and the way to fix it please let me know.

Currently I refactor the codes and add the case for sourceRank==0. I think we probably want to treat these cases separately, because when sourceRank==0 we don't need to do linearization with memref.reinterpret_cast op.

yzhang93 updated this revision to Diff 533814.Jun 22 2023, 4:38 PM
yzhang93 marked 2 inline comments as done.Jun 22 2023, 4:58 PM
yzhang93 added inline comments.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
152 ↗(On Diff #529694)

In the latest revision, I refactored the codes of linearization part with AffineApplyOp as suggested. I also added the conversion pattern for memref::AssumeAlignmentOp, as this is required for e2e test. The tests were updated accordingly.

mravishankar accepted this revision.Jun 23 2023, 6:35 PM

Thanks! This revision looks good!

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
61 ↗(On Diff #533814)

Nit: Avoid using auto here. It is only used in LLVM when the type is obvious from the context, and here it is not.

85 ↗(On Diff #533814)

Nit: avoid using the same variable in two contexts.

263 ↗(On Diff #533814)

I am still not sure about this one... Not really sure this actually happens in practice, but harmless enough. (Could you just leave a comment explaining it isnt clear that this is needed, or something to record this discussion.

This revision is now accepted and ready to land.Jun 23 2023, 6:35 PM
yzhang93 updated this revision to Diff 534651.Jun 26 2023, 10:56 AM
yzhang93 marked 6 inline comments as done.
hanchung accepted this revision.Jun 26 2023, 2:15 PM
This revision was automatically updated to reflect the committed changes.