diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -17,6 +17,8 @@ #include "mlir/IR/Visitors.h" namespace mlir { +struct LogicalResult; + /// `Block` represents an ordered list of `Operation`s. class Block : public IRObjectWithUseList, public llvm::ilist_node_with_parent { @@ -129,6 +131,11 @@ /// probably in Dominance.cpp. Operation *findAncestorOpInBlock(Operation &op); + /// Sort the Operations in the provided range to enforce dominance. + /// This is useful after fusing / reorganizing Operations in a block and later + /// needing to readjust the ordering to ensure dominance. + static LogicalResult sortTopologically(iterator first_op, iterator last_op); + /// This drops all operand uses from operations within this block, which is /// an essential step in breaking cyclic dependences between references when /// they are to be deleted. diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -7,8 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/Block.h" + #include "mlir/IR/Builders.h" #include "mlir/IR/Operation.h" +#include "llvm/ADT/SmallPtrSet.h" + using namespace mlir; //===----------------------------------------------------------------------===// @@ -81,6 +84,52 @@ return currOp; } +LogicalResult Block::sortTopologically(iterator first_op, iterator last_op) { + Block *block = first_op->getParent(); + assert(block == last_op->getParent() && "ops must be in the same block"); + + // Track the ops that still need to be scheduled in a set. + SmallPtrSet unscheduled_ops; + for (Operation &op : llvm::make_range(first_op, last_op)) + unscheduled_ops.insert(&op); + + iterator last_scheduled_op = first_op; + while (!unscheduled_ops.empty()) { + bool scheduled_at_least_once = false; + // Loop over the ops that are not sorted yet, try to find the ones "ready", + // i.e. the ones for which there aren't any operand produced by an op in the + // set, and "schedule" it (move it before the last_scheduled_op). + for (Operation &op : llvm::make_range(last_scheduled_op, last_op)) { + WalkResult ready_to_schedule = op.walk([&](Operation *nested_op) { + if (llvm::all_of(nested_op->getOperands(), [&](Value operand) { + Operation *defining_op = operand.getDefiningOp(); + if (!defining_op) + return true; + Operation *producer_in_block = + block->findAncestorOpInBlock(*defining_op); + if (producer_in_block && producer_in_block != &op && + unscheduled_ops.count(producer_in_block)) + return false; + return true; + })) + return WalkResult::advance(); + return WalkResult::interrupt(); + }); + if (ready_to_schedule.wasInterrupted()) + continue; + unscheduled_ops.erase(&op); + if (iterator(op) != last_scheduled_op) + op.moveBefore(block, last_scheduled_op); + else + ++last_scheduled_op; + scheduled_at_least_once = true; + } + if (!scheduled_at_least_once) + return failure(); + } + return success(); +} + /// This drops all operand uses from operations within this block, which is /// an essential step in breaking cyclic dependences between references when /// they are to be deleted. diff --git a/mlir/test/IR/block-topological-sort.mlir b/mlir/test/IR/block-topological-sort.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/block-topological-sort.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt %s -test-block-toposort | FileCheck %s -dump-input-on-failure + +// CHECK-LABEL: no_data_dep +func @no_data_dep(%a : i32) { + %1 = "foo.foo"(%a) : (i32) -> i32 +// CHECK: _pre_order = 3 +// CHECK: _pre_order = 5 +// CHECK: _pre_order = 6 + %2 = "foo.foo"(%a) { _pre_order = 5 : i32 } : (i32) -> i32 + %3 = "foo.foo"(%a) { _pre_order = 6 : i32 } : (i32) -> i32 + %4 = "foo.foo"(%a) { _pre_order = 3 : i32 } : (i32) -> i32 + %5 = "foo.foo"(%a) : (i32) -> i32 +// CHECK: _pre_order = 0 +// CHECK: _pre_order = 1 + %6 = "foo.foo"(%a) { _pre_order = 1 : i32 } : (i32) -> i32 + %7 = "foo.foo"(%a) { _pre_order = 0 : i32 } : (i32) -> i32 + return +} + +// CHECK-LABEL: chained_op +func @chained_op(%a : i32) { + %1 = "foo.foo"(%a) : (i32) -> i32 +// Data dependency should prevent re-ordering here. + %2 = "foo.foo"(%a) { _pre_order = 5 : i32 } : (i32) -> i32 + %3 = "foo.foo"(%2) { _pre_order = 6 : i32 } : (i32) -> i32 + %4 = "foo.foo"(%3) { _pre_order = 3 : i32 } : (i32) -> i32 +// CHECK: _pre_order = 5 +// CHECK: _pre_order = 6 +// CHECK: _pre_order = 3 + return +} diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,4 +1,5 @@ add_llvm_library(MLIRTestIR + TestBlock.cpp TestFunc.cpp TestMatchers.cpp TestSymbolUses.cpp diff --git a/mlir/test/lib/IR/TestBlock.cpp b/mlir/test/lib/IR/TestBlock.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestBlock.cpp @@ -0,0 +1,66 @@ +//===- TestBlock.cpp - Pass to test helpers on Blocks ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Block.h" + +#include "mlir/Pass/Pass.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; + +#define DEBUG_TYPE "test-block-toposort" + +namespace { +/// This is a test pass for verifying Block's topologicalSort method. +struct TestTopologicalSort : public ModulePass { + void runOnModule() override { + auto module = getModule(); + module.walk([](Operation *op) { + for (Region ®ion : op->getRegions()) { + for (Block &block : region.getBlocks()) { + Block::iterator op_iterator = block.begin(); + while (op_iterator != block.end()) { + Operation *current_op = &*op_iterator; + ++op_iterator; + if (!current_op->getAttrOfType("_pre_order")) + continue; + LLVM_DEBUG(llvm::dbgs() + << "Found range beginning at: " << *current_op << "\n"); + SmallVector ops_to_sort{current_op}; + while (op_iterator != block.end() && + op_iterator->getAttrOfType("_pre_order")) { + LLVM_DEBUG(llvm::dbgs() << "- Adding : " << *op_iterator << "\n"); + ops_to_sort.push_back(&*op_iterator); + ++op_iterator; + } + llvm::stable_sort(ops_to_sort, [](Operation *lhs, Operation *rhs) { + int lhs_order = + lhs->getAttrOfType("_pre_order").getInt(); + int rhs_order = + rhs->getAttrOfType("_pre_order").getInt(); + return lhs_order < rhs_order; + }); + for (Operation *sorted_op : ops_to_sort) + sorted_op->moveBefore(&*op_iterator); + + if (failed(Block::sortTopologically( + Block::iterator(ops_to_sort.front()), op_iterator))) { + current_op->emitOpError("Cycle encountered"); + return WalkResult::interrupt(); + } + } + } + } + return WalkResult::advance(); + }); + } +}; +} // namespace + +static PassRegistration pass("test-block-toposort", + "Test erasing func args.");