diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -27,7 +27,10 @@ #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/Support/FormatVariadic.h" +#include +#include #include using namespace mlir; @@ -444,24 +447,48 @@ return rewriter.notifyMatchFailure(shuffleOp, "unsupported result vector type"); - auto oldSourceType = shuffleOp.getV1VectorType(); - if (oldSourceType.getNumElements() > 1) { - SmallVector components = llvm::to_vector<4>( - llvm::map_range(shuffleOp.getMask(), [](Attribute attr) -> int32_t { - return cast(attr).getValue().getZExtValue(); - })); + SmallVector mask = llvm::map_to_vector<4>( + shuffleOp.getMask(), [](Attribute attr) -> int32_t { + return cast(attr).getValue().getZExtValue(); + }); + + auto oldV1Type = shuffleOp.getV1VectorType(); + auto oldV2Type = shuffleOp.getV2VectorType(); + + // When both operands are SPIR-V vectors, emit a SPIR-V shuffle. + if (oldV1Type.getNumElements() > 1 && oldV2Type.getNumElements() > 1) { rewriter.replaceOpWithNewOp( shuffleOp, newResultType, adaptor.getV1(), adaptor.getV2(), - rewriter.getI32ArrayAttr(components)); + rewriter.getI32ArrayAttr(mask)); return success(); } - SmallVector oldOperands = {adaptor.getV1(), adaptor.getV2()}; - SmallVector newOperands; - newOperands.reserve(oldResultType.getNumElements()); - for (const APInt &i : shuffleOp.getMask().getAsValueRange()) { - newOperands.push_back(oldOperands[i.getZExtValue()]); + // When at least one of the operands becomes a scalar after type conversion + // for SPIR-V, extract all the required elements and construct the result + // vector. + auto getElementAtIdx = [&rewriter, loc = shuffleOp.getLoc()]( + Value scalarOrVec, int32_t idx) -> Value { + if (auto vecTy = dyn_cast(scalarOrVec.getType())) + return rewriter.create(loc, scalarOrVec, + idx); + + assert(idx == 0 && "Invalid scalar element index"); + return scalarOrVec; + }; + + int32_t numV1Elems = oldV1Type.getNumElements(); + SmallVector newOperands(mask.size()); + for (auto [shuffleIdx, newOperand] : llvm::zip_equal(mask, newOperands)) { + Value vec = adaptor.getV1(); + int32_t elementIdx = shuffleIdx; + if (elementIdx >= numV1Elems) { + vec = adaptor.getV2(); + elementIdx -= numV1Elems; + } + + newOperand = getElementAtIdx(vec, elementIdx); } + rewriter.replaceOpWithNewOp( shuffleOp, newResultType, newOperands); diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -442,6 +442,47 @@ // ----- +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<3xi32> +// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xi32> to i32 +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[ARG1]][1 : i32] : vector<3xi32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[ARG1]][2 : i32] : vector<3xi32> +// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[V0]], %[[S1]], %[[S2]] : (i32, i32, i32) -> vector<3xi32> +// CHECK: return %[[RES]] +func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<3xi32>) -> vector<3xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<1xi32>, vector<3xi32> + return %shuffle : vector<3xi32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<3xi32>, %[[ARG1:.+]]: vector<1xi32> +// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xi32> to i32 +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<3xi32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[ARG0]][2 : i32] : vector<3xi32> +// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[S0]], %[[S1]], %[[V1]] : (i32, i32, i32) -> vector<3xi32> +// CHECK: return %[[RES]] +func.func @shuffle(%v0 : vector<3xi32>, %v1: vector<1xi32>) -> vector<3xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 2, 3] : vector<3xi32>, vector<1xi32> + return %shuffle : vector<3xi32> +} + +// ----- + +// CHECK-LABEL: func @shuffle +// CHECK-SAME: %[[ARG0:.+]]: vector<1xi32>, %[[ARG1:.+]]: vector<1xi32> +// CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xi32> to i32 +// CHECK: %[[V1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xi32> to i32 +// CHECK: %[[RES:.+]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (i32, i32) -> vector<2xi32> +// CHECK: return %[[RES]] +func.func @shuffle(%v0 : vector<1xi32>, %v1: vector<1xi32>) -> vector<2xi32> { + %shuffle = vector.shuffle %v0, %v1 [0, 1] : vector<1xi32>, vector<1xi32> + return %shuffle : vector<2xi32> +} + +// ----- + // CHECK-LABEL: func @reduction_add // CHECK-SAME: (%[[V:.+]]: vector<4xi32>) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>