diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -772,35 +772,25 @@ /// Linalg generalization pattern. /// /// Apply the `generalization` transformation as a pattern. -/// `filter` controls LinalgTransformMarker matching and update when specified. /// See `generalization` for more details. +// +// TODO: Automatic default pattern class that just unwraps a function returning +// FailureOr. struct LinalgGeneralizationPattern : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgGeneralizationPattern( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgGeneralizationPattern( - StringRef opName, MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; /// `matchAndRewrite` implementation that returns the significant transformed /// pieces of IR. FailureOr - returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const; + returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const { + return generalizeNamedOp(rewriter, op); + } LogicalResult matchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const override { return returningMatchAndRewrite(op, rewriter); } - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; }; /// @@ -917,9 +907,7 @@ /// Populates `patterns` with patterns to convert spec-generated named ops to /// linalg.generic ops. -void populateLinalgNamedOpsGeneralizationPatterns( - RewritePatternSet &patterns, - const LinalgTransformationFilter &filter = LinalgTransformationFilter()); +void populateLinalgNamedOpsGeneralizationPatterns(RewritePatternSet &patterns); /// Linalg decompose convolutions patterns diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -86,8 +86,8 @@ } void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns( - RewritePatternSet &patterns, const LinalgTransformationFilter &marker) { - patterns.add(patterns.getContext(), marker); + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); } std::unique_ptr> diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -427,30 +427,6 @@ return paddedOp; } -/// Linalg generalization pattern. -mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( - MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)) {} - -mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( - StringRef opName, MLIRContext *context, LinalgTransformationFilter f, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)) {} - -FailureOr -mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( - LinalgOp linalgOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - FailureOr genericOp = generalizeNamedOp(rewriter, linalgOp); - if (failed(genericOp)) - return failure(); - filter.replaceLinalgTransformationFilter(rewriter, *genericOp); - return genericOp; -} - LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp);