diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -92,6 +92,11 @@ std::unique_ptr createSymbolPrivatizePass(ArrayRef excludeSymbols = {}); +/// Creates a pass that recursively sorts nested regions without SSA dominance +/// topologically such that, as much as possible, users of values appear after +/// their producers. +std::unique_ptr createTopologicalSortPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -260,4 +260,17 @@ let constructor = "mlir::createPrintOpGraphPass()"; } +def TopologicalSort : Pass<"topological-sort"> { + let summary = "Sort regions without SSA dominance in topological order"; + let description = [{ + Recursively sorts all nested regions without SSA dominance in topological + order. The main purpose is readability, as well as potentially processing of + certain transformations and analyses. The function sorts the operations in + all nested regions such that, as much as possible, all users appear after + their producers. + }]; + + let constructor = "mlir::createTopologicalSortPass()"; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/include/mlir/Transforms/TopologicalSortUtils.h b/mlir/include/mlir/Transforms/TopologicalSortUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Transforms/TopologicalSortUtils.h @@ -0,0 +1,79 @@ +//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H +#define MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H + +#include "mlir/IR/Block.h" + +namespace mlir { + +/// Given a block without SSA dominance, sort the range operations in +/// topological order. The main purpose is readability, as well as potentially +/// faster processing of certain transformations and analyses. The function +/// sorts the given operations such that, as much as possible, all users appear +/// after their producers. +/// +/// For example: +/// +/// ```mlir +/// %0 = test.foo +/// %1 = test.bar %0, %2 +/// %2 = test.baz +/// ``` +/// +/// Will become: +/// +/// ```mlir +/// %0 = test.foo +/// %1 = test.baz +/// %2 = test.bar %0, %1 +/// ``` +/// +/// The sort also works on operations with regions and implicit captures. For +/// example: +/// +/// ```mlir +/// %0 = test.foo { +/// test.baz %1 +/// %1 = test.bar %2 +/// } +/// %2 = test.foo +/// ``` +/// +/// Will become: +/// +/// ```mlir +/// %0 = test.foo +/// %1 = test.foo { +/// test.baz %2 +/// %2 = test.bar %0 +/// } +/// ``` +/// +/// Note that the sort is not recursive on nested regions. +/// +/// If the sort is left with only operations that form a cycle, it breaks the +/// cycle by marking the first encountered operation as ready and moving on. +/// +/// The function optionally accepts a callback that can be provided by users to +/// virtually break cycles early. +bool sortTopologically( + Block *block, iterator_range ops, + function_ref isOperandReady = + [](Value, Operation *) { return false; }); + +/// Given a block without SSA dominance, sort its operations in order, excluding +/// its terminator if it has one. +bool sortTopologically( + Block *block, function_ref isOperandReady = + [](Value, Operation *) { return false; }); + +} // end namespace mlir + +#endif // MLIR_TRANSFORMS_TOPOLOGICALSORTUTILS_H diff --git a/mlir/lib/Transforms/TopologicalSort.cpp b/mlir/lib/Transforms/TopologicalSort.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/TopologicalSort.cpp @@ -0,0 +1,38 @@ +//===- TopologicalSort.cpp - Topological sort pass ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/RegionKindInterface.h" +#include "mlir/Transforms/TopologicalSortUtils.h" + +using namespace mlir; + +/// Topologically sort the regions of the operation without SSA dominance. +static void sortRegionsTopologically(Operation *op) { + auto regionInterface = dyn_cast(op); + if (!regionInterface) + return; + for (auto &it : llvm::enumerate(op->getRegions())) { + if (regionInterface.hasSSADominance(it.index())) + continue; + for (Block &block : it.value()) + sortTopologically(&block); + } +} + +namespace { +struct TopologicalSortPass : public TopologicalSortBase { + void runOnOperation() override { + getOperation()->walk(sortRegionsTopologically); + } +}; +} // end anonymous namespace + +std::unique_ptr mlir::createTopologicalSortPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/Utils/TopologicalSortUtils.cpp @@ -0,0 +1,98 @@ +//===- TopologicalSortUtils.h - Topological sort utilities ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/TopologicalSortUtils.h" +#include "mlir/IR/OpDefinition.h" + +using namespace mlir; + +bool mlir::sortTopologically( + Block *block, llvm::iterator_range ops, + function_ref isOperandReady) { + if (ops.empty()) + return true; + + // The set of operations that have not yet been scheduled. + DenseSet unscheduledOps; + // Mark all operations as unscheduled. + for (Operation &op : ops) + unscheduledOps.insert(&op); + + Block::iterator nextScheduledOp = ops.begin(); + Block::iterator end = ops.end(); + + // An operation is ready to be scheduled if all its operands are ready. An + // operation is ready if: + const auto isReady = [&](Value value, Operation *top) { + // - the user-provided callback marks it as ready, + if (isOperandReady(value, top)) + return true; + Operation *parent = value.getDefiningOp(); + // - it is a block argument, + if (!parent) + return true; + Operation *ancestor = block->findAncestorOpInBlock(*parent); + // - it is an implicit capture, + if (!ancestor) + return true; + // - it is defined in a nested region, or + if (ancestor == top) + return true; + // - its ancestor in the block is scheduled. + return !unscheduledOps.contains(ancestor); + }; + + bool allOpsScheduled = true; + while (!unscheduledOps.empty()) { + bool scheduledAtLeastOnce = false; + + // Loop over the ops that are not sorted yet, try to find the ones "ready", + // i.e. the ones for which there aren't any operand produced by an op in the + // set, and "schedule" it (move it before the `nextScheduledOp`). + for (Operation &op : + llvm::make_early_inc_range(llvm::make_range(nextScheduledOp, end))) { + // An operation is recursively ready to be scheduled of it and its nested + // operations are ready. + WalkResult readyToSchedule = op.walk([&](Operation *nestedOp) { + return llvm::all_of( + nestedOp->getOperands(), + [&](Value operand) { return isReady(operand, &op); }) + ? WalkResult::advance() + : WalkResult::interrupt(); + }); + if (readyToSchedule.wasInterrupted()) + continue; + + // Schedule the operation by moving it to the start. + unscheduledOps.erase(&op); + op.moveBefore(block, nextScheduledOp); + scheduledAtLeastOnce = true; + // Move the iterator forward if we schedule the operation at the front. + if (&op == &*nextScheduledOp) + ++nextScheduledOp; + } + // If no operations were scheduled, give up and advance the iterator. + if (!scheduledAtLeastOnce) { + allOpsScheduled = false; + unscheduledOps.erase(&*nextScheduledOp); + ++nextScheduledOp; + } + } + + return allOpsScheduled; +} + +bool mlir::sortTopologically( + Block *block, function_ref isOperandReady) { + if (block->empty()) + return true; + if (block->back().hasTrait()) + return sortTopologically(block, block->without_terminator(), + isOperandReady); + return sortTopologically(block, *block, isOperandReady); +} diff --git a/mlir/test/Transforms/test-toposort.mlir b/mlir/test/Transforms/test-toposort.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/test-toposort.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt -topological-sort %s | FileCheck %s + +// Test producer is after user. +// CHECK-LABEL: test.graph_region +test.graph_region { + // CHECK-NEXT: test.foo + // CHECK-NEXT: test.baz + // CHECK-NEXT: test.bar + %0 = "test.foo"() : () -> i32 + "test.bar"(%1, %0) : (i32, i32) -> () + %1 = "test.baz"() : () -> i32 +} + +// Test cycles. +// CHECK-LABEL: test.graph_region +test.graph_region { + // CHECK-NEXT: test.d + // CHECK-NEXT: test.a + // CHECK-NEXT: test.c + // CHECK-NEXT: test.b + %2 = "test.c"(%1) : (i32) -> i32 + %1 = "test.b"(%0, %2) : (i32, i32) -> i32 + %0 = "test.a"(%3) : (i32) -> i32 + %3 = "test.d"() : () -> i32 +} + +// Test block arguments. +// CHECK-LABEL: test.graph_region +test.graph_region { +// CHECK-NEXT: (%{{.*}}: +^entry(%arg0: i32): + // CHECK-NEXT: test.foo + // CHECK-NEXT: test.baz + // CHECK-NEXT: test.bar + %0 = "test.foo"(%arg0) : (i32) -> i32 + "test.bar"(%1, %0) : (i32, i32) -> () + %1 = "test.baz"(%arg0) : (i32) -> i32 +} + +// Test implicit block capture (and sort nested region). +// CHECK-LABEL: test.graph_region +func.func @test_graph_cfg() -> () { + %0 = "test.foo"() : () -> i32 + cf.br ^next(%0 : i32) + +^next(%1: i32): + test.graph_region { + // CHECK-NEXT: test.foo + // CHECK-NEXT: test.baz + // CHECK-NEXT: test.bar + %2 = "test.foo"(%1) : (i32) -> i32 + "test.bar"(%3, %2) : (i32, i32) -> () + %3 = "test.baz"(%0) : (i32) -> i32 + } + return +} + +// Test region ops (and recursive sort). +// CHECK-LABEL: test.graph_region +test.graph_region { + // CHECK-NEXT: test.baz + // CHECK-NEXT: test.graph_region attributes {a} { + // CHECK-NEXT: test.b + // CHECK-NEXT: test.a + // CHECK-NEXT: } + // CHECK-NEXT: test.bar + // CHECK-NEXT: test.foo + %0 = "test.foo"(%1) : (i32) -> i32 + test.graph_region attributes {a} { + %a = "test.a"(%b) : (i32) -> i32 + %b = "test.b"(%2) : (i32) -> i32 + } + %1 = "test.bar"(%2) : (i32) -> i32 + %2 = "test.baz"() : () -> i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -723,17 +723,6 @@ // Test GraphRegionOp //===----------------------------------------------------------------------===// -ParseResult GraphRegionOp::parse(OpAsmParser &parser, OperationState &result) { - // Parse the body region, and reuse the operand info as the argument info. - Region *body = result.addRegion(); - return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}); -} - -void GraphRegionOp::print(OpAsmPrinter &p) { - p << "test.graph_region "; - p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); -} - RegionKind GraphRegionOp::getRegionKind(unsigned index) { return RegionKind::Graph; } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1790,7 +1790,7 @@ }]; let regions = (region AnyRegion:$region); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = "attr-dict-with-keyword $region"; } def AffineScopeOp : TEST_Op<"affine_scope", [AffineScope]> {