diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -401,9 +401,15 @@ LogicalResult promoteSubviewsPrecondition(Operation *op, LinalgPromotionOptions options); -/// Rewrite a linalg.generic into a suitable vector.contraction op. +/// Return success if the operation can be vectorized. LogicalResult vectorizeLinalgOpPrecondition(Operation *op); +/// Return success if `op` can be vectorized assuming it is static. This allows +/// checking if an op will be vectorizable once all the dimensions are folded to +/// static values. +/// It is the same as `vectorizeLinalgOpPrecondition` for static shapes. +LogicalResult vectorizeStaticLinalgOpPrecondition(LinalgOp op); + //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -599,34 +599,39 @@ return success(); } -LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { - auto linalgOp = cast(op); - // All types must be static shape to go to vector. - if (linalgOp.hasDynamicShape()) { - LDBG("precondition failed: dynamic shape"); - return failure(); - } +LogicalResult +mlir::linalg::vectorizeStaticLinalgOpPrecondition(linalg::LinalgOp op) { if (isElementwise(op)) return success(); // TODO: isaConvolutionOpInterface that can also infer from generic features. // But we will still need stride/dilation attributes that will be annoying to // reverse-engineer... - if (isa(op)) + if (isa(op.getOperation())) return success(); // TODO: the common vector shape is equal to the static loop sizes only when // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. - if (!allIndexingsAreProjectedPermutation(linalgOp)) { + if (!allIndexingsAreProjectedPermutation(op)) { LDBG("precondition failed: not projected permutations"); return failure(); } - if (failed(reductionPreconditions(linalgOp))) { + if (failed(reductionPreconditions(op))) { LDBG("precondition failed: reduction preconditions"); return failure(); } return success(); } +LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) { + auto linalgOp = cast(op); + // All types must be static shape to go to vector. + if (linalgOp.hasDynamicShape()) { + LDBG("precondition failed: dynamic shape"); + return failure(); + } + return vectorizeStaticLinalgOpPrecondition(linalgOp); +} + LogicalResult mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op, SmallVectorImpl &newResults) {