diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertSPIRVToLLVM.cpp @@ -555,6 +555,66 @@ } }; +/// Converts `spv.CompositeExtract` to `llvm.extractvalue` if the container type +/// is an aggregate type (struct or array). Otherwise, converts to +/// `llvm.extractelement` that operates on vectors. +class CompositeExtractPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = this->typeConverter.convertType(op.getType()); + if (!dstType) + return failure(); + + Type containerType = op.composite().getType(); + if (containerType.isa()) { + Location loc = op.getLoc(); + IntegerAttr value = op.indices()[0].cast(); + Value index = createI32ConstantOf(loc, rewriter, value.getInt()); + rewriter.replaceOpWithNewOp( + op, dstType, op.composite(), index); + return success(); + } + rewriter.replaceOpWithNewOp( + op, dstType, op.composite(), op.indices()); + return success(); + } +}; + +/// Converts `spv.CompositeInsert` to `llvm.insertvalue` if the container type +/// is an aggregate type (struct or array). Otherwise, converts to +/// `llvm.insertelement` that operates on vectors. +class CompositeInsertPattern + : public SPIRVToLLVMConversion { +public: + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + + LogicalResult + matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto dstType = this->typeConverter.convertType(op.getType()); + if (!dstType) + return failure(); + + Type containerType = op.composite().getType(); + if (containerType.isa()) { + Location loc = op.getLoc(); + IntegerAttr value = op.indices()[0].cast(); + Value index = createI32ConstantOf(loc, rewriter, value.getInt()); + rewriter.replaceOpWithNewOp( + op, dstType, op.composite(), op.object(), index); + return success(); + } + rewriter.replaceOpWithNewOp( + op, dstType, op.composite(), op.object(), op.indices()); + return success(); + } +}; + /// Converts SPIR-V operations that have straightforward LLVM equivalent /// into LLVM dialect operations. template @@ -1360,6 +1420,7 @@ VariablePattern, // Miscellaneous ops + CompositeExtractPattern, CompositeInsertPattern, DirectConversionPattern, DirectConversionPattern, diff --git a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir --- a/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir +++ b/mlir/test/Conversion/SPIRVToLLVM/misc-ops-to-llvm.mlir @@ -1,5 +1,43 @@ // RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s +//===----------------------------------------------------------------------===// +// spv.CompositeExtract +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @composite_extract_array +spv.func @composite_extract_array(%arg: !spv.array<4x!spv.array<4xf32>>) "None" { + // CHECK: llvm.extractvalue %{{.*}}[1 : i32, 3 : i32] : !llvm.array<4 x array<4 x float>> + %0 = spv.CompositeExtract %arg[1 : i32, 3 : i32] : !spv.array<4x!spv.array<4xf32>> + spv.Return +} + +// CHECK-LABEL: @composite_extract_vector +spv.func @composite_extract_vector(%arg: vector<3xf32>) "None" { + // CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 + // CHECK: llvm.extractelement %{{.*}}[%[[ZERO]] : !llvm.i32] : !llvm.vec<3 x float> + %0 = spv.CompositeExtract %arg[0 : i32] : vector<3xf32> + spv.Return +} + +//===----------------------------------------------------------------------===// +// spv.CompositeInsert +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: @composite_insert_struct +spv.func @composite_insert_struct(%arg0: i32, %arg1: !spv.struct>) "None" { + // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[1 : i32, 3 : i32] : !llvm.struct)> + %0 = spv.CompositeInsert %arg0, %arg1[1 : i32, 3 : i32] : i32 into !spv.struct> + spv.Return +} + +// CHECK-LABEL: @composite_insert_vector +spv.func @composite_insert_vector(%arg0: vector<3xf32>, %arg1: f32) "None" { + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 + // CHECK: llvm.insertelement %{{.*}}, %{{.*}}[%[[ONE]] : !llvm.i32] : !llvm.vec<3 x float> + %0 = spv.CompositeInsert %arg1, %arg0[1 : i32] : f32 into vector<3xf32> + spv.Return +} + //===----------------------------------------------------------------------===// // spv.Select //===----------------------------------------------------------------------===//