diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -865,8 +865,16 @@ LinalgTilingAndFusionOptions options, LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(Operation *op, PatternRewriter &rewriter) const; + LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } private: /// LinalgTransformMarker handles special attribute manipulations. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -549,7 +549,8 @@ : RewritePattern(opName, benefit, context), filter(std::move(f)), options(std::move(options)) {} -LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( +FailureOr +mlir::linalg::LinalgTileAndFuseTensorOpsPattern::returningMatchAndRewrite( Operation *op, PatternRewriter &rewriter) const { LinalgOp rootOp = dyn_cast(op); if (!rootOp) @@ -604,7 +605,7 @@ // Apply the filter if specified. for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) filter.replaceLinalgTransformationFilter(rewriter, linalgOp); - return success(); + return tileLoopNest; } /// Linalg generic interchange pattern.