The current implementation in SPIRVTypeConverter just unconditionally turns
everything into 32-bit if it doesn't meet the requirements of extensions or
capabilities. In this case, we can load a 32-bit value and then do bit
extraction to get the value.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
I found that there is a bug in the patch, please wait me to fix it before review, thanks!
| mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | ||
|---|---|---|
| 103 | Couple of things here 
 | |
| 123 | This too could be generalized to handle any target integer width. | |
| 653 | This will assert if this is not an integer. So it might be better to have a different pattern for load stores when the memref is integer type. So one pattern will implement this logic for integer type load/stores. Another pattern will be generic that will be type agnostic (and will return failure for integer types to not intersect with the other pattern) | |
| 667 | Could we add a smarter logic here. We can try to find the "next highest power of 2" that is legal and use that instead. | |
Isn't this the kind of legalization that can be made on the std dialect itself as a pre-pass before the conversion to SPIRV? That would make all this logic reusable.
Awesome, thanks Hanhan for taking on this! Sorry for a lot of comments; but this is type availability in SPIR-V is quite nuanced. :)
| mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | ||
|---|---|---|
| 103 | Nit: s/bits/elementBits/ | |
| 103 | Do we need to pass in the op here? I think we just need the location and the last index? That way this function can be clearer that it is just adjusting an index into 32-bit arrays into another index into bits-bit arrays. | |
| 104 | Just use normal OpBuilder? | |
| 121 | Assert in the function regarding 1-D array? | |
| 124 | Nit: s/bits/elementBits/? | |
| 125 | Nit: this can just be normal OpBuilder? | |
| 129 | assert llvm::isPowerOf2_32(bits)? | |
| 131 | auto indices = llvm::to_vector<4>(op.indices())? | |
| 653 | +1 to having separate patterns and reject not-handled cases early. It's okay to just implement integer for now and add others gradually. | |
| 654 | The type conversion must factor in the storage class, which is carried as the memref memory space. This affects the converted element type. For example, if StorageBuffer16BitAccess is available then 16-bit integers in storage buffer class (which right now mapped to memory space 0) does not need conversion. If we only consider the element type here it can be wrong because as long as Int16 is not available, we will convert 16-bit integers to 32-bit. So here we should convert the whole memref type and then get the element type. | |
| 662 | Just directly update result instead creating this local variable? | |
| mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir | ||
| 625 | This needs to be updated: // Check that access chain indices are properly adjusted if non-32-bit types are emulated via 32-bit types. | |
| 635 | What about creating separate functions for each type so that we have more focused and easier-to-read tests? | |
| 635 | We will need tests with StorageBuffer16BitAcess/etc. capability. | |
| 637 | I think we want to check the index calculation in detail for at least one of the case here given it's the crucial part of the adjusting. For others we might be able to just check the op name. | |
Good question! But whether to do a specific type conversion is determined by the SPIR-V target environment and it can be quite nuanced. For example, if we only have StorageBuffer16Acesss capability then memref<i16, 0> will be fine but memref<i8,0>/memref<i16, 4>/etc. needs to be adjusted. There are many other similar capabilities like UniformAndStorageBuffer16BitAccess, *8BitAccess, {Int|Float}{8|16|64}, etc. This kind of information is only available when converting to SPIR-V and hide behind SPIRVTypeConverter. If this is to be implemented as a pre-pass operating on standard types, it's not quite clear to me how to solve the phase-ordering issue and rope the configuration there.
But regarding code reuse, I guess we might be able to extract some of the index adjusting logic out and change them to templated ones so one can also plug in std and other dialect ops to reuse.
| mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | ||
|---|---|---|
| 129 | I make it to handle any target integer width, add assert(targetBits % elementBits == 0); in the beginning. | |
| 667 | This depends on how type converter handles it. In this case, I followed your suggestion to generalize it -- making all 32-bit to convertedBit. Thus, if the typeConverter does try to find the "next highest power of 2", it will still work. | |
| mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir | ||
| 635 | That's what I think after sent out for review...I planned to fix it in a later rivision. | |
| 635 | I added one more test, although the exts and caps are more than I expected. Please let me know if this isn't the case you'd like me to add. Thanks! | |
Yes, as Lei said, the information of available integer width is hidden in SPIRVTypeConverter. I think some of code reuse could be like having a method loadAndCast where it would rewrite a std load to loading a elementBits element and applying a shift and an and mask. It'd be great if there are other targets need this, so we can think more about how to reuse it.
| mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | ||
|---|---|---|
| 112 | idx and elemBitsValue seem to be the same... | |
| 117 | This comment is hard to parse. Maybe more descriptive will help. Something along the lines Based on the extension/capabilities, certain integer bitwidths (targetBits) might not be supported. During conversion if a memref of an unsupported type is used, load/stores to this memref need to be modified to use a supported higher bitwidth (elementBits) and extracting the required bits. For a accessing a 1D array (spv.array or spv.rt_array), the last index is modified to load the bits needed. The extraction of the actual bits needed are handled separately. | |
| 135 | This probably needs some explanation. If the accesschain is created while lowering a zero-rank memref, you have only one element in indices. You are just changing the element type here. This is still valid cause the host side would have to use the same bitwidth to store the scalar (Even though it needs lesser bitwidth). | |
| 139 | use builder.replaceOpWithNewOp. I am assuming the older accesschain operation is dead and needs to be deleted. | |
Nice! I just have a few more nits.
| mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | ||
|---|---|---|
| 100 | What about: Assuming index is an index into a 1-D array with each element having sourceBits, returns the adjusted index by treating the 1-D array as having elements of targetBits? This means renaming lastDim to index. | |
| 104 | Sorry for nitpicking again, but with targetBits, it's better to call elementBits as sourceBits then. ;) Simlarly for the next function. | |
| 104 | What about naming it as adjust1DArrayIndexForBitwidth? It's nothing special to integer anymore. | |
| 122 | What about naming it as adjustAccessChainForBitwidth? | |
| mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir | ||
| 654 | It would be nice to test a 1-D memref here and with a index coming as function parameter. | |
Address comments. Also found that scalar is not the case, so we can remove some checks and make the logic simpler.
| mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | ||
|---|---|---|
| 100 | I think the method is not to adjust the index. Instead, it's calculating the offset of value from loaded value. When accessing the value from target 1-D array, multiple values are loaded in the same time. In this context, the method returns the offset where the srcIdx locates in the value. In the example, it's (x % 4) * 8, not (x % 4). I add more comments here, please take a look. | |
| 104 | Yes, sourceBits is better. thanks! | |
| 112 | Good catch, thanks! | |
| 135 | I just found that there is no scalar case here because getElementPtr() always linearize the buffer. If it's a scalar, we still turn it to a 1D array. | |
| 139 | The method looks more like returning an adjusted ptr to me, so we can focus more on how to build the ptr. I think keeping the replacement logic in the matchAndRewrite method is better. In this use case, we don't want it to be destroyed immediately because we still need some information from it later. | |
Awesome, thanks Hanhan!
| mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp | ||
|---|---|---|
| 100 | Oh yeah good point. :) | |
What about:
Assuming index is an index into a 1-D array with each element having sourceBits, returns the adjusted index by treating the 1-D array as having elements of targetBits?
This means renaming lastDim to index.