diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -313,6 +313,13 @@ PatternBenefit benefit = 1) : LinalgBaseTilingPattern(OpTy::getOperationName(), context, options, marker, benefit) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (failed(LinalgBaseTilingPattern::matchAndRewrite(op, rewriter))) + return failure(); + rewriter.eraseOp(op); + return success(); + } }; /// @@ -415,7 +422,8 @@ AffineLoops = 2, ParallelLoops = 3 }; -template struct LinalgLoweringPattern : public RewritePattern { +template +struct LinalgLoweringPattern : public RewritePattern { LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType, LinalgMarker marker = LinalgMarker(), PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -126,8 +126,6 @@ // New marker if specified. marker.replaceLinalgMarker(rewriter, res->op.getOperation()); - - rewriter.eraseOp(op); return success(); }