diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Transforms/TopologicalSortUtils.h --- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h +++ b/mlir/include/mlir/Transforms/TopologicalSortUtils.h @@ -95,16 +95,13 @@ Block *block, function_ref isOperandReady = nullptr); -/// Compute a topological ordering of the given ops. All ops must belong to the -/// specified block. -/// -/// This sort is not stable. +/// Compute a topological ordering of the given ops. This sort is not stable. /// /// Note: If the specified ops contain incomplete/interrupted SSA use-def /// chains, the result may not actually be a topological sorting with respect to /// the entire program. bool computeTopologicalSorting( - Block *block, MutableArrayRef ops, + MutableArrayRef ops, function_ref isOperandReady = nullptr); } // end namespace mlir diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp --- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp +++ b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp @@ -12,12 +12,11 @@ using namespace mlir; /// Return `true` if the given operation is ready to be scheduled. -static bool isOpReady(Block *block, Operation *op, - DenseSet &unscheduledOps, +static bool isOpReady(Operation *op, DenseSet &unscheduledOps, function_ref isOperandReady) { // An operation is ready to be scheduled if all its operands are ready. An // operation is ready if: - const auto isReady = [&](Value value, Operation *top) { + const auto isReady = [&](Value value) { // - the user-provided callback marks it as ready, if (isOperandReady && isOperandReady(value, op)) return true; @@ -25,22 +24,24 @@ // - it is a block argument, if (!parent) return true; - Operation *ancestor = block->findAncestorOpInBlock(*parent); - // - it is an implicit capture, - if (!ancestor) - return true; - // - it is defined in a nested region, or - if (ancestor == op) - return true; - // - its ancestor in the block is scheduled. - return !unscheduledOps.contains(ancestor); + // - or it is not defined by an unscheduled op (and also not nested within + // an unscheduled op). + do { + // Stop traversal when op under examination is reached. + if (parent == op) + return true; + if (unscheduledOps.contains(parent)) + return false; + } while ((parent = parent->getParentOp())); + // No unscheduled op found. + return true; }; // An operation is recursively ready to be scheduled of it and its nested // operations are ready. WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) { return llvm::all_of(nestedOp->getOperands(), - [&](Value operand) { return isReady(operand, op); }) + [&](Value operand) { return isReady(operand); }) ? WalkResult::advance() : WalkResult::interrupt(); }); @@ -71,7 +72,7 @@ // set, and "schedule" it (move it before the `nextScheduledOp`). for (Operation &op : llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) { - if (!isOpReady(block, &op, unscheduledOps, isOperandReady)) + if (!isOpReady(&op, unscheduledOps, isOperandReady)) continue; // Schedule the operation by moving it to the start. @@ -104,7 +105,7 @@ } bool mlir::computeTopologicalSorting( - Block *block, MutableArrayRef ops, + MutableArrayRef ops, function_ref isOperandReady) { if (ops.empty()) return true; @@ -113,10 +114,8 @@ DenseSet unscheduledOps; // Mark all operations as unscheduled. - for (Operation *op : ops) { - assert(op->getBlock() == block && "op must belong to block"); + for (Operation *op : ops) unscheduledOps.insert(op); - } unsigned nextScheduledOp = 0; @@ -128,7 +127,7 @@ // i.e. the ones for which there aren't any operand produced by an op in the // set, and "schedule" it (swap it with the op at `nextScheduledOp`). for (unsigned i = nextScheduledOp; i < ops.size(); ++i) { - if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady)) + if (!isOpReady(ops[i], unscheduledOps, isOperandReady)) continue; // Schedule the operation by moving it to the start. diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Transforms/test-toposort.mlir --- a/mlir/test/Transforms/test-toposort.mlir +++ b/mlir/test/Transforms/test-toposort.mlir @@ -36,6 +36,26 @@ %3 = "test.d"() {selected} : () -> i32 } +// CHECK-LABEL: func @test_multiple_blocks +// CHECK-ANALYSIS-LABEL: func @test_multiple_blocks +func.func @test_multiple_blocks() -> (i32) attributes{"root"} { + // CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0 + %0 = "test.foo"() {selected} : () -> (i32) + // CHECK-ANALYSIS-NEXT: test.foo + %1 = "test.foo"() : () -> (i32) + cf.br ^bb0 +^bb0: + // CHECK-ANALYSIS: test.foo{{.*}} {pos = 1 + %2 = "test.foo"() {selected} : () -> (i32) + // CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2 + %3 = "test.bar"(%0, %1, %2) {selected} : (i32, i32, i32) -> (i32) + cf.br ^bb1 (%2 : i32) +^bb1(%arg0: i32): + // CHECK-ANALYSIS: test.qux{{.*}} {pos = 3 + %4 = "test.qux"(%arg0, %0) {selected} : (i32, i32) -> (i32) + return %4 : i32 +} + // Test block arguments. // CHECK-LABEL: test.graph_region test.graph_region { diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Transforms/TestTopologicalSort.cpp --- a/mlir/test/lib/Transforms/TestTopologicalSort.cpp +++ b/mlir/test/lib/Transforms/TestTopologicalSort.cpp @@ -30,25 +30,29 @@ Operation *op = getOperation(); OpBuilder builder(op->getContext()); - op->walk([&](Operation *root) { + WalkResult result = op->walk([&](Operation *root) { if (!root->hasAttr("root")) return WalkResult::advance(); - assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() && - "expected one block"); - Block *block = &root->getRegion(0).front(); SmallVector selectedOps; - block->walk([&](Operation *op) { - if (op->hasAttr("selected")) - selectedOps.push_back(op); + root->walk([&](Operation *selected) { + if (selected->hasAttr("selected")) + selectedOps.push_back(selected); }); - computeTopologicalSorting(block, selectedOps); + if (!computeTopologicalSorting(selectedOps)) { + root->emitError("could not schedule all ops"); + return WalkResult::skip(); + } + for (const auto &it : llvm::enumerate(selectedOps)) it.value()->setAttr("pos", builder.getIndexAttr(it.index())); return WalkResult::advance(); }); + + if (result.wasSkipped()) + signalPassFailure(); } }; } // namespace