diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1109,10 +1109,14 @@ if (!srcType.hasStaticShape() || !dstType.hasStaticShape()) return failure(); - auto readType = - VectorType::get(srcType.getShape(), getElementTypeOrSelf(srcType)); - auto writeType = - VectorType::get(dstType.getShape(), getElementTypeOrSelf(dstType)); + auto srcElementType = getElementTypeOrSelf(srcType); + auto dstElementType = getElementTypeOrSelf(dstType); + if (!VectorType::isValidElementType(srcElementType) || + !VectorType::isValidElementType(dstElementType)) + return failure(); + + auto readType = VectorType::get(srcType.getShape(), srcElementType); + auto writeType = VectorType::get(dstType.getShape(), dstElementType); Location loc = copyOp->getLoc(); Value zero = rewriter.create(loc, 0); @@ -1173,6 +1177,8 @@ tensor::PadOp padOp, Value dest) { auto sourceType = padOp.getSourceType(); auto resultType = padOp.getResultType(); + if (!VectorType::isValidElementType(sourceType.getElementType())) + return failure(); // Copy cannot be vectorized if pad value is non-constant and source shape // is dynamic. In case of a dynamic source shape, padding must be appended diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -359,6 +359,23 @@ %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation %2 = transform.structured.vectorize %1 } + +// ----- + +// CHECK-LABEL: func @test_vectorize_copy_complex +// CHECK-NOT: vector< +func.func @test_vectorize_copy_complex(%A : memref<8x16xcomplex>, %B : memref<8x16xcomplex>) { + memref.copy %A, %B : memref<8x16xcomplex> to memref<8x16xcomplex> + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 +} + // ----- // CHECK-LABEL: func @test_vectorize_trailing_index @@ -806,6 +823,25 @@ %2 = transform.structured.vectorize %1 { vectorize_padding } } +// ----- + +// CHECK-LABEL: func @pad_static_complex( +// CHECK-NOT: vector< +func.func @pad_static_complex(%arg0: tensor<2x5x2xcomplex>, %pad_value: complex) -> tensor<2x6x4xcomplex> { + %0 = tensor.pad %arg0 low[0, 0, 2] high[0, 1, 0] { + ^bb0(%arg1: index, %arg2: index, %arg3: index): + tensor.yield %pad_value : complex + } : tensor<2x5x2xcomplex> to tensor<2x6x4xcomplex> + return %0 : tensor<2x6x4xcomplex> +} + + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 { vectorize_padding } +} // -----