diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1638,12 +1638,15 @@ // Add a matchAndRewrite style pattern represented as a C function pointer. template - RewritePatternSet &add(LogicalResult (*implFn)(OpType, - PatternRewriter &rewriter)) { + RewritePatternSet & + add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), + PatternBenefit benefit = 1, ArrayRef generatedNames = {}) { struct FnPattern final : public OpRewritePattern { FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), - MLIRContext *context) - : OpRewritePattern(context), implFn(implFn) {} + MLIRContext *context, PatternBenefit benefit, + ArrayRef generatedNames) + : OpRewritePattern(context, benefit, generatedNames), + implFn(implFn) {} LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { @@ -1653,7 +1656,8 @@ private: LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); }; - add(std::make_unique(std::move(implFn), getContext())); + add(std::make_unique(std::move(implFn), getContext(), benefit, + generatedNames)); return *this; } diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp --- a/mlir/unittests/IR/PatternMatchTest.cpp +++ b/mlir/unittests/IR/PatternMatchTest.cpp @@ -28,3 +28,22 @@ ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName()); } } // end anonymous namespace + +namespace { +LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) { + return failure(); +} +TEST(AnOpRewritePatternTest, PatternFuncAttributes) { + MLIRContext context; + RewritePatternSet patterns(&context); + + patterns.add(anOpRewritePatternFunc, /*benefit=*/3, + /*generatedNames=*/{test::OpB::getOperationName()}); + ASSERT_EQ(patterns.getNativePatterns().size(), 1); + auto &pattern = patterns.getNativePatterns().front(); + ASSERT_EQ(pattern->getBenefit(), 3); + ASSERT_EQ(pattern->getGeneratedOps().size(), 1); + ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(), + test::OpB::getOperationName()); +} +} // end anonymous namespace