diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -977,6 +977,10 @@ mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp, ArrayRef inputVectorSizes, bool vectorizeNDExtract) { + // tensor with dimension of 0 cannot be vectorized. + if (llvm::any_of(linalgOp.getStaticShape(), + [](int64_t dim) { return dim == 0; })) + return failure(); // Check API contract for input vector sizes. if (!inputVectorSizes.empty()) { assert(inputVectorSizes.size() == linalgOp.getNumLoops() && diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -2004,3 +2004,29 @@ // CHECK-LABEL: @wrong_reduction_detection // CHECK: vector.broadcast // CHECK: vector.transfer_write + +// ----- + +// Don't vectorize tensor<0xf32> +// CHECK-LABEL: @tensor_size0 +// CHECK: linalg.generic +func.func @tensor_size0(%arg0: tensor<0xf32>, + %arg1: tensor) -> tensor { + %0 = linalg.generic + {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor<0xf32>) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %12 = arith.addf %out, %in : f32 + linalg.yield %12 : f32 + } -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation + %2 = transform.structured.vectorize %1 +} +