diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Transforms/TopologicalSortUtils.h --- a/mlir/include/mlir/Transforms/TopologicalSortUtils.h +++ b/mlir/include/mlir/Transforms/TopologicalSortUtils.h @@ -90,11 +90,23 @@ function_ref isOperandReady = nullptr); /// Given a block, sort its operations in topological order, excluding its -/// terminator if it has one. +/// terminator if it has one. This sort is stable. bool sortTopologically( Block *block, function_ref isOperandReady = nullptr); +/// Compute a topological ordering of the given ops. All ops must belong to the +/// specified block. +/// +/// This sort is not stable. +/// +/// Note: If the specified ops contain incomplete/interrupted SSA use-def +/// chains, the result may not actually be a topological sorting with respect to +/// the entire program. +bool computeTopologicalSorting( + Block *block, MutableArrayRef ops, + function_ref isOperandReady = nullptr); + } // end namespace mlir #endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp --- a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp +++ b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp @@ -8,29 +8,19 @@ #include "mlir/Transforms/TopologicalSortUtils.h" #include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/SetVector.h" using namespace mlir; -bool mlir::sortTopologically( - Block *block, llvm::iterator_range ops, - function_ref isOperandReady) { - if (ops.empty()) - return true; - - // The set of operations that have not yet been scheduled. - DenseSet unscheduledOps; - // Mark all operations as unscheduled. - for (Operation &op : ops) - unscheduledOps.insert(&op); - - Block::iterator nextScheduledOp = ops.begin(); - Block::iterator end = ops.end(); - +/// Return `true` if the given operation is ready to be scheduled. +static bool isOpReady(Block *block, Operation *op, + DenseSet &unscheduledOps, + function_ref isOperandReady) { // An operation is ready to be scheduled if all its operands are ready. An // operation is ready if: const auto isReady = [&](Value value, Operation *top) { // - the user-provided callback marks it as ready, - if (isOperandReady && isOperandReady(value, top)) + if (isOperandReady && isOperandReady(value, op)) return true; Operation *parent = value.getDefiningOp(); // - it is a block argument, @@ -41,12 +31,38 @@ if (!ancestor) return true; // - it is defined in a nested region, or - if (ancestor == top) + if (ancestor == op) return true; // - its ancestor in the block is scheduled. return !unscheduledOps.contains(ancestor); }; + // An operation is recursively ready to be scheduled of it and its nested + // operations are ready. + WalkResult readyToSchedule = op->walk([&](Operation *nestedOp) { + return llvm::all_of(nestedOp->getOperands(), + [&](Value operand) { return isReady(operand, op); }) + ? WalkResult::advance() + : WalkResult::interrupt(); + }); + return !readyToSchedule.wasInterrupted(); +} + +bool mlir::sortTopologically( + Block *block, llvm::iterator_range ops, + function_ref isOperandReady) { + if (ops.empty()) + return true; + + // The set of operations that have not yet been scheduled. + DenseSet unscheduledOps; + // Mark all operations as unscheduled. + for (Operation &op : ops) + unscheduledOps.insert(&op); + + Block::iterator nextScheduledOp = ops.begin(); + Block::iterator end = ops.end(); + bool allOpsScheduled = true; while (!unscheduledOps.empty()) { bool scheduledAtLeastOnce = false; @@ -56,16 +72,7 @@ // set, and "schedule" it (move it before the `nextScheduledOp`). for (Operation &op : llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) { - // An operation is recursively ready to be scheduled of it and its nested - // operations are ready. - WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) { - return llvm::all_of( - nestedOp->getOperands(), - [&](Value operand) { return isReady(operand, &op); }) - ? WalkResult::advance() - : WalkResult::interrupt(); - }); - if (readyToSchedule.wasInterrupted()) + if (!isOpReady(block, &op, unscheduledOps, isOperandReady)) continue; // Schedule the operation by moving it to the start. @@ -96,3 +103,48 @@ isOperandReady); return sortTopologically(block, *block, isOperandReady); } + +bool mlir::computeTopologicalSorting( + Block *block, MutableArrayRef ops, + function_ref isOperandReady) { + if (ops.empty()) + return true; + + // The set of operations that have not yet been scheduled. + DenseSet unscheduledOps; + + // Mark all operations as unscheduled. + for (Operation *op : ops) { + assert(op->getBlock() == block && "op must belong to block"); + unscheduledOps.insert(op); + } + + unsigned nextScheduledOp = 0; + + bool allOpsScheduled = true; + while (!unscheduledOps.empty()) { + bool scheduledAtLeastOnce = false; + + // Loop over the ops that are not sorted yet, try to find the ones "ready", + // i.e. the ones for which there aren't any operand produced by an op in the + // set, and "schedule" it (swap it with the op at `nextScheduledOp`). + for (unsigned i = nextScheduledOp; i < ops.size(); ++i) { + if (!isOpReady(block, ops[i], unscheduledOps, isOperandReady)) + continue; + + // Schedule the operation by moving it to the start. + unscheduledOps.erase(ops[i]); + std::swap(ops[i], ops[nextScheduledOp]); + scheduledAtLeastOnce = true; + ++nextScheduledOp; + } + + // If no operations were scheduled, just schedule the first op and continue. + if (!scheduledAtLeastOnce) { + allOpsScheduled = false; + unscheduledOps.erase(ops[nextScheduledOp++]); + } + } + + return allOpsScheduled; +} diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Transforms/test-toposort.mlir --- a/mlir/test/Transforms/test-toposort.mlir +++ b/mlir/test/Transforms/test-toposort.mlir @@ -1,27 +1,39 @@ // RUN: mlir-opt -topological-sort %s | FileCheck %s +// RUN: mlir-opt -test-topological-sort-analysis %s | FileCheck %s -check-prefix=CHECK-ANALYSIS // Test producer is after user. // CHECK-LABEL: test.graph_region -test.graph_region { +// CHECK-ANALYSIS-LABEL: test.graph_region +test.graph_region attributes{"root"} { // CHECK-NEXT: test.foo // CHECK-NEXT: test.baz // CHECK-NEXT: test.bar - %0 = "test.foo"() : () -> i32 - "test.bar"(%1, %0) : (i32, i32) -> () - %1 = "test.baz"() : () -> i32 + + // CHECK-ANALYSIS-NEXT: test.foo{{.*}} {pos = 0 + // CHECK-ANALYSIS-NEXT: test.bar{{.*}} {pos = 2 + // CHECK-ANALYSIS-NEXT: test.baz{{.*}} {pos = 1 + %0 = "test.foo"() {selected} : () -> i32 + "test.bar"(%1, %0) {selected} : (i32, i32) -> () + %1 = "test.baz"() {selected} : () -> i32 } // Test cycles. // CHECK-LABEL: test.graph_region -test.graph_region { +// CHECK-ANALYSIS-LABEL: test.graph_region +test.graph_region attributes{"root"} { // CHECK-NEXT: test.d // CHECK-NEXT: test.a // CHECK-NEXT: test.c // CHECK-NEXT: test.b - %2 = "test.c"(%1) : (i32) -> i32 + + // CHECK-ANALYSIS-NEXT: test.c{{.*}} {pos = 0 + // CHECK-ANALYSIS-NEXT: test.b{{.*}} : ( + // CHECK-ANALYSIS-NEXT: test.a{{.*}} {pos = 2 + // CHECK-ANALYSIS-NEXT: test.d{{.*}} {pos = 1 + %2 = "test.c"(%1) {selected} : (i32) -> i32 %1 = "test.b"(%0, %2) : (i32, i32) -> i32 - %0 = "test.a"(%3) : (i32) -> i32 - %3 = "test.d"() : () -> i32 + %0 = "test.a"(%3) {selected} : (i32) -> i32 + %3 = "test.d"() {selected} : () -> i32 } // Test block arguments. diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -5,6 +5,7 @@ TestControlFlowSink.cpp TestInlining.cpp TestIntRangeInference.cpp + TestTopologicalSort.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Transforms/TestTopologicalSort.cpp b/mlir/test/lib/Transforms/TestTopologicalSort.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestTopologicalSort.cpp @@ -0,0 +1,62 @@ +//===- TestTopologicalSort.cpp - Pass to test topological sort analysis ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/TopologicalSortUtils.h" + +using namespace mlir; + +namespace { +struct TestTopologicalSortAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTopologicalSortAnalysisPass) + + StringRef getArgument() const final { + return "test-topological-sort-analysis"; + } + StringRef getDescription() const final { + return "Test topological sorting of ops"; + } + + void runOnOperation() override { + Operation *op = getOperation(); + OpBuilder builder(op->getContext()); + + op->walk([&](Operation *root) { + if (!root->hasAttr("root")) + return WalkResult::advance(); + + assert(root->getNumRegions() == 1 && root->getRegion(0).hasOneBlock() && + "expected one block"); + Block *block = &root->getRegion(0).front(); + SmallVector selectedOps; + block->walk([&](Operation *op) { + if (op->hasAttr("selected")) + selectedOps.push_back(op); + }); + + computeTopologicalSorting(block, selectedOps); + for (const auto &it : llvm::enumerate(selectedOps)) + it.value()->setAttr("pos", builder.getIndexAttr(it.index())); + + return WalkResult::advance(); + }); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestTopologicalSortAnalysisPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -111,6 +111,7 @@ void registerTestSliceAnalysisPass(); void registerTestTensorTransforms(); void registerTestTilingInterface(); +void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectInterpreterPass(); void registerTestVectorLowerings(); void registerTestNvgpuLowerings(); @@ -207,6 +208,7 @@ mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorTransforms(); mlir::test::registerTestTilingInterface(); + mlir::test::registerTestTopologicalSortAnalysisPass(); mlir::test::registerTestTransformDialectInterpreterPass(); mlir::test::registerTestVectorLowerings(); mlir::test::registerTestNvgpuLowerings();