diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -1638,12 +1638,15 @@ // Add a matchAndRewrite style pattern represented as a C function pointer. template - RewritePatternSet &add(LogicalResult (*implFn)(OpType, - PatternRewriter &rewriter)) { + RewritePatternSet & + add(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), + PatternBenefit benefit = 1, ArrayRef generatedNames = {}) { struct FnPattern final : public OpRewritePattern { FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), - MLIRContext *context) - : OpRewritePattern(context), implFn(implFn) {} + MLIRContext *context, PatternBenefit benefit, + ArrayRef generatedNames) + : OpRewritePattern(context, benefit, generatedNames), + implFn(implFn) {} LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { @@ -1653,7 +1656,8 @@ private: LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); }; - add(std::make_unique(std::move(implFn), getContext())); + add(std::make_unique(std::move(implFn), getContext(), benefit, + generatedNames)); return *this; } diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp --- a/mlir/unittests/IR/PatternMatchTest.cpp +++ b/mlir/unittests/IR/PatternMatchTest.cpp @@ -14,11 +14,13 @@ using namespace mlir; namespace { + struct AnOpRewritePattern : OpRewritePattern { AnOpRewritePattern(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1, /*generatedNames=*/{test::OpB::getOperationName()}) {} }; + TEST(OpRewritePatternTest, GetGeneratedNames) { MLIRContext context; AnOpRewritePattern pattern(&context); @@ -27,4 +29,23 @@ ASSERT_EQ(ops.size(), 1u); ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName()); } + +LogicalResult anOpRewritePatternFunc(test::OpA op, PatternRewriter &rewriter) { + return failure(); +} + +TEST(AnOpRewritePatternTest, PatternFuncAttributes) { + MLIRContext context; + RewritePatternSet patterns(&context); + + patterns.add(anOpRewritePatternFunc, /*benefit=*/3, + /*generatedNames=*/{test::OpB::getOperationName()}); + ASSERT_EQ(patterns.getNativePatterns().size(), 1); + auto &pattern = patterns.getNativePatterns().front(); + ASSERT_EQ(pattern->getBenefit(), 3); + ASSERT_EQ(pattern->getGeneratedOps().size(), 1); + ASSERT_EQ(pattern->getGeneratedOps().front().getStringRef(), + test::OpB::getOperationName()); +} + } // end anonymous namespace