Relax the verifier for transfer_read/transfer_write operation so that it can take a memref with a different element type than the vector being read/written.
This is based on the discourse discussion:
https://llvm.discourse.group/t/memref-cast/1514
Details
Diff Detail
Event Timeline
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
1509–1510 | It looks like this TODO can be remove too now, since you removed the suffix part? |
also, with the restrictions out of the way, did you make sure this does not crash the lowering to llvm part (try out a few of your new examples). you may have to return failure() in the lowering logic itself if it breaks anything....
It looks like it mostly works out of the box. The translation to llvm does a pointer bitcast so it doesn't seem to rely on type matching. I added a test for conversion to LLVM.
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | ||
---|---|---|
956 | ah, this seems to be in the old style (note that our llvm type syntax has changed this morning) You will need to rebase this CL, otherwise you will get conflicts trying to submit this |
rebase
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir | ||
---|---|---|
956 | Thanks for the heads up. I rebased. |
We probably want to add some sort of sanity checking in later to make sure we don't do very strange bitcasts, but this is good for now to get the ball rolling.
The relaxing is too brutal IMO and will likely create multiple issues related to lowering to SCF and LLVM.
To reduce the pain in making this work with transformations I'd suggest to add restrictions on the type.
Basically you're only allowed to:
- take elements immediately to the left of your element type in the memref and stick these into the vector.
- cast to vector types that have the same number fo bits and a power of 2 size.
Anything else will be a significant amount of pain until we have a solid DataLayout in MLIR.
So basically you can allow the existing semantic + exactly 1 vector type_cast that changes vector type.
memref<axbxcxdxvector<txuxvxf32>> -> memref<axbx | cxdxvector<txuxvxf32>> -> memref<axbx vector<cxdxtxuxvxf32>> -> memref<axbx vector<whatever>> -> vector<whatever>, where:
- | is a separator that you can put anywhere on the memref (just like the current vector.type_cast)
- whatever has the same bitsize and alignment as cxdxtxuxvxf32.
To make 2. reasonably simple I'd say just allow the most minor of whatever and v x f32 that have a power of 2 size if you want to cast the element type.
I'd recommend very slowly and very carefully relaxing the constraints to match your needs without stepping out of bounds for now.
Those rules are not enough for the cases I'm trying to solve. For instance I want to support:
vector.transfer_read ... : memref<128x2xvector<4xi32>>, vector<16x32xi8> and
vector.transfer_read ... : memref<4096x1024xvector<4xi32>>, vector<16x16xi32>
vector.transfer_read ... : memref<32x8xvector<4xi32>>, vector<32x16xi8>
What I need is basically to be able to read as if the memref was scalar. One restriction I could add is to make sure that each dimension of the vector is smaller or equal to the associated dimension of the memref and for the last dimension that the size in bits is smaller or equal to the size in bits of the last dimension of the memref.
That only ensure that no dimension of the vector is bigger than the memref though.
Do you have any suggestions on the rules to pick to support those cases?
I added a restriction to enforce that the inner dimension of the destination vector is a multiple of the inner dimension of the memref element type. This way the transfer_read/transfer_write can always be lowered to several load of the memref element type.
Cool, just as we discussed offline.
Could you please make sure conversion to LLVM barfs when seeing this, we don't want to miscompile and we likely will right now.
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
1503–1504 | We need a better error message here and below. requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the minor 1-D vector of the memref ? | |
1521 | Please turn this into an assert into the ConvertVectorToLLVM pattern that wants to lower that to LLVM. |
mlir/lib/Dialect/Vector/VectorOps.cpp | ||
---|---|---|
1517–1518 | We need a better error message here and above. requires the bitwidth of the minor 1-D vector to be an integral multiple of the bitwidth of the memref element type ? |
We need a better error message here and below.
How about:
?