diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -157,6 +157,13 @@ LogicalResult matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + // Special case for inserting scalar values into size-1 vectors. + if (insertOp.getSourceType().isIntOrFloat() && + insertOp.getDestVectorType().getNumElements() == 1) { + rewriter.replaceOp(insertOp, adaptor.source()); + return success(); + } + if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); @@ -209,20 +216,23 @@ Value srcVector = adaptor.getOperands().front(); Value dstVector = adaptor.getOperands().back(); - // Insert scalar values not supported yet. - if (srcVector.getType().isa() || - dstVector.getType().isa()) - return failure(); - uint64_t stride = getFirstIntValue(insertOp.strides()); if (stride != 1) return failure(); + uint64_t offset = getFirstIntValue(insertOp.offsets()); + + if (srcVector.getType().isa()) { + assert(!dstVector.getType().isa()); + rewriter.replaceOpWithNewOp( + insertOp, dstVector.getType(), srcVector, dstVector, + rewriter.getI32ArrayAttr(offset)); + return success(); + } uint64_t totalSize = dstVector.getType().cast().getNumElements(); uint64_t insertSize = srcVector.getType().cast().getNumElements(); - uint64_t offset = getFirstIntValue(insertOp.offsets()); SmallVector indices(totalSize); std::iota(indices.begin(), indices.end(), 0); diff --git a/mlir/test/Conversion/VectorToSPIRV/simple.mlir b/mlir/test/Conversion/VectorToSPIRV/simple.mlir --- a/mlir/test/Conversion/VectorToSPIRV/simple.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/simple.mlir @@ -61,6 +61,17 @@ // ----- +// CHECK-LABEL: @insert_size1_vector +// CHECK-SAME: %[[V:.*]]: vector<1xf32>, %[[S:.*]]: f32 +// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[S]] +// CHECK: return %[[R]] +func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf32> { + %1 = vector.insert %arg1, %arg0[0] : f32 into vector<1xf32> + return %1 : vector<1xf32> +} + +// ----- + // CHECK-LABEL: @extract_element // CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32 // CHECK: spv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32 @@ -139,6 +150,17 @@ // ----- +// CHECK-LABEL: @insert_size1_vector +// CHECK-SAME: %[[SUB:.*]]: vector<1xf32>, %[[FULL:.*]]: vector<3xf32> +// CHECK: %[[S:.+]] = builtin.unrealized_conversion_cast %[[SUB]] +// CHECK: spv.CompositeInsert %[[S]], %[[FULL]][2 : i32] : f32 into vector<3xf32> +func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: vector<3xf32>) -> vector<3xf32> { + %1 = vector.insert_strided_slice %arg0, %arg1 {offsets = [2], strides = [1]} : vector<1xf32> into vector<3xf32> + return %1 : vector<3xf32> +} + +// ----- + // CHECK-LABEL: @fma // CHECK-SAME: %[[A:.*]]: vector<4xf32>, %[[B:.*]]: vector<4xf32>, %[[C:.*]]: vector<4xf32> // CHECK: spv.GLSL.Fma %[[A]], %[[B]], %[[C]] : vector<4xf32>