diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -61,7 +61,8 @@ static constexpr int64_t kNoLimit = -1; /// Only ops within the scope are added to the worklist. If no scope is - /// specified, the closest enclosing region is used as a scope. + /// specified, the closest enclosing region around the initial list of ops + /// is used as a scope. Region *scope = nullptr; /// Strict mode can restrict the ops that are added to the worklist during diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -124,7 +124,6 @@ MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) { - assert(config.scope && "scope is not specified"); worklist.reserve(64); // Apply a simple cost model based solely on pattern benefit. @@ -266,19 +265,15 @@ void GreedyPatternRewriteDriver::addToWorklist(Operation *op) { // Gather potential ancestors while looking for a "scope" parent region. SmallVector ancestors; - ancestors.push_back(op); - while (Region *region = op->getParentRegion()) { - if (config.scope == region) { - // All gathered ops are in fact ancestors. - for (Operation *op : ancestors) - addSingleOpToWorklist(op); - break; - } - op = region->getParentOp(); - if (!op) - break; + do { ancestors.push_back(op); - } + if (config.scope == op->getParentRegion()) { + // Scope (can be `nullptr`) was reached. Stop traveral and enqueue ops. + for (Operation *op : ancestors) + addSingleOpToWorklist(op); + return; + } + } while ((op = op->getParentOp())); } void GreedyPatternRewriteDriver::addSingleOpToWorklist(Operation *op) { @@ -556,6 +551,9 @@ } /// Find the region that is the closest common ancestor of all given ops. +/// +/// Note: This function returns `nullptr` if there is a top-level op among the +/// given list of ops. static Region *findCommonAncestor(ArrayRef ops) { assert(!ops.empty() && "expected at least one op"); // Fast path in case there is only one op. @@ -566,7 +564,7 @@ ops = ops.drop_front(); int sz = ops.size(); llvm::BitVector remainingOps(sz, true); - do { + while (region) { int pos = -1; // Iterate over all remaining ops. while ((pos = remainingOps.find_first_in(pos + 1, sz)) != -1) { @@ -576,8 +574,8 @@ } if (remainingOps.none()) break; - } while ((region = region->getParentRegion())); - assert(region && "could not find common parent region"); + region = region->getParentRegion(); + } return region; } @@ -594,7 +592,8 @@ // Determine scope of rewrite. if (!config.scope) { - // Compute scope if none was provided. + // Compute scope if none was provided. The scope will remain `nullptr` if + // there is a top-level op among `ops`. config.scope = findCommonAncestor(ops); } else { // If a scope was provided, make sure that all ops are in scope.