diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -442,11 +442,19 @@ filters.push_back(f); return *this; } + template LinalgTransformationFilter &addOpFilter() { return addFilter( [](Operation *op) { return success(isa(op)); }); } + + LinalgTransformationFilter &addOpNameFilter(StringRef opName) { + return addFilter([opName](Operation *op) { + return success(op->getName().getStringRef() == opName); + }); + } + LinalgTransformationFilter &setMatchByDefault() { matchByDefault = true; return *this; @@ -607,7 +615,7 @@ /// See `tiling` for more details. // TODO: TiledOpInterface struct LinalgTilingPattern : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `f`. + /// Construct a generic pattern applied to all LinalgOp that verify `filter`. LinalgTilingPattern( MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter f = LinalgTransformationFilter(), @@ -643,20 +651,29 @@ /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `padding` for more details. struct LinalgPaddingPattern : public OpInterfaceRewritePattern { - // Entry point to match any LinalgOp OpInterface. + /// Construct a generic pattern applied to all LinalgOp that verify `filter`. LinalgPaddingPattern( MLIRContext *context, LinalgPaddingOptions options = LinalgPaddingOptions(), - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - // Entry point to match a specific LinalgOp. + + /// Construct a pattern specifically applied to `opName`. LinalgPaddingPattern( StringRef opName, MLIRContext *context, LinalgPaddingOptions options = LinalgPaddingOptions(), - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(LinalgOp, - PatternRewriter &rewriter) const override; + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr returningMatchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } private: /// LinalgTransformMarker handles special attribute manipulations. @@ -679,7 +696,7 @@ StringRef opName, MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(), LinalgTransformationFilter originalOpMarker = LinalgTransformationFilter(), @@ -711,14 +728,14 @@ LinalgTileAndFusePattern( MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(), LinalgTransformationFilter originalOpMarker = LinalgTransformationFilter(), PatternBenefit benefit = 1) : LinalgBaseTileAndFusePattern( OpTy::getOperationName(), context, dependenceGraph, tilingOptions, - fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {} + fusionOptions, f, fusedOpMarker, originalOpMarker, benefit) {} }; /// @@ -731,13 +748,13 @@ // Entry point to match any LinalgOp. LinalgTileAndFuseTensorOpsPattern( MLIRContext *context, LinalgTilingAndFusionOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); // Entry point to match a specific LinalgOp. LinalgTileAndFuseTensorOpsPattern( StringRef opName, MLIRContext *context, LinalgTilingAndFusionOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override; @@ -757,12 +774,22 @@ /// See `interchange` for more details. struct GenericOpInterchangePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + + /// GenericOp-specific constructor with an optional `filter`. GenericOpInterchangePattern( MLIRContext *context, ArrayRef interchangeVector, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override; + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(GenericOp op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } private: /// LinalgTransformMarker handles special attribute manipulations. @@ -777,19 +804,29 @@ /// Apply the `generalization` transformation as a pattern. /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `generalization` for more details. -struct LinalgGeneralizationPattern : public RewritePattern { - // Entry point to match any LinalgOp OpInterface. +struct LinalgGeneralizationPattern + : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all LinalgOp that verify `filter`. LinalgGeneralizationPattern( MLIRContext *context, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - // Entry point to match a specific Linalg op. + + /// Construct a pattern specifically applied to `opName`. LinalgGeneralizationPattern( StringRef opName, MLIRContext *context, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } private: /// LinalgTransformMarker handles special attribute manipulations. @@ -806,13 +843,13 @@ /// Entry point to match any LinalgOp OpInterface. /// MatchAnyOpTag-based constructor with a mandatory `filter`. LinalgBasePromotionPattern( - MLIRContext *context, LinalgTransformationFilter filter, + MLIRContext *context, LinalgTransformationFilter f, LinalgPromotionOptions options = LinalgPromotionOptions(), PatternBenefit benefit = 1); /// Entry point to match a specific Linalg op. LinalgBasePromotionPattern( StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); LogicalResult matchAndRewrite(Operation *op, @@ -832,16 +869,16 @@ template LinalgPromotionPattern( MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1) : LinalgBasePromotionPattern(OpTy::getOperationName(), context, options, - filter, benefit) {} + f, benefit) {} /// This constructor is available to anyone. LinalgPromotionPattern( StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1) - : LinalgBasePromotionPattern(opName, context, options, filter, benefit) {} + : LinalgBasePromotionPattern(opName, context, options, f, benefit) {} }; /// @@ -852,39 +889,28 @@ /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `vectorizeLinalgOp` for more details. -struct LinalgBaseVectorizationPattern : public RewritePattern { - /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgBaseVectorizationPattern(MLIRContext *context, - LinalgTransformationFilter filter, - PatternBenefit benefit = 1); - /// Name-based constructor with an optional `filter`. - LinalgBaseVectorizationPattern( +struct LinalgVectorizationPattern : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all LinalgOp that verify `filter`. + LinalgVectorizationPattern( + MLIRContext *context, + LinalgTransformationFilter f = LinalgTransformationFilter(), + LinalgVectorizationOptions options = LinalgVectorizationOptions(), + PatternBenefit benefit = 1); + + /// Construct a pattern specifically applied to `opName`. + LinalgVectorizationPattern( StringRef opName, MLIRContext *context, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgVectorizationOptions options = LinalgVectorizationOptions(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, + + LogicalResult matchAndRewrite(LinalgOp linalgOp, PatternRewriter &rewriter) const override; private: /// LinalgTransformMarker handles special attribute manipulations. LinalgTransformationFilter filter; -}; - -struct LinalgVectorizationPattern : public LinalgBaseVectorizationPattern { - /// These constructors are available to anyone. - /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgVectorizationPattern( - MLIRContext *context, LinalgTransformationFilter filter, - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - PatternBenefit benefit = 1) - : LinalgBaseVectorizationPattern(context, filter, benefit) {} - /// Name-based constructor with an optional `filter`. - LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {} + LinalgVectorizationOptions options; }; //===----------------------------------------------------------------------===// @@ -1008,48 +1034,6 @@ //===----------------------------------------------------------------------===// // Transformations exposed as rewrite patterns. //===----------------------------------------------------------------------===// -/// Trait to check if T provides a `getOperationName` method. -template -using has_get_operation_name = decltype(T::getOperationName()); -template -using detect_has_get_operation_name = - llvm::is_detected; - -/// SFINAE helper for single C++ op with a `getOperationName` method. -template < - typename OpType, - typename = std::enable_if_t::value>, - typename = void> -void insertVectorizationPatternImpl(RewritePatternSet &patternList, - linalg::LinalgVectorizationOptions options, - linalg::LinalgTransformationFilter f) { - patternList.add( - OpType::getOperationName(), patternList.getContext(), options, f); -} - -/// SFINAE helper for single C++ class without a `getOperationName` method (e.g. -/// an OpInterface). -template ::value>> -void insertVectorizationPatternImpl(RewritePatternSet &patternList, - linalg::LinalgVectorizationOptions options, - linalg::LinalgTransformationFilter f) { - patternList.add( - patternList.getContext(), f.addOpFilter(), options); -} - -/// Variadic helper function to insert vectorization patterns for C++ ops. -template -void insertVectorizationPatterns(RewritePatternSet &patternList, - linalg::LinalgVectorizationOptions options, - linalg::LinalgTransformationFilter f = - linalg::LinalgTransformationFilter()) { - // FIXME: In c++17 this can be simplified by using 'fold expressions'. - (void)std::initializer_list{ - 0, - (insertVectorizationPatternImpl(patternList, options, f), 0)...}; -} - /// /// Linalg lowering patterns. /// @@ -1067,10 +1051,10 @@ struct LinalgLoweringPattern : public RewritePattern { LinalgLoweringPattern( MLIRContext *context, LinalgLoweringType loweringType, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1) : RewritePattern(OpTy::getOperationName(), benefit, context), - filter(filter), loweringType(loweringType) {} + filter(std::move(f)), loweringType(loweringType) {} // TODO: Move implementation to .cpp once named ops are auto-generated. LogicalResult matchAndRewrite(Operation *op, @@ -1352,6 +1336,29 @@ //===----------------------------------------------------------------------===// // Helper classes for type list expansion. //===----------------------------------------------------------------------===// +template +class VectorizationPatterns; + +template <> +class VectorizationPatterns<> { +public: + static void insert(RewritePatternSet &patterns, + const LinalgVectorizationOptions &options, + const LinalgTransformationFilter &f) {} +}; + +template +class VectorizationPatterns { +public: + static void insert(RewritePatternSet &patterns, + const LinalgVectorizationOptions &options, + const LinalgTransformationFilter &f) { + patterns.add(OpTy::getOperationName(), + patterns.getContext(), options, f); + VectorizationPatterns::insert(patterns, options, f); + } +}; + template class TilingPatterns; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -357,11 +357,11 @@ StringRef opName, MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker, + LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) : RewritePattern(opName, benefit, context, {}), dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), - fusionOptions(std::move(fusionOptions)), filter(std::move(filter)), + fusionOptions(std::move(fusionOptions)), filter(std::move(f)), fusedOpMarker(std::move(fusedOpMarker)), originalOpMarker(std::move(originalOpMarker)) {} @@ -462,11 +462,7 @@ StringRef opName, MLIRContext *context, LinalgTilingOptions options, LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)), options(std::move(options)) { - this->filter.addFilter([opName](Operation *op) { - return success(op->getName().getStringRef() == opName); - }); -} + filter(f.addOpNameFilter(opName)), options(std::move(options)) {} FailureOr mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( @@ -496,21 +492,18 @@ /// Linalg padding pattern. mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( MLIRContext *context, LinalgPaddingOptions options, - LinalgTransformationFilter filter, PatternBenefit benefit) + LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), - filter(std::move(filter)), options(std::move(options)) {} + filter(std::move(f)), options(std::move(options)) {} mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( StringRef opName, MLIRContext *context, LinalgPaddingOptions options, - LinalgTransformationFilter filter, PatternBenefit benefit) + LinalgTransformationFilter f, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), - filter(std::move(filter)), options(std::move(options)) { - this->filter.addFilter([opName](Operation *op) { - return success(op->getName().getStringRef() == opName); - }); -} + filter(f.addOpNameFilter(opName)), options(std::move(options)) {} -LogicalResult mlir::linalg::LinalgPaddingPattern::matchAndRewrite( +FailureOr +mlir::linalg::LinalgPaddingPattern::returningMatchAndRewrite( LinalgOp linalgOp, PatternRewriter &rewriter) const { if (!linalgOp.hasTensorSemantics()) return failure(); @@ -549,24 +542,24 @@ // Replace the original operation to pad. rewriter.replaceOp(linalgOp, newResults.getValue()); filter.replaceLinalgTransformationFilter(rewriter, paddedOp); - return success(); + return paddedOp; } /// Linalg tile and fuse tensor ops pattern. mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, LinalgTilingAndFusionOptions options, - LinalgTransformationFilter filter, + LinalgTransformationFilter f, PatternBenefit benefit) : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(filter)), options(std::move(options)) {} + filter(std::move(f)), options(std::move(options)) {} mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, LinalgTilingAndFusionOptions options, - LinalgTransformationFilter filter, + LinalgTransformationFilter f, PatternBenefit benefit) - : RewritePattern(opName, benefit, context), filter(std::move(filter)), + : RewritePattern(opName, benefit, context), filter(std::move(f)), options(std::move(options)) {} LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( @@ -624,11 +617,12 @@ /// Linalg generic interchange pattern. mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( MLIRContext *context, ArrayRef interchangeVector, - LinalgTransformationFilter filter, PatternBenefit benefit) - : OpRewritePattern(context, benefit), filter(std::move(filter)), + LinalgTransformationFilter f, PatternBenefit benefit) + : OpRewritePattern(context, benefit), filter(std::move(f)), interchangeVector(interchangeVector.begin(), interchangeVector.end()) {} -LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite( +FailureOr +mlir::linalg::GenericOpInterchangePattern::returningMatchAndRewrite( GenericOp genericOp, PatternRewriter &rewriter) const { if (failed(filter.checkAndNotify(rewriter, genericOp))) return failure(); @@ -645,41 +639,38 @@ /// Linalg generalization pattern. mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( - MLIRContext *context, LinalgTransformationFilter filter, - PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(filter)) {} + MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(f)) {} mlir::linalg::LinalgGeneralizationPattern::LinalgGeneralizationPattern( - StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, + StringRef opName, MLIRContext *context, LinalgTransformationFilter f, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)) {} + : OpInterfaceRewritePattern(context, benefit), + filter(f.addOpNameFilter(opName)) {} -LogicalResult mlir::linalg::LinalgGeneralizationPattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - // TODO: Interface pattern. - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) - return failure(); - if (failed(filter.checkAndNotify(rewriter, op))) +FailureOr +mlir::linalg::LinalgGeneralizationPattern::returningMatchAndRewrite( + LinalgOp linalgOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); FailureOr genericOp = generalizeNamedOp(rewriter, linalgOp); if (failed(genericOp)) return failure(); filter.replaceLinalgTransformationFilter(rewriter, *genericOp); - return success(); + return genericOp; } mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( - MLIRContext *context, LinalgTransformationFilter filter, + MLIRContext *context, LinalgTransformationFilter f, LinalgPromotionOptions options, PatternBenefit benefit) : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(filter)), options(std::move(options)) {} + filter(std::move(f)), options(std::move(options)) {} mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern( StringRef opName, MLIRContext *context, LinalgPromotionOptions options, - LinalgTransformationFilter filter, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)), + LinalgTransformationFilter f, PatternBenefit benefit) + : RewritePattern(opName, benefit, context, {}), filter(std::move(f)), options(std::move(options)) {} LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite( @@ -704,24 +695,21 @@ return success(); } -mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( - MLIRContext *context, LinalgTransformationFilter filter, - PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(filter)) {} +mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( + MLIRContext *context, LinalgTransformationFilter f, + LinalgVectorizationOptions options, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(f)) {} -mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern( - StringRef opName, MLIRContext *context, LinalgTransformationFilter filter, - PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), filter(std::move(filter)) {} +mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( + StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, + LinalgTransformationFilter f, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(f.addOpNameFilter(opName)) {} -LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - // TODO: Interface-based rewrite. - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) - return failure(); - if (failed(filter.checkAndNotify(rewriter, op))) +LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( + LinalgOp linalgOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, linalgOp))) return failure(); return vectorize(rewriter, linalgOp); } @@ -947,10 +935,10 @@ : public OpRewritePattern { DownscaleSizeOneWindowed2DConvolution( MLIRContext *context, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - filter(std::move(filter)) {} + filter(std::move(f)) {} LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const override { @@ -1033,10 +1021,10 @@ : public OpRewritePattern { DownscaleDepthwiseConv2DNhwcHwcOp( MLIRContext *context, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), - filter(std::move(filter)) {} + filter(std::move(f)) {} LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const override { diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -300,8 +300,7 @@ MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(), LinalgTransformationFilter(StringAttr::get(ctx, "VEC")))); patternsVector.back().add( - ctx, LinalgTransformationFilter().addFilter( - [](Operation *op) { return success(isa(op)); })); + ctx, LinalgTransformationFilter().addOpFilter()); } //===----------------------------------------------------------------------===//