diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -355,10 +355,12 @@ struct OpRewritePattern : public detail::OpOrInterfaceRewritePatternBase { /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. - OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) + /// also specify the benefit of the pattern matching and a list of generated + /// ops. + OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1, + ArrayRef generatedNames = {}) : detail::OpOrInterfaceRewritePatternBase( - SourceOp::getOperationName(), benefit, context) {} + SourceOp::getOperationName(), benefit, context, generatedNames) {} }; /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -3,6 +3,7 @@ DialectTest.cpp InterfaceAttachmentTest.cpp OperationSupportTest.cpp + PatternMatchTest.cpp ShapedTypeTest.cpp SubElementInterfaceTest.cpp diff --git a/mlir/unittests/IR/PatternMatchTest.cpp b/mlir/unittests/IR/PatternMatchTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/IR/PatternMatchTest.cpp @@ -0,0 +1,30 @@ +//===- PatternMatchTest.cpp - PatternMatch unit tests ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/PatternMatch.h" +#include "gtest/gtest.h" + +#include "../../test/lib/Dialect/Test/TestDialect.h" + +using namespace mlir; + +namespace { +struct AnOpRewritePattern : OpRewritePattern { + AnOpRewritePattern(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/1, + /*generatedNames=*/{test::OpB::getOperationName()}) {} +}; +TEST(OpRewritePatternTest, GetGeneratedNames) { + MLIRContext context; + AnOpRewritePattern pattern(&context); + ArrayRef ops = pattern.getGeneratedOps(); + + ASSERT_EQ(ops.size(), 1u); + ASSERT_EQ(ops.front().getStringRef(), test::OpB::getOperationName()); +} +} // end anonymous namespace