diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -321,13 +321,18 @@ LogicalResult matchAndRewrite(vector::SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - VectorType dstVecType = op.getType(); - if (!spirv::CompositeType::isValid(dstVecType)) + Type dstType = getTypeConverter()->convertType(op.getType()); + if (!dstType) return failure(); - SmallVector source(dstVecType.getNumElements(), - adaptor.getInput()); - rewriter.replaceOpWithNewOp(op, dstVecType, - source); + if (dstType.isa()) { + rewriter.replaceOp(op, adaptor.getInput()); + } else { + auto dstVecType = dstType.cast(); + SmallVector source(dstVecType.getNumElements(), + adaptor.getInput()); + rewriter.replaceOpWithNewOp(op, dstType, + source); + } return success(); } }; diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -182,6 +182,17 @@ // ----- +// CHECK-LABEL: func @splat_size1_vector +// CHECK-SAME: (%[[A:.+]]: f32) +// CHECK: %[[VAL:.+]] = builtin.unrealized_conversion_cast %[[A]] +// CHECK: return %[[VAL]] +func.func @splat_size1_vector(%f : f32) -> vector<1xf32> { + %splat = vector.splat %f : vector<1xf32> + return %splat : vector<1xf32> +} + +// ----- + // CHECK-LABEL: func @shuffle // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32> // CHECK: %[[V0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]