diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -74,9 +74,9 @@ /// condition predicates. PatternBenefit getBenefit() const { return benefit; } - /// Return the root node that this pattern matches. Patterns that can - /// match multiple root types are instantiated once per root. - OperationName getRootKind() const { return rootKind; } + /// Return the root node that this pattern matches. Patterns that can match + /// multiple root types return None. + Optional getRootKind() const { return rootKind; } //===--------------------------------------------------------------------===// // Implementation hooks for patterns to implement. @@ -89,12 +89,30 @@ virtual ~Pattern() {} protected: - /// Patterns must specify the root operation name they match against, and can - /// also specify the benefit of the pattern matching. + /// This class acts as a special tag that makes the desire to match "any" + /// operation type explicit. This helps to avoid unnecessary usages of this + /// feature, and ensures that the user is making a conscious decision. + struct MatchAnyOpTypeTag {}; + + /// This constructor is used for patterns that match against a specific + /// operation type. The `benefit` is the expected benefit of matching this + /// pattern. Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context); + /// This contructor is used when a pattern may match against multiple + /// different types of operations. The `benefit` is the expected benefit of + /// matching this pattern. `MatchAnyOpTypeTag` is just a tag to ensure that + /// the "match any" behavior is what the user actually desired, + /// `MatchAnyOpTypeTag()` should always be supplied here. + Pattern(PatternBenefit benefit, MatchAnyOpTypeTag); + private: - const OperationName rootKind; + /// The root operation of the pattern. If the pattern matches a specific + /// operation, this contains the name of that operation. Contains None + /// otherwise. + Optional rootKind; + + /// The expected benefit of matching this pattern. const PatternBenefit benefit; virtual void anchor(); @@ -151,10 +169,24 @@ MLIRContext *context) : Pattern(rootName, benefit, context) {} /// Patterns must specify the root operation name they match against, and can + /// also specify the benefit of the pattern matching. `MatchAnyOpTypeTag` + /// is just a tag to ensure that the "match any" behavior is what the user + /// actually desired, `MatchAnyOpTypeTag()` should always be supplied here. + RewritePattern(PatternBenefit benefit, MatchAnyOpTypeTag tag) + : Pattern(benefit, tag) {} + /// Patterns must specify the root operation name they match against, and can /// also specify the benefit of the pattern matching. They can also specify /// the names of operations that may be generated during a successful rewrite. RewritePattern(StringRef rootName, ArrayRef generatedNames, PatternBenefit benefit, MLIRContext *context); + /// Patterns must specify the root operation name they match against, and can + /// also specify the benefit of the pattern matching. They can also specify + /// the names of operations that may be generated during a successful rewrite. + /// `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" + /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should + /// always be supplied here. + RewritePattern(ArrayRef generatedNames, PatternBenefit benefit, + MLIRContext *context, MatchAnyOpTypeTag tag); /// A list of the potential operations that may be generated when rewriting /// an op with this pattern. @@ -431,6 +463,14 @@ return *this; } + /// Add an instance of each of the pattern types 'Ts'. Return a reference to + /// `this` for chaining insertions. + template OwningRewritePatternList &insert() { + (void)std::initializer_list{ + 0, (patterns.emplace_back(std::make_unique()), 0)...}; + return *this; + } + /// Add the given pattern to the pattern list. void insert(std::unique_ptr pattern) { patterns.emplace_back(std::move(pattern)); @@ -485,11 +525,23 @@ void walkAllPatterns(function_ref walk); private: + /// Attempt to match and rewrite the given op with the given pattern, allowing + /// a predicate to decide if a pattern can be applied or not, and hooks for if + /// the pattern match was a success or failure. + LogicalResult matchAndRewrite( + Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, + function_ref canApply, + function_ref onFailure, + function_ref onSuccess); + /// The list that owns the patterns used within this applicator. const OwningRewritePatternList &owningPatternList; /// The set of patterns to match for each operation, stable sorted by benefit. DenseMap> patterns; + /// The set of patterns that may match against any operation type, stable + /// sorted by benefit. + SmallVector anyOpPatterns; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -274,12 +274,6 @@ /// below. class ConversionPattern : public RewritePattern { public: - /// Construct an ConversionPattern. `rootName` must correspond to the - /// canonical name of the first operation matched by the pattern. - ConversionPattern(StringRef rootName, PatternBenefit benefit, - MLIRContext *ctx) - : RewritePattern(rootName, benefit, ctx) {} - /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of rewritten values /// that are passed to this operation, `rewriter` can be used to emit the new @@ -304,6 +298,9 @@ LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final; +protected: + using RewritePattern::RewritePattern; + private: using RewritePattern::rewrite; }; diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -30,6 +30,8 @@ Pattern::Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context) : rootKind(OperationName(rootName, context)), benefit(benefit) {} +Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag) + : benefit(benefit) {} // Out-of-line vtable anchor. void Pattern::anchor() {} @@ -47,9 +49,6 @@ llvm_unreachable("need to implement either match or matchAndRewrite!"); } -/// Patterns must specify the root operation name they match against, and can -/// also specify the benefit of the pattern matching. They can also specify the -/// names of operations that may be generated during a successful rewrite. RewritePattern::RewritePattern(StringRef rootName, ArrayRef generatedNames, PatternBenefit benefit, MLIRContext *context) @@ -60,6 +59,16 @@ return OperationName(name, context); }); } +RewritePattern::RewritePattern(ArrayRef generatedNames, + PatternBenefit benefit, MLIRContext *context, + MatchAnyOpTypeTag tag) + : Pattern(benefit, tag) { + generatedOps.reserve(generatedNames.size()); + std::transform(generatedNames.begin(), generatedNames.end(), + std::back_inserter(generatedOps), [context](StringRef name) { + return OperationName(name, context); + }); +} PatternRewriter::~PatternRewriter() { // Out of line to provide a vtable anchor for the class. @@ -173,22 +182,28 @@ void PatternApplicator::applyCostModel(CostModel model) { // Separate patterns by root kind to simplify lookup later on. patterns.clear(); - for (const auto &pat : owningPatternList) - patterns[pat->getRootKind()].push_back(pat.get()); + anyOpPatterns.clear(); + for (const auto &pat : owningPatternList) { + // If the pattern is always impossible to match, just ignore it. + if (pat->getBenefit().isImpossibleToMatch()) + continue; + if (Optional opName = pat->getRootKind()) + patterns[*opName].push_back(pat.get()); + else + anyOpPatterns.push_back(pat.get()); + } // Sort the patterns using the provided cost model. llvm::SmallDenseMap benefits; auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) { return benefits[lhs] > benefits[rhs]; }; - for (auto &it : patterns) { - SmallVectorImpl &list = it.second; - + auto processPatternList = [&](SmallVectorImpl &list) { // Special case for one pattern in the list, which is the most common case. if (list.size() == 1) { if (model(*list.front()).isImpossibleToMatch()) list.clear(); - continue; + return; } // Collect the dynamic benefits for the current pattern list. @@ -201,7 +216,10 @@ std::stable_sort(list.begin(), list.end(), cmp); while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) list.pop_back(); - } + }; + for (auto &it : patterns) + processPatternList(it.second); + processPatternList(anyOpPatterns); } void PatternApplicator::walkAllPatterns( @@ -210,32 +228,64 @@ walk(*it); } -/// Try to match the given operation to a pattern and rewrite it. LogicalResult PatternApplicator::matchAndRewrite( Operation *op, PatternRewriter &rewriter, function_ref canApply, function_ref onFailure, function_ref onSuccess) { + // Check to see if there are patterns matching this specific operation type. + MutableArrayRef opPatterns; auto patternIt = patterns.find(op->getName()); - if (patternIt == patterns.end()) - return failure(); + if (patternIt != patterns.end()) + opPatterns = patternIt->second; + + // Process the patterns for that match the specific operation type, and any + // operation type in an interleaved fashion. + // FIXME: It'd be nice to just write an llvm::make_merge_range utility + // and pass in a comparison function. That would make this code trivial. + auto opIt = opPatterns.begin(), opE = opPatterns.end(); + auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); + while (opIt != opE && anyIt != anyE) { + // Try to match the pattern providing the most benefit. + RewritePattern *pattern; + if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) + pattern = *(opIt++); + else + pattern = *(anyIt++); + + // Otherwise, try to match the generic pattern. + if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, + onSuccess))) + return success(); + } + // If we break from the loop, then only one of the ranges can still have + // elements. Loop over both without checking given that we don't need to + // interleave anymore. + for (RewritePattern *pattern : llvm::concat( + llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { + if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, + onSuccess))) + return success(); + } + return failure(); +} - for (auto *pattern : patternIt->second) { - // Check that the pattern can be applied. - if (canApply && !canApply(*pattern)) - continue; +LogicalResult PatternApplicator::matchAndRewrite( + Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, + function_ref canApply, + function_ref onFailure, + function_ref onSuccess) { + // Check that the pattern can be applied. + if (canApply && !canApply(pattern)) + return failure(); - // Try to match and rewrite this pattern. The patterns are sorted by - // benefit, so if we match we can immediately rewrite. - rewriter.setInsertionPoint(op); - if (succeeded(pattern->matchAndRewrite(op, rewriter))) { - if (!onSuccess || succeeded(onSuccess(*pattern))) - return success(); - continue; - } + // Try to match and rewrite this pattern. The patterns are sorted by + // benefit, so if we match we can immediately rewrite. + rewriter.setInsertionPoint(op); + if (succeeded(pattern.matchAndRewrite(op, rewriter))) + return success(!onSuccess || succeeded(onSuccess(pattern))); - if (onFailure) - onFailure(*pattern); - } + if (onFailure) + onFailure(pattern); return failure(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1164,10 +1164,11 @@ RewriterState &curState); /// Build an optimistic legalization graph given the provided patterns. This - /// function populates 'legalizerPatterns' with the operations that are not - /// directly legal, but may be transitively legal for the current target given - /// the provided patterns. + /// function populates 'anyOpLegalizerPatterns' and 'legalizerPatterns' with + /// patterns for operations that are not directly legal, but may be + /// transitively legal for the current target given the provided patterns. void buildLegalizationGraph( + LegalizationPatterns &anyOpLegalizerPatterns, DenseMap &legalizerPatterns); /// Compute the benefit of each node within the computed legalization graph. @@ -1179,6 +1180,21 @@ /// pattern with the highest PatternBenefit. This allows for users to /// prefer specific legalizations over others. void computeLegalizationGraphBenefit( + LegalizationPatterns &anyOpLegalizerPatterns, + DenseMap &legalizerPatterns); + + /// Compute the legalization depth when legalizing an operation of the given + /// type. + unsigned computeOpLegalizationDepth( + OperationName op, DenseMap &minOpPatternDepth, + DenseMap &legalizerPatterns); + + /// Apply the conversion cost model to the given set of patterns, and return + /// the smallest legalization depth of any of the patterns. See + /// `computeLegalizationGraphBenefit` for the breakdown of the cost model. + unsigned applyCostModelToPatterns( + LegalizationPatterns &patterns, + DenseMap &minOpPatternDepth, DenseMap &legalizerPatterns); /// The current set of patterns that have been applied. @@ -1195,12 +1211,13 @@ OperationLegalizer::OperationLegalizer(ConversionTarget &targetInfo, const OwningRewritePatternList &patterns) : target(targetInfo), applicator(patterns) { - // The set of legality information for operations transitively supported by - // the target. + // The set of patterns that can be applied to illegal operations to transform + // them into legal ones. DenseMap legalizerPatterns; + LegalizationPatterns anyOpLegalizerPatterns; - buildLegalizationGraph(legalizerPatterns); - computeLegalizationGraphBenefit(legalizerPatterns); + buildLegalizationGraph(anyOpLegalizerPatterns, legalizerPatterns); + computeLegalizationGraphBenefit(anyOpLegalizerPatterns, legalizerPatterns); } bool OperationLegalizer::isIllegal(Operation *op) const { @@ -1365,7 +1382,7 @@ LLVM_DEBUG({ auto &os = rewriter.getImpl().logger; os.getOStream() << "\n"; - os.startLine() << "* Pattern : '" << pattern.getRootKind() << " -> ("; + os.startLine() << "* Pattern : '" << op->getName() << " -> ("; llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs()); os.getOStream() << ")' {\n"; os.indent(); @@ -1466,6 +1483,7 @@ } void OperationLegalizer::buildLegalizationGraph( + LegalizationPatterns &anyOpLegalizerPatterns, DenseMap &legalizerPatterns) { // A mapping between an operation and a set of operations that can be used to // generate it. @@ -1478,22 +1496,41 @@ // Build the mapping from operations to the parent ops that may generate them. applicator.walkAllPatterns([&](const RewritePattern &pattern) { - OperationName root = pattern.getRootKind(); + Optional root = pattern.getRootKind(); + + // If the pattern has no specific root, we can't analyze the relationship + // between the root op and generated operations. Given that, add all such + // patterns to the legalization set. + if (!root) { + anyOpLegalizerPatterns.push_back(&pattern); + return; + } // Skip operations that are always known to be legal. - if (target.getOpAction(root) == LegalizationAction::Legal) + if (target.getOpAction(*root) == LegalizationAction::Legal) return; // Add this pattern to the invalid set for the root op and record this root // as a parent for any generated operations. - invalidPatterns[root].insert(&pattern); + invalidPatterns[*root].insert(&pattern); for (auto op : pattern.getGeneratedOps()) - parentOps[op].insert(root); + parentOps[op].insert(*root); // Add this pattern to the worklist. patternWorklist.insert(&pattern); }); + // If there are any patterns that don't have a specific root kind, we can't + // make direct assumptions about what operations will never be legalized. + // Note: Technically we could, but it would require an analysis that may + // recurse into itself. It would be better to perform this kind of filtering + // at a higher level than here anyways. + if (!anyOpLegalizerPatterns.empty()) { + for (const RewritePattern *pattern : patternWorklist) + legalizerPatterns[*pattern->getRootKind()].push_back(pattern); + return; + } + while (!patternWorklist.empty()) { auto *pattern = patternWorklist.pop_back_val(); @@ -1507,108 +1544,132 @@ // Otherwise, if all of the generated operation are valid, this op is now // legal so add all of the child patterns to the worklist. - legalizerPatterns[pattern->getRootKind()].push_back(pattern); - invalidPatterns[pattern->getRootKind()].erase(pattern); + legalizerPatterns[*pattern->getRootKind()].push_back(pattern); + invalidPatterns[*pattern->getRootKind()].erase(pattern); // Add any invalid patterns of the parent operations to see if they have now // become legal. - for (auto op : parentOps[pattern->getRootKind()]) + for (auto op : parentOps[*pattern->getRootKind()]) patternWorklist.set_union(invalidPatterns[op]); } } void OperationLegalizer::computeLegalizationGraphBenefit( + LegalizationPatterns &anyOpLegalizerPatterns, DenseMap &legalizerPatterns) { // The smallest pattern depth, when legalizing an operation. - DenseMap minPatternDepth; - - // Compute the minimum legalization depth for a given operation. - std::function computeDepth = [&](OperationName op) { - // Check for existing depth. - auto depthIt = minPatternDepth.find(op); - if (depthIt != minPatternDepth.end()) - return depthIt->second; - - // If a mapping for this operation does not exist, then this operation - // is always legal. Return 0 as the depth for a directly legal operation. - auto opPatternsIt = legalizerPatterns.find(op); - if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) - return 0u; - - // Initialize the depth to the maximum value. - unsigned minDepth = std::numeric_limits::max(); - - // Record this initial depth in case we encounter this op again when - // recursively computing the depth. - minPatternDepth.try_emplace(op, minDepth); - - // Compute the depth for each pattern used to legalize this operation. - SmallVector, 4> patternsByDepth; - patternsByDepth.reserve(opPatternsIt->second.size()); - for (const RewritePattern *pattern : opPatternsIt->second) { - unsigned depth = 0; - for (auto generatedOp : pattern->getGeneratedOps()) - depth = std::max(depth, computeDepth(generatedOp) + 1); - patternsByDepth.emplace_back(pattern, depth); - - // Update the min depth for this operation. - minDepth = std::min(minDepth, depth); - } - - // Update the pattern depth. - minPatternDepth[op] = minDepth; - - // If the operation only has one legalization pattern, there is no need to - // sort them. - if (patternsByDepth.size() == 1) - return minDepth; - - // Sort the patterns by those likely to be the most beneficial. - llvm::array_pod_sort( - patternsByDepth.begin(), patternsByDepth.end(), - [](const std::pair *lhs, - const std::pair *rhs) { - // First sort by the smaller pattern legalization depth. - if (lhs->second != rhs->second) - return llvm::array_pod_sort_comparator(&lhs->second, - &rhs->second); - - // Then sort by the larger pattern benefit. - auto lhsBenefit = lhs->first->getBenefit(); - auto rhsBenefit = rhs->first->getBenefit(); - return llvm::array_pod_sort_comparator(&rhsBenefit, - &lhsBenefit); - }); - - // Update the legalization pattern to use the new sorted list. - opPatternsIt->second.clear(); - for (auto &patternIt : patternsByDepth) - opPatternsIt->second.push_back(patternIt.first); - - return minDepth; - }; + DenseMap minOpPatternDepth; // For each operation that is transitively legal, compute a cost for it. for (auto &opIt : legalizerPatterns) - if (!minPatternDepth.count(opIt.first)) - computeDepth(opIt.first); + if (!minOpPatternDepth.count(opIt.first)) + computeOpLegalizationDepth(opIt.first, minOpPatternDepth, + legalizerPatterns); + + // Apply the cost model to the patterns that can match any operation. Those + // with a specific operation type are already resolved when computing the op + // legalization depth. + if (!anyOpLegalizerPatterns.empty()) + applyCostModelToPatterns(anyOpLegalizerPatterns, minOpPatternDepth, + legalizerPatterns); // Apply a cost model to the pattern applicator. We order patterns first by // depth then benefit. `legalizerPatterns` contains per-op patterns by // decreasing benefit. applicator.applyCostModel([&](const RewritePattern &p) { - auto &list = legalizerPatterns[p.getRootKind()]; + ArrayRef orderedPatternList; + if (Optional rootName = p.getRootKind()) + orderedPatternList = legalizerPatterns[*rootName]; + else + orderedPatternList = anyOpLegalizerPatterns; // If the pattern is not found, then it was removed and cannot be matched. - LegalizationPatterns::iterator it = llvm::find(list, &p); - if (it == list.end()) + auto it = llvm::find(orderedPatternList, &p); + if (it == orderedPatternList.end()) return PatternBenefit::impossibleToMatch(); // Patterns found earlier in the list have higher benefit. - return PatternBenefit(std::distance(it, list.end())); + return PatternBenefit(std::distance(it, orderedPatternList.end())); }); } +unsigned OperationLegalizer::computeOpLegalizationDepth( + OperationName op, DenseMap &minOpPatternDepth, + DenseMap &legalizerPatterns) { + // Check for existing depth. + auto depthIt = minOpPatternDepth.find(op); + if (depthIt != minOpPatternDepth.end()) + return depthIt->second; + + // If a mapping for this operation does not exist, then this operation + // is always legal. Return 0 as the depth for a directly legal operation. + auto opPatternsIt = legalizerPatterns.find(op); + if (opPatternsIt == legalizerPatterns.end() || opPatternsIt->second.empty()) + return 0u; + + // Record this initial depth in case we encounter this op again when + // recursively computing the depth. + minOpPatternDepth.try_emplace(op, std::numeric_limits::max()); + + // Apply the cost model to the operation patterns, and update the minimum + // depth. + unsigned minDepth = applyCostModelToPatterns( + opPatternsIt->second, minOpPatternDepth, legalizerPatterns); + minOpPatternDepth[op] = minDepth; + return minDepth; +} + +unsigned OperationLegalizer::applyCostModelToPatterns( + LegalizationPatterns &patterns, + DenseMap &minOpPatternDepth, + DenseMap &legalizerPatterns) { + unsigned minDepth = std::numeric_limits::max(); + + // Compute the depth for each pattern within the set. + SmallVector, 4> patternsByDepth; + patternsByDepth.reserve(patterns.size()); + for (const RewritePattern *pattern : patterns) { + unsigned depth = 0; + for (auto generatedOp : pattern->getGeneratedOps()) { + unsigned generatedOpDepth = computeOpLegalizationDepth( + generatedOp, minOpPatternDepth, legalizerPatterns); + depth = std::max(depth, generatedOpDepth + 1); + } + patternsByDepth.emplace_back(pattern, depth); + + // Update the minimum depth of the pattern list. + minDepth = std::min(minDepth, depth); + } + + // If the operation only has one legalization pattern, there is no need to + // sort them. + if (patternsByDepth.size() == 1) + return minDepth; + + // Sort the patterns by those likely to be the most beneficial. + llvm::array_pod_sort( + patternsByDepth.begin(), patternsByDepth.end(), + [](const std::pair *lhs, + const std::pair *rhs) { + // First sort by the smaller pattern legalization depth. + if (lhs->second != rhs->second) + return llvm::array_pod_sort_comparator(&lhs->second, + &rhs->second); + + // Then sort by the larger pattern benefit. + auto lhsBenefit = lhs->first->getBenefit(); + auto rhsBenefit = rhs->first->getBenefit(); + return llvm::array_pod_sort_comparator(&rhsBenefit, + &lhsBenefit); + }); + + // Update the legalization pattern to use the new sorted list. + patterns.clear(); + for (auto &patternIt : patternsByDepth) + patterns.push_back(patternIt.first); + return minDepth; +} + //===----------------------------------------------------------------------===// // OperationConverter //===----------------------------------------------------------------------===// diff --git a/mlir/test/Transforms/test-legalize-unknown-root.mlir b/mlir/test/Transforms/test-legalize-unknown-root.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-legalize-unknown-root.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt %s -test-legalize-unknown-root-patterns | FileCheck %s + +// Test that all `test` dialect operations are removed. +// CHECK-LABEL: func @remove_all_ops +func @remove_all_ops(%arg0: i32) { + // CHECK-NEXT: return + %0 = "test.illegal_op_a"() : () -> i32 + %1 = "test.illegal_op_b"() : () -> i32 + %2 = "test.illegal_op_c"() : () -> i32 + %3 = "test.illegal_op_d"() : () -> i32 + %4 = "test.illegal_op_e"() : () -> i32 + return +} diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -701,18 +701,50 @@ }; } // end anonymous namespace +//===----------------------------------------------------------------------===// +// Test patterns without a specific root operation kind +//===----------------------------------------------------------------------===// + +namespace { +/// This pattern matches and removes any operation in the test dialect. +struct RemoveTestDialectOps : public RewritePattern { + RemoveTestDialectOps() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (!isa(op->getDialect())) + return failure(); + rewriter.eraseOp(op); + return success(); + } +}; + +struct TestUnknownRootOpDriver + : public mlir::PassWrapper { + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + patterns.insert(); + + mlir::ConversionTarget target(getContext()); + target.addIllegalDialect(); + if (failed(applyPartialConversion(getFunction(), target, patterns))) + signalPassFailure(); + } +}; +} // end anonymous namespace + namespace mlir { void registerPatternsTestPass() { - mlir::PassRegistration("test-return-type", - "Run return type functions"); + PassRegistration("test-return-type", + "Run return type functions"); - mlir::PassRegistration( - "test-derived-attr", "Run test derived attributes"); + PassRegistration("test-derived-attr", + "Run test derived attributes"); - mlir::PassRegistration("test-patterns", - "Run test dialect patterns"); + PassRegistration("test-patterns", + "Run test dialect patterns"); - mlir::PassRegistration( + PassRegistration( "test-legalize-patterns", "Run test dialect legalization patterns", [] { return std::make_unique( legalizerConversionMode); @@ -721,5 +753,9 @@ PassRegistration( "test-remapped-value", "Test public remapped value mechanism in ConversionPatternRewriter"); + + PassRegistration( + "test-legalize-unknown-root-patterns", + "Test public remapped value mechanism in ConversionPatternRewriter"); } } // namespace mlir