diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -18,7 +18,9 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/STLExtras.h" #include using namespace mlir; @@ -264,6 +266,43 @@ } }; +struct VectorShuffleOpConvert final + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ShuffleOp shuffleOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto oldResultType = shuffleOp.getVectorType(); + if (!spirv::CompositeType::isValid(oldResultType)) + return failure(); + auto newResultType = getTypeConverter()->convertType(oldResultType); + + auto oldSourceType = shuffleOp.getV1VectorType(); + if (oldSourceType.getNumElements() > 1) { + SmallVector components = llvm::to_vector<4>( + llvm::map_range(shuffleOp.mask(), [](Attribute attr) -> int32_t { + return attr.cast().getValue().getZExtValue(); + })); + rewriter.replaceOpWithNewOp( + shuffleOp, newResultType, adaptor.v1(), adaptor.v2(), + rewriter.getI32ArrayAttr(components)); + return success(); + } + + SmallVector oldOperands = {adaptor.v1(), adaptor.v2()}; + SmallVector newOperands; + newOperands.reserve(oldResultType.getNumElements()); + for (const APInt &i : shuffleOp.mask().getAsValueRange()) { + newOperands.push_back(oldOperands[i.getZExtValue()]); + } + rewriter.replaceOpWithNewOp( + shuffleOp, newResultType, newOperands); + + return success(); + } +}; + } // namespace void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, @@ -272,6 +311,6 @@ VectorExtractElementOpConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, VectorInsertElementOpConvert, VectorInsertOpConvert, - VectorInsertStridedSliceOpConvert, VectorSplatPattern>( - typeConverter, patterns.getContext()); + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorSplatPattern>(typeConverter, patterns.getContext()); } diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -179,3 +179,34 @@ %splat = vector.splat %f : vector<4xf32> return %splat : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32> +// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] +// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] +// CHECK: spv.CompositeConstruct %[[V0]], %[[V1]], %[[V1]], %[[V0]] : vector<4xf32> +func @shuffle(%v0 : vector<1xf32>, %v1: vector<1xf32>) -> vector<4xf32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1, 1, 0] : vector<1xf32>, vector<1xf32> + return %shuffle : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[V0:.+]]: vector<3xf32>, %[[V1:.+]]: vector<3xf32> +// CHECK: spv.VectorShuffle [3 : i32, 2 : i32, 5 : i32, 1 : i32] %[[V0]] : vector<3xf32>, %[[V1]] : vector<3xf32> -> vector<4xf32> +func @shuffle(%v0 : vector<3xf32>, %v1: vector<3xf32>) -> vector<4xf32> { + %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xf32>, vector<3xf32> + return %shuffle : vector<4xf32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +func @shuffle(%v0 : vector<2x16xf32>, %v1: vector<1x16xf32>) -> vector<3x16xf32> { + // CHECK: vector.shuffle + %shuffle = vector.shuffle %v0, %v1 [0, 1, 2] : vector<2x16xf32>, vector<1x16xf32> + return %shuffle : vector<3x16xf32> +}