For 0-D as well as 1-D vectors, both these patterns should
return a failure as there is no need to collapse the shape
of the source. Currently, only 1-D vectors were handled. This
patch handles the 0-D case as well.
Details
Diff Detail
- Repository
- rG LLVM Github Monorepo
Event Timeline
@Benoit could you take a look as I believe you added those patterns. Why doesn't the memref has to be rank 1 for the pattern to skip? Do you know why the unit tests above are disaled?
Better ask @nicolasvasilache ! In https://reviews.llvm.org/D114993 I had originally written a larger number of tests. Nicolas took over and disabled/removed tests.
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | ||
---|---|---|
376–379 | If the vector rank is 0 and the source rank is 1, then the baseline code (before this patch) works. This is only an issue when mlir::AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t>, ArrayRef<mlir::AffineExpr>, mlir::MLIRContext *): Assertion `!sizes.empty() && !exprs.empty() && "expected non-empty sizes and exprs"' failed. | |
mlir/test/Dialect/Vector/vector-transfer-flatten.mlir | ||
60 ↗ | (On Diff #406648) | yes, for example if the memref is 1d (memref<8xi8> vs memref<i8>), then it works even though the vector is rank 0 in both cases. |
mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp | ||
---|---|---|
376–379 | In the case that the source rank is 0, I guess we don't need to enforce the vector rank since vector.transfer_read/write semantics enforce that the vector rank will also have to be 0D. So do you want to just check for sourceType.getRank() == 0? Or did you have another idea? |
Good question, I think I just wrote this condition without really thinking about it.
I made the edit, rebuilt and reran, and so I can confirm, that it does not make a difference for my own usage patterns. Specifically: this diff does not matter to me:
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -373,7 +373,7 @@ class FlattenContiguousRowMajorTransferReadPattern // Contiguity check is valid on tensors only. if (!sourceType) return failure(); - if (vectorType.getRank() == 1 && sourceType.getRank() == 1) + if (vectorType.getRank() == 1)
Ok then I think we should change the condition to vectorType.getRank() <=1 for simplicity
Shouldn't the condition be vectorType.getRank() <= 1?