diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -50,6 +50,34 @@ namespace { +struct VectorShapeCast final : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ShapeCastOp shapeCastOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Type dstType = getTypeConverter()->convertType(shapeCastOp.getType()); + if (!dstType) + return failure(); + + // If dstType is same as the source type. This is no-op and can be directly + // replaced. + if (dstType == adaptor.getSource().getType()) { + rewriter.replaceOp(shapeCastOp, adaptor.getSource()); + return success(); + } + + // Special case for shape-casting size-1 vectors. + if (shapeCastOp.getResultVectorType().getNumElements() == 1) { + rewriter.replaceOp(shapeCastOp, adaptor.getSource()); + return success(); + } + + // Lowering for size-n vectors when n > 1 hasn't been implemented. + return failure(); + } +}; + struct VectorBitcastConvert final : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -551,15 +579,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, - VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, - VectorFmaOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern, - VectorReductionPattern, VectorInsertStridedSliceOpConvert, - VectorShuffleOpConvert, VectorSplatPattern>(typeConverter, - patterns.getContext()); + patterns.add, + VectorFmaOpConvert, VectorInsertElementOpConvert, + VectorInsertOpConvert, VectorReductionPattern, + VectorReductionPattern, VectorShapeCast, + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorSplatPattern>(typeConverter, patterns.getContext()); } void mlir::populateVectorReductionToSPIRVDotProductPatterns( diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -568,3 +568,15 @@ %reduce = vector.reduction , %v, %s : vector<3xi32> into i32 return %reduce : i32 } + +// ----- + +// CHECK-LABEL: @shape_cast_size1_vector +// CHECK-SAME: (%[[ARG0:.*]]: vector) +// CHECK: %[[R0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector to f32 +// CHECK: %[[R1:.+]] = builtin.unrealized_conversion_cast %[[R0]] : f32 to vector<1xf32> +// CHECK: return %[[R1]] +func.func @shape_cast_size1_vector(%arg0 : vector) -> vector<1xf32> { + %1 = vector.shape_cast %arg0 : vector to vector<1xf32> + return %1 : vector<1xf32> +}