Add conversion pass for Vector dialect to SPIR-V dialect and add some simple conversion pattern for vector.broadcast, vector.insert, vector.extract.
Details
Diff Detail
Event Timeline
Nice! Thanks Thomas for starting this! I wanted to do this for ages. :)
mlir/include/mlir/Conversion/Passes.td | ||
---|---|---|
389 | s/SPIRV/SPIR-V/ | |
mlir/include/mlir/Conversion/VectorToSPIRV/ConvertVectorToSPIRVPass.h | ||
9 | SPIR-V | |
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | ||
37 | Should we check that the input is just a scalar and not some n-D vector? Vector in SPIR-V is constrained; we can generate invalid SPIR-V ops here. Similarly for the following two patterns. | |
85 | Use empty lines to separate the function and namespace. (Won't clang-format do that actually?) | |
mlir/test/Conversion/VectorToSPIRV/simple.mlir | ||
3 | CHECK-LABEL | |
15 | CHECK-LABEL |
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | ||
---|---|---|
37 | Vectors in SPIR-V can only be <= 4 elements by default. If the source is broadcasting a scalar into a, say, 128-element vector, it will create invalid SPIR-V. I think we also need to check the result vector is okay with spirv::CompositeType::isValid and otherwise return false. We can handle the large vector case later with type conversions. |
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | ||
---|---|---|
37 | I added check that the vector result is of size <= 4. Can the composite type not be valid for this case? |
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | ||
---|---|---|
37 |
spirv::CompositeType::isValid(VectorType) is a static function that wraps around the logic here and checks whether a general VectorType is allowed in SPIR-V: https://github.com/llvm/llvm-project/blob/master/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp#L174-L177. It can avoid us duplicating the logic multiple times. :) |
Use spirv::CompositeType::isValid
mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp | ||
---|---|---|
37 | My bad I should have looked at it. Makes sense, I replaced the checks with this function. |
s/SPIRV/SPIR-V/