diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -157,6 +157,13 @@ LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Speical case for inserting scalar values into size-1 vectors. + if (insertOp.getSourceType().isIntOrFloat() && + insertOp.getDestVectorType().getNumElements() == 1) { + rewriter.replaceOp(insertOp, adaptor.source()); + return success(); + } + if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -61,6 +61,17 @@ // ----- +// CHECK-LABEL: @insert_size1_vector +// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32 +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] +// CHECK: return %[[R]] +func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> { + %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32> + return %1 : vector<1xf32> +} + +// ----- + // CHECK-LABEL: @extract_element // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32