diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -237,6 +237,40 @@ return success(); } +/// Utility for `spv.CompositeExtractOp` and `spv.CompositeInsertOp` vector type +/// conversion. +static LogicalResult replaceWithExtractOrInsertElement( + Operation *op, Type dstType, ArrayAttr indices, PatternRewriter &rewriter) { + IntegerAttr value = indices[0].cast(); + Value index = createI32ConstantOf(op->getLoc(), rewriter, value.getInt()); + if (auto extractOp = dyn_cast(op)) { + rewriter.replaceOpWithNewOp( + extractOp, dstType, extractOp.composite(), index); + return success(); + } + auto insertOp = dyn_cast(op); + rewriter.replaceOpWithNewOp( + insertOp, dstType, insertOp.composite(), insertOp.object(), index); + return success(); +} + +/// Utility for `spv.CompositeExtractOp` and `spv.CompositeInsertOp` aggregate +/// type conversion. +static LogicalResult +replaceWithExtractOrInsertValue(Operation *op, Type dstType, + PatternRewriter &rewriter) { + if (auto extractOp = dyn_cast(op)) { + rewriter.replaceOpWithNewOp( + extractOp, dstType, extractOp.composite(), extractOp.indices()); + return success(); + } + auto insertOp = dyn_cast(op); + rewriter.replaceOpWithNewOp( + insertOp, dstType, insertOp.composite(), insertOp.object(), + insertOp.indices()); + return success(); +} + //===----------------------------------------------------------------------===// // Type conversion //===----------------------------------------------------------------------===// @@ -555,6 +589,32 @@ } }; +/// Converts `spv.CompositeExtract`/`spv.CompositeInsert` to `llvm.extractvalue` +/// or `llvm.insertvalue` if the container type is an aggreagate type (struct or +/// array). Otherwise, converts to `llvm.extractelement` or `llvm.insertelement` +/// that operate on vectors. +template +class CompositePattern : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(SPIRVOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = this->typeConverter.convertType(op.getType()); + if (!dstType) + return failure(); + + Type containerType = op.composite().getType(); + if (containerType.isa()) { + replaceWithExtractOrInsertElement(op, dstType, op.indices(), rewriter); + return success(); + } + replaceWithExtractOrInsertValue(op, dstType, rewriter); + return success(); + } +}; + /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template @@ -1360,6 +1420,8 @@ VariablePattern, // Miscellaneous ops + CompositePattern, + CompositePattern, DirectConversionPattern, DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -1,5 +1,43 @@ // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.CompositeExtract +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @composite_extract_array +spv.func @composite_extract_array(%arg: !spv.array<4x!spv.array<4xf32>>) "None" { + // CHECK: llvm.extractvalue %{{.*}}[1 : i32, 3 : i32] : !llvm.array<4 x array<4 x float>> + %0 = spv.CompositeExtract %arg[1 : i32, 3 : i32] : !spv.array<4x!spv.array<4xf32>> + spv.Return +} + +// CHECK-LABEL: @composite_extract_vector +spv.func @composite_extract_vector(%arg: vector<3xf32>) "None" { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: llvm.extractelement %{{.*}}[%[[ZERO]] : !llvm.i32] : !llvm.vec<3 x float> + %0 = spv.CompositeExtract %arg[0 : i32] : vector<3xf32> + spv.Return +} + +//===----------------------------------------------------------------------===// +// spv.CompositeInsert +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @composite_insert_struct +spv.func @composite_insert_struct(%arg0: i32, %arg1: !spv.struct>) "None" { + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1 : i32, 3 : i32] : !llvm.struct)> + %0 = spv.CompositeInsert %arg0, %arg1[1 : i32, 3 : i32] : i32 into !spv.struct> + spv.Return +} + +// CHECK-LABEL: @composite_insert_vector +spv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%[[ONE]] : !llvm.i32] : !llvm.vec<3 x float> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32] : f32 into vector<3xf32> + spv.Return +} + //===----------------------------------------------------------------------===// // spv.Select //===----------------------------------------------------------------------===//