diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -344,6 +344,9 @@ LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options); +/// Return success if the operation can be vectorized. +LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp); + //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -566,8 +566,16 @@ return failure(); } } - if (isElementwise(op)) + if (isElementwise(op)) { + // Some operations in the body cannot be vectorized. + for (Operation &payloadOp : op.getBlock()->getOperations()) { + if (isa(payloadOp)) { + LDBG("precondition failed: `tensor.extract` not vectorizable"); + return failure(); + } + } return success(); + } // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to // reverse-engineer... @@ -587,7 +595,7 @@ return success(); } -static LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp) { // All types must be static shape to go to vector. if (linalgOp.hasDynamicShape()) { LDBG("precondition failed: dynamic shape");