Index: mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp =================================================================== --- mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -217,6 +217,22 @@ rewriter.getIntegerAttr(rewriter.getI32Type(), value)); } +/// If the value is in the map, return the value directly; otherwise, +/// create a LLVM dialect constant with the given value and return. +static Value getOrCreateI32ConstantOp(Location loc, PatternRewriter &rewriter, + unsigned value, + std::map &constantOps) { + if (constantOps.find(value) == constantOps.end()) { + auto newOp = rewriter.create( + loc, IntegerType::get(rewriter.getContext(), 32), + rewriter.getIntegerAttr(rewriter.getI32Type(), value)); + constantOps[value] = newOp; + return newOp; + } + + return constantOps[value]; +} + /// Utility for `spv.Load` and `spv.Store` conversion. static LogicalResult replaceWithLoadOrStore(Operation *op, ConversionPatternRewriter &rewriter, @@ -1360,6 +1376,63 @@ } }; +//===----------------------------------------------------------------------===// +// VectorShuffleOp conversion +//===----------------------------------------------------------------------===// + +class VectorShufflePattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + LogicalResult + matchAndRewrite(spirv::VectorShuffleOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto components = op.components(); + auto vector1 = op.vector1(); + auto vector2 = op.vector2(); + int vector1Size = vector1.getType().cast().getNumElements(); + int vector2Size = vector2.getType().cast().getNumElements(); + Value targetOp; + if (vector1Size == vector2Size) { + targetOp = rewriter.create(loc, vector1, vector2, + components); + } else { + auto dstType = typeConverter.convertType(op.getType()); + auto scalarType = typeConverter.convertType( + dstType.cast().getElementType()); + auto componentsArray = components.getValue(); + std::map constantOps; + targetOp = rewriter.create(loc, dstType); + for (unsigned i = 0; i < componentsArray.size(); i++) { + assert(componentsArray[i].isa() && + "non-constant component is not supported yet"); + int indexVal = componentsArray[i].cast().getInt(); + if (indexVal == -1) + continue; + + int offsetVal = 0; + Value baseVector = vector1; + if (indexVal >= vector1Size) { + offsetVal = vector1Size; + baseVector = vector2; + } + + Value dstIndex = + getOrCreateI32ConstantOp(loc, rewriter, i, constantOps); + Value index = getOrCreateI32ConstantOp( + loc, rewriter, indexVal - offsetVal, constantOps); + auto extractOp = rewriter.create( + loc, scalarType, baseVector, index); + targetOp = rewriter.create( + loc, dstType, targetOp, extractOp, dstIndex); + } + } + rewriter.replaceOp(op, targetOp); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -1485,6 +1558,7 @@ CompositeExtractPattern, CompositeInsertPattern, DirectConversionPattern, DirectConversionPattern, + VectorShufflePattern, // Shift ops ShiftPattern, Index: mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir =================================================================== --- mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -58,6 +58,30 @@ spv.Return } +//===----------------------------------------------------------------------===// +// spv.VectorShuffle +//===----------------------------------------------------------------------===// + +spv.func @vector_shuffle_same_size(%vector1: vector<2xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" { + // CHECK: %[[res:.*]] = llvm.shufflevector {{.*}} [0 : i32, 2 : i32, -1 : i32] : vector<2xf32>, vector<2xf32> + // CHECK-NEXT: return %[[res]] : vector<3xf32> + %0 = spv.VectorShuffle [0: i32, 2: i32, 0xffffffff: i32] %vector1: vector<2xf32>, %vector2: vector<2xf32> -> vector<3xf32> + spv.ReturnValue %0: vector<3xf32> +} + +spv.func @vector_shuffle_different_size(%vector1: vector<3xf32>, %vector2: vector<2xf32>) -> vector<3xf32> "None" { + // CHECK: %[[UNDEF:.*]] = llvm.mlir.undef : vector<3xf32> + // CHECK-NEXT: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK-NEXT: %[[EXT0:.*]] = llvm.extractelement %arg0[%[[C0]] : i32] : vector<3xf32> + // CHECK-NEXT: %[[INSERT0:.*]] = llvm.insertelement %[[EXT0]], %[[UNDEF]][%[[C0]] : i32] : vector<3xf32> + // CHECK-NEXT: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK-NEXT: %[[EXT1:.*]] = llvm.extractelement {{.*}}[%[[C1]] : i32] : vector<2xf32> + // CHECK-NEXT: %[[RES:.*]] = llvm.insertelement %[[EXT1]], %[[INSERT0]][%[[C1]] : i32] : vector<3xf32> + // CHECK-NEXT: llvm.return %[[RES]] : vector<3xf32> + %0 = spv.VectorShuffle [0: i32, 4: i32, 0xffffffff: i32] %vector1: vector<3xf32>, %vector2: vector<2xf32> -> vector<3xf32> + spv.ReturnValue %0: vector<3xf32> +} + //===----------------------------------------------------------------------===// // spv.EntryPoint and spv.ExecutionMode //===----------------------------------------------------------------------===//