diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/LinalgTransforms.h @@ -63,12 +63,18 @@ ArrayRef sizes, StringRef linalgMarker, ArrayRef permutation); +LogicalResult tileLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + StringRef linalgMarker, ArrayRef permutation); /// Tiles `op` by `sizes`, fuses the producers of `operandIndicesToFuse` and /// sets the attribute `kLinalgTransformMarker` to `linalgMarker`. LogicalResult tileAndFuseLinalgOpAndSetMarker( PatternRewriter &rewriter, Operation *op, ArrayRef sizes, ArrayRef operandIndicesToFuse, StringRef linalgMarker); +LogicalResult tileAndFuseLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker); using LinalgLoops = SmallVector; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -40,11 +40,16 @@ const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker = "__internal_linalg_transform__"; -LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( - PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - StringRef linalgMarker, ArrayRef permutation) { +using TileFn = Optional(OpBuilder &, LinalgOp, ArrayRef, + ArrayRef, OperationFolder *); + +static LogicalResult +tileLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, + Operation *op, ArrayRef sizes, + StringRef linalgMarker, + ArrayRef permutation) { assert(permutation.empty() || permutation.size() == sizes.size()); - auto tileRes = tileLinalgOperation(rewriter, op, sizes, permutation); + auto tileRes = tileFn(rewriter, op, sizes, permutation, /*folder=*/nullptr); if (!tileRes) return failure(); tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, @@ -52,10 +57,26 @@ return success(); } -LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( +LogicalResult mlir::linalg::tileLinalgOpAndSetMarker( PatternRewriter &rewriter, Operation *op, ArrayRef sizes, - ArrayRef operandIndicesToFuse, StringRef linalgMarker) { - auto tileRes = tileLinalgOperation(rewriter, op, sizes); + StringRef linalgMarker, ArrayRef permutation) { + return tileLinalgOpAndSetMarkerImpl(tileLinalgOp, rewriter, op, sizes, + linalgMarker, permutation); +} +LogicalResult mlir::linalg::tileLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + StringRef linalgMarker, ArrayRef permutation) { + return tileLinalgOpAndSetMarkerImpl(tileLinalgOpToParallelLoops, rewriter, op, + sizes, linalgMarker, permutation); +} + +static LogicalResult +tileAndFuseLinalgOpAndSetMarkerImpl(TileFn tileFn, PatternRewriter &rewriter, + Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, + StringRef linalgMarker) { + auto tileRes = + tileFn(rewriter, op, sizes, /*permutation=*/{}, /*folder=*/nullptr); if (!tileRes) return failure(); tileRes->op.setAttr(LinalgTransforms::kLinalgTransformMarker, @@ -89,6 +110,20 @@ return success(); } +LogicalResult mlir::linalg::tileAndFuseLinalgOpAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker) { + return tileAndFuseLinalgOpAndSetMarkerImpl( + tileLinalgOp, rewriter, op, sizes, operandIndicesToFuse, linalgMarker); +} +LogicalResult mlir::linalg::tileAndFuseLinalgOpToParallelLoopsAndSetMarker( + PatternRewriter &rewriter, Operation *op, ArrayRef sizes, + ArrayRef operandIndicesToFuse, StringRef linalgMarker) { + return tileAndFuseLinalgOpAndSetMarkerImpl( + tileLinalgOpToParallelLoops, rewriter, op, sizes, operandIndicesToFuse, + linalgMarker); +} + bool mlir::linalg::detail::isProducedByOpOfTypeImpl( Operation *consumerOp, Value consumedView, function_ref isaOpType) {