This patch adds support for narrow bitwidth storage emulation. The goal is to support sub-byte type
codegen for LLVM CPU. Specifically, a type converter is added to convert memref of narrow bitwidth
(e.g., i4) into supported wider bitwidth (e.g., i8). Another focus of this patch is to populate the
pattern for int4 memref.load. memref.store pattern should be added in a seperate patch.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
The bigger issue is:
https://llvm.org/docs/CodingStandards.html#anonymous-namespaces
mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h | ||
---|---|---|
27 | I have a personal issue with the missing new line between the class and the namespace. | |
mlir/test/Dialect/MemRef/emulate-narrow-int.mlir | ||
82 | Newline please. |
Nice work! I dropped few comments inline. :)
Inspired by IREE project patch, I found that some people would like to emulate f16 computation on f32 domain: https://github.com/openxla/iree/pull/13808
I think we can rename the EmulateNarrowInt to EmulateNarrowType or EmulateNarrowNumerics, and they can add such patterns to it later.
mlir/lib/Dialect/Arith/Transforms/EmulateNarrowInt.cpp | ||
---|---|---|
40–68 | Can we move the conversion out of constructor? Users want to control the conversions themselves. I think we can move them to TestEmulateNarrowInt.cpp. | |
46–49 | We don't tie the type conversion to targetWideInt. Instead, we leave the decision to users. The targetWideInt controls how we load a int4, but not the computation domain. Think the case that we want to use byte load for int4, but leave the computation on int4. We can force the test to convert int4 to int8, or control it with a flag (like int4-arith-bitwidth=8). | |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp | ||
90–94 | I think we should check if converted MemRef type is as same as original type. The emulation is needed only if memref types mismatch. The newResTy can be the same in the scenario of using byte load but operating on int4 domain. IMO, the working flow is:
| |
100 | Can we add a comment to elaborate the core idea? And some comments about why/how we compute linearized_size, linearized_offset, and %stride#1. %0 = memref.load %0[%v0][%v1] : memref<?x?xi4, strided<[?, ?], offset = ?>> can be replaced with %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0 %linearized_offset = /// %v0 * %stride#0 + %v1 * %stride#1 %linearized_size = /// %size0 * %size1 %linearized = memref.reinterpret_cast %b, offset = [%offset], sizes = [%linearized_size], strides = [%stride#1] %load = memref.load %linearized[%linearized_offset] : memref<?xi4, strided<?, offset = ?>> | |
131–132 | style nit: names should be in camelCase, i.e., we should name them to linearizedOffset and linearizedSize. https://llvm.org/docs/CodingStandards.html#name-types-functions-variables-and-enumerators-properly | |
141–142 | style nit: use auto because XXX::get already spells the type. https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable | |
200–202 | IMO, this should be derived from targetBitwidth. If the bitwidth is less than targetBitwidth, we use targetBitwidth. | |
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp | ||
67 | I can't map the comment to populate functions, maybe we can just remove the comment... It's pretty clear (from the function name) that we populate the patterns to do narrow type emulation. |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp | ||
---|---|---|
90–94 | Thanks for your review and suggestions! I have modified the codes accordingly. |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp | ||
---|---|---|
90–95 | Same comment as in https://reviews.llvm.org/D151827 |
Overall looks good, just some nits. Please also consider renaming the pass name and file name, thanks!
mlir/include/mlir/Dialect/Arith/Transforms/NarrowIntEmulationConverter.h | ||
---|---|---|
16–18 | please update the comment as well. | |
22 | There are target bitwidth, one is for load/store emulation, and the other is for arith computation domain. This is the one related to load/store emulation, please update the variable name more concretely. | |
mlir/include/mlir/Dialect/Arith/Transforms/Passes.h | ||
39–40 | Please add a comment about "users need to add conversions about the computation domain of narrow types". | |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowInt.cpp | ||
90–94 | I think we should if op.getMemRefType() is as same as typeConverter->convertType(op.getMemRefType()). If they are not the same, we need the emulation. | |
109–119 | Can you format it in a better way? Adding spaces and new lines could help, IMO. maybe something like: https://github.com/llvm/llvm-project/blob/7f374b6902fad9caed41284a57d573abe9ada9d1/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h#L450-L476 | |
180–182 | It should be typeConverter->convertType(oldElementType). | |
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowInt.cpp | ||
103–105 | The naming is ambiguous (and mismatch between the name and flag), can we rename it to something like loadEmulationBitwidth? |
LGTM if the comments are addressed.
mlir/lib/Dialect/Arith/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
10–24 ↗ | (On Diff #527745) | I think we can remove VectorOps.h from the includes. |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
98–99 ↗ | (On Diff #527745) | we can remove the declaration of source and write it like MemRefType sourceType = adaptor.getMemRefType(); |
103–105 ↗ | (On Diff #527745) | This can be simplified to op.getMemRefType().getElementType().getIntOrFloatBitWidth() |
131 ↗ | (On Diff #527745) | I would just use adaptor.getMemref() here |
189–191 ↗ | (On Diff #527745) | The trunci op is not needed if they have the same number of bits. |
224–226 ↗ | (On Diff #527745) | can this be auto? |
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp | ||
57–62 ↗ | (On Diff #527745) | the if-else already covers all the cases, this can be simplified. |
71–79 ↗ | (On Diff #527745) | ditto, this can be simplified |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
98–99 ↗ | (On Diff #527745) | Looks like there's no member named 'getMemRefType' in 'mlir::memref::LoadOpAdaptor'. I'll keep the use of what I have. |
just two nits
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
230–239 ↗ | (On Diff #527998) | I think we don't need else, this can save us one level of indents. |
mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp | ||
71–79 ↗ | (On Diff #527745) | can we remove the else keyword? that would save us a level of indent. same for above one. |
Nice! Mostly looks good. Just a few comments.
mlir/include/mlir/Dialect/Arith/Transforms/NarrowTypeEmulationConverter.h | ||
---|---|---|
23 ↗ | (On Diff #528020) | Make this private and add a getLoadStoreBitwidth method. |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
29–39 ↗ | (On Diff #528020) | Language nit : /// When data is loaded/stored in `targetBits` granularity, but is used in `sourceBits` granularity /// (`sourceBits` < `targetBits`), the `targetBits` is treated as an array of elements of width `sourceBits`. /// Return the bit offset of the value at position `srcIdx`. For example, if /// `sourceBits` equals to 4 and `targetBits` equals to 8, the x-th element is /// located at (x % 2) * 4. Because there are two elements in one i8, and one /// element has 4 bits. |
90 ↗ | (On Diff #528020) | Nit : For statements spanning multiple lines, still it is recommended to use braces. |
105 ↗ | (On Diff #528020) | Instead of assert just return a failure return notifyMatchFailure(op, "only dstBits %srcBits == 0 supported"); |
139 ↗ | (On Diff #528020) | I think this needs to happen on the linearizedOffset. Basically
|
184 ↗ | (On Diff #528020) | Note: This is only relevant for big-endian... Maybe add a comment somewhere that this is the only mode supported for now. Another robust option is to allow setting this in the TypeConverter, and assert that it is the endian-ness expected. Without that it can lead to subtle bugs. |
192 ↗ | (On Diff #528020) | I am trying to understand when this case happens. The resultType |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
192 ↗ | (On Diff #528020) | this looks like a premature optimization to me. If the result width and compute width is the same, then there should not be a need to do this.... If the compute width is higher, then the trunc and ext should be folded away as a canonicalization. In any case I actually dont see a test with the --test-emulate-narrow-int="arith-compute-bitwidth=8 memref-load-bitwidth=8". Maybe we just do if (resultTy != srcElementType) { result = rewriter.create<arith::TruncIOp>(loc, resultTy, bitsLoad); } |
106 ↗ | (On Diff #528616) | Nit: This statement spans two lines. Please use braces. |
153 ↗ | (On Diff #528616) | Two things here
linearizedOffset = (adjustedOffset[0] + adjustedOffset[1] + ...) * srcBits / dstBits |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
192 ↗ | (On Diff #528020) | My idea behind this is if the emulated memref load bits and the computation bits are the same, e.g., 8 bits, and the actual load is 4 bits in the example. We'll need to return a 8 bits data but only the last 4 bits are the data we needed. So that's why I added a mask to make the first 4 bits zero, and only the last 4 bits are valid. I also added a test for the "arith-compute-bitwidth=8 memref-load-bitwidth=8" case. We can chat in detail if this doesn't make sense to you. |
153 ↗ | (On Diff #528616) | My bad. Thanks for pointing this out. |
Thanks Vivian, there are a couple of more bugs in this patch... also left a suggestion to use makeComposedAffineApplyOp which will make the code and IR more readable.
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
152 ↗ | (On Diff #529694) | THanks for the changes. Now I understand better. I think I found an issue (sorry if it was triggered by a suggestion from me). This will segfault for zero-rank memrefs. So this has to be Value linearizedOffset = builder.create<arith::ConstantIndexOp>(loc, 0).; Value linearizedSize = builder.create<arith::ConstantIndexOp>(loc, 1); for (int i = 0; i < sourceRank; ++i) { linearizedOffset = rewriter.create<arith::AddIOp>(loc, linearizedOffset, adjustedOffsets[i]); linearizedSize = rewriter.create<arith::MulIOp>(loc, linearizedSize, baseSizes[i]); } Better yet... instead of creating all these ops we can use makeComposedAffineApplyOp OpFoldResult linearizedOffset = rewriter.getIndexAttr(0); OpFoldResult linearizedSize = rewriter.getIndexAttr(1); AffineExpr s0, s1, s2; bindSymbols(s0, s1, s2); for (auto i : llvm::seq<int>(0, sourceRank)) { linearizedOffset = makeComposedAffineApplyOp(rewriter, loc, s0 + s1 * s2, {linearizedOffset, indices[i], baseStrides[i]); linearizedSize = makeComposedAffineApplyOp(rewriter, loc, s0 * s1, {linearizedSize, baseSizes[i]}); } OpFoldResult scaler =rewriter.getIndexAttr(dstBits/srcBits); linearizedOffset = makeComposedAffineApply(rewriter, loc, s0 floorDiv s1, {linearizedOffset, scaler}); Then you can get the Value for linearizedOffset/linearizedSize using getOrCreateConstantIndexOp. This will fold away any statically known values, and will also make the code easier to read, the IR easier to read, while reducing index arithmetic overhead. |
172 ↗ | (On Diff #529694) | If I am not mistaken, baseOffset also needs to be scaled. |
192 ↗ | (On Diff #528020) | Maybe... But i cant think of a valid program where the load is in 4 bits but the use of it is directly in 8-bits.... |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
152 ↗ | (On Diff #529694) | Thanks for pointing out the zero-rank problem and I appreciate your suggestions. I tried what you suggested with AffineApplyOp, but kept having this "error: failed to legalize operation 'memref.load' that was explicitly marked illegal %1 = memref.load %0[%arg0] : memref<4xi4>" on the test even with the simplest test. I'm not sure what caused the error, but if you know any potential issue and the way to fix it please let me know. Currently I refactor the codes and add the case for sourceRank==0. I think we probably want to treat these cases separately, because when sourceRank==0 we don't need to do linearization with memref.reinterpret_cast op. |
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
152 ↗ | (On Diff #529694) | In the latest revision, I refactored the codes of linearization part with AffineApplyOp as suggested. I also added the conversion pattern for memref::AssumeAlignmentOp, as this is required for e2e test. The tests were updated accordingly. |
Thanks! This revision looks good!
mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp | ||
---|---|---|
61 ↗ | (On Diff #533814) | Nit: Avoid using auto here. It is only used in LLVM when the type is obvious from the context, and here it is not. |
85 ↗ | (On Diff #533814) | Nit: avoid using the same variable in two contexts. |
263 ↗ | (On Diff #533814) | I am still not sure about this one... Not really sure this actually happens in practice, but harmless enough. (Could you just leave a comment explaining it isnt clear that this is needed, or something to record this discussion. |
please update the comment as well.