diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -552,6 +552,19 @@ } static LogicalResult vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) { + // All types in the body should be a supported element type for VectorType. + for (Operation &innerOp : op->getRegion(0).front()) { + if (llvm::any_of(innerOp.getOperandTypes(), [](Type type) { + return !VectorType::isValidElementType(type); + })) { + return failure(); + } + if (llvm::any_of(innerOp.getResultTypes(), [](Type type) { + return !VectorType::isValidElementType(type); + })) { + return failure(); + } + } if (isElementwise(op)) return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -207,6 +207,23 @@ // ----- +// CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types +func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomplex>, %arg0 : complex) { + // CHECK-NOT: vector.broadcast + // CHECK-NOT: vector.transfer_write + linalg.generic { + indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : complex) + outs(%A: memref<8x16xcomplex>) { + ^bb(%0: complex, %1: complex) : + linalg.yield %0 : complex + } + return +} + +// ----- + // CHECK-LABEL: func @test_vectorize_fill func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { // CHECK: %[[V:.*]] = vector.broadcast {{.*}} : f32 to vector<8x16xf32>