diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -157,22 +157,21 @@ // Ignore the other module if it has no patterns. if (!other.pdlModule) return; + + // Steal the functions of the other module. + for (auto &it : other.constraintFunctions) + registerConstraintFunction(it.first(), std::move(it.second)); + for (auto &it : other.rewriteFunctions) + registerRewriteFunction(it.first(), std::move(it.second)); + // Steal the other state if we have no patterns. if (!pdlModule) { - constraintFunctions = std::move(other.constraintFunctions); - rewriteFunctions = std::move(other.rewriteFunctions); pdlModule = std::move(other.pdlModule); return; } - // Steal the functions of the other module. - for (auto &it : constraintFunctions) - registerConstraintFunction(it.first(), std::move(it.second)); - for (auto &it : rewriteFunctions) - registerRewriteFunction(it.first(), std::move(it.second)); // Merge the pattern operations from the other module into this one. Block *block = pdlModule->getBody(); - block->getTerminator()->erase(); block->getOperations().splice(block->end(), other.pdlModule->getBody()->getOperations()); } @@ -182,18 +181,20 @@ void PDLPatternModule::registerConstraintFunction( StringRef name, PDLConstraintFunction constraintFn) { - auto it = constraintFunctions.try_emplace(name, std::move(constraintFn)); - (void)it; - assert(it.second && - "constraint with the given name has already been registered"); + // TODO: Is it possible to diagnose when `name` is already registered to + // a function that is not equivalent to `constraintFn`? + // Allow existing mappings in the case multiple patterns depend on the same + // constraint. + constraintFunctions.try_emplace(name, std::move(constraintFn)); } void PDLPatternModule::registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn) { - auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn)); - (void)it; - assert(it.second && "native rewrite function with the given name has " - "already been registered"); + // TODO: Is it possible to diagnose when `name` is already registered to + // a function that is not equivalent to `rewriteFn`? + // Allow existing mappings in the case multiple patterns depend on the same + // rewrite. + rewriteFunctions.try_emplace(name, std::move(rewriteFn)); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp --- a/mlir/test/lib/Rewrite/TestPDLByteCode.cpp +++ b/mlir/test/lib/Rewrite/TestPDLByteCode.cpp @@ -87,13 +87,27 @@ if (!patternModule || !irModule) return; + RewritePatternSet patternList(module->getContext()); + + // Register ahead of time to test when functions are registered without a + // pattern. + patternList.getPDLPatterns().registerConstraintFunction( + "multi_entity_constraint", customMultiEntityConstraint); + patternList.getPDLPatterns().registerConstraintFunction( + "single_entity_constraint", customSingleEntityConstraint); + // Process the pattern module. patternModule.getOperation()->remove(); PDLPatternModule pdlPattern(patternModule); + + // Note: This constraint was already registered, but we re-register here to + // ensure that duplication registration is allowed (the duplicate mapping + // will be ignored). This tests that we support separating the registration + // of library functions from the construction of patterns, and also that we + // allow multiple patterns to depend on the same library functions (without + // asserting/crashing). pdlPattern.registerConstraintFunction("multi_entity_constraint", customMultiEntityConstraint); - pdlPattern.registerConstraintFunction("single_entity_constraint", - customSingleEntityConstraint); pdlPattern.registerConstraintFunction("multi_entity_var_constraint", customMultiEntityVariadicConstraint); pdlPattern.registerRewriteFunction("creator", customCreate); @@ -101,8 +115,7 @@ customVariadicResultCreate); pdlPattern.registerRewriteFunction("type_creator", customCreateType); pdlPattern.registerRewriteFunction("rewriter", customRewriter); - - RewritePatternSet patternList(std::move(pdlPattern)); + patternList.add(std::move(pdlPattern)); // Invoke the pattern driver with the provided patterns. (void)applyPatternsAndFoldGreedily(irModule.getBodyRegion(),