diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -688,7 +688,7 @@ /// Apply the `padding` transformation as a pattern. /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `padding` for more details. -struct LinalgPaddingPattern : public RewritePattern { +struct LinalgPaddingPattern : public OpInterfaceRewritePattern { // Entry point to match any LinalgOp OpInterface. LinalgPaddingPattern( MLIRContext *context, @@ -701,7 +701,7 @@ LinalgPaddingOptions options = LinalgPaddingOptions(), LinalgTransformationFilter filter = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, + LogicalResult matchAndRewrite(LinalgOp, PatternRewriter &rewriter) const override; private: diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -489,23 +489,24 @@ mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( MLIRContext *context, LinalgPaddingOptions options, LinalgTransformationFilter filter, PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), + : OpInterfaceRewritePattern(context, benefit), filter(std::move(filter)), options(std::move(options)) {} mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( StringRef opName, MLIRContext *context, LinalgPaddingOptions options, LinalgTransformationFilter filter, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)), - options(std::move(options)) {} + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(filter)), options(std::move(options)) { + this->filter.addFilter([opName](Operation *op) { + return success(op->getName().getStringRef() == opName); + }); +} LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) - return failure(); + LinalgOp linalgOp, PatternRewriter &rewriter) const { if (!linalgOp.hasTensorSemantics()) return failure(); - if (failed(filter.checkAndNotify(rewriter, op))) + if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); // Pad the operation. @@ -538,7 +539,7 @@ } // Replace the original operation to pad. - rewriter.replaceOp(op, newResults.getValue()); + rewriter.replaceOp(linalgOp, newResults.getValue()); filter.replaceLinalgTransformationFilter(rewriter, paddedOp); return success(); }