diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1364,6 +1364,65 @@ } }; +//===----------------------------------------------------------------------===// +// VectorShuffleOp conversion +//===----------------------------------------------------------------------===// + +class VectorShufflePattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + LogicalResult + matchAndRewrite(spirv::VectorShuffleOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto components = adaptor.components(); + auto vector1 = adaptor.vector1(); + auto vector2 = adaptor.vector2(); + int vector1Size = vector1.getType().cast().getNumElements(); + int vector2Size = vector2.getType().cast().getNumElements(); + if (vector1Size == vector2Size) { + rewriter.replaceOpWithNewOp(op, vector1, vector2, + components); + return success(); + } + + auto dstType = typeConverter.convertType(op.getType()); + auto scalarType = dstType.cast().getElementType(); + auto componentsArray = components.getValue(); + auto context = rewriter.getContext(); + auto llvmI32Type = IntegerType::get(context, 32); + Value targetOp = rewriter.create(loc, dstType); + for (unsigned i = 0; i < componentsArray.size(); i++) { + if (componentsArray[i].isa()) + op.emitError("unable to support non-constant component"); + + int indexVal = componentsArray[i].cast().getInt(); + if (indexVal == -1) + continue; + + int offsetVal = 0; + Value baseVector = vector1; + if (indexVal >= vector1Size) { + offsetVal = vector1Size; + baseVector = vector2; + } + + Value dstIndex = rewriter.create( + loc, llvmI32Type, rewriter.getIntegerAttr(rewriter.getI32Type(), i)); + Value index = rewriter.create( + loc, llvmI32Type, + rewriter.getIntegerAttr(rewriter.getI32Type(), indexVal - offsetVal)); + + auto extractOp = rewriter.create( + loc, scalarType, baseVector, index); + targetOp = rewriter.create(loc, dstType, targetOp, + extractOp, dstIndex); + } + rewriter.replaceOp(op, targetOp); + return success(); + } +}; } // namespace //===----------------------------------------------------------------------===// @@ -1489,6 +1548,7 @@ CompositeExtractPattern, CompositeInsertPattern, DirectConversionPattern, DirectConversionPattern, + VectorShufflePattern, // Shift ops ShiftPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -58,6 +58,32 @@ spv.Return } +//===----------------------------------------------------------------------===// +// spv.VectorShuffle +//===----------------------------------------------------------------------===// + +spv.func @vector_shuffle_same_size(%vector1: vector<2xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" { + // CHECK: %[[res:.*]] = llvm.shufflevector {{.*}} [0 : i32, 2 : i32, -1 : i32] : vector<2xf32>, vector<2xf32> + // CHECK-NEXT: return %[[res]] : vector<3xf32> + %0 = spv.VectorShuffle [0: i32, 2: i32, 0xffffffff: i32] %vector1: vector<2xf32>, %vector2: vector<2xf32> -> vector<3xf32> + spv.ReturnValue %0: vector<3xf32> +} + +spv.func @vector_shuffle_different_size(%vector1: vector<3xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" { + // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : vector<3xf32> + // CHECK-NEXT: %[[C0_0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[C0_1:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[EXT0:.*]] = llvm.extractelement %arg0[%[[C0_1]] : i32] : vector<3xf32> + // CHECK-NEXT: %[[INSERT0:.*]] = llvm.insertelement %[[EXT0]], %[[UNDEF]][%[[C0_0]] : i32] : vector<3xf32> + // CHECK-NEXT: %[[C1_0:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[C1_1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[EXT1:.*]] = llvm.extractelement {{.*}}[%[[C1_1]] : i32] : vector<2xf32> + // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[EXT1]], %[[INSERT0]][%[[C1_0]] : i32] : vector<3xf32> + // CHECK-NEXT: llvm.return %[[RES]] : vector<3xf32> + %0 = spv.VectorShuffle [0: i32, 4: i32, 0xffffffff: i32] %vector1: vector<3xf32>, %vector2: vector<2xf32> -> vector<3xf32> + spv.ReturnValue %0: vector<3xf32> +} + //===----------------------------------------------------------------------===// // spv.EntryPoint and spv.ExecutionMode //===----------------------------------------------------------------------===//