diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -510,10 +510,40 @@ /// }); template > - RetT walk(FnT &&callback) { + typename std::enable_if< + llvm::function_traits>::num_args == 1, RetT>::type + walk(FnT &&callback) { return detail::walk(this, std::forward(callback)); } + /// Generic walker with a stage aware callback. Walk the operation by calling + /// the callback for each nested operation (including this one) N+1 times, + /// where N is the number of regions attached to that operation. + /// + /// The callback method can take any of the following forms: + /// void(Operation *, const WalkStage &) : Walk all operation opaquely + /// * op->walk([](Operation *nestedOp, const WalkStage &stage) { ...}); + /// void(OpT, const WalkStage &) : Walk all operations of the given derived + /// type. + /// * op->walk([](ReturnOp returnOp, const WalkStage &stage) { ...}); + /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations, + /// but allow for interruption/skipping. + /// * op->walk([](... op, const WalkStage &stage) { + /// // Skip the walk of this op based on some invariant. + /// if (some_invariant) + /// return WalkResult::skip(); + /// // Interrupt, i.e cancel, the walk based on some invariant. + /// if (another_invariant) + /// return WalkResult::interrupt(); + /// return WalkResult::advance(); + /// }); + template > + typename std::enable_if< + llvm::function_traits>::num_args == 2, RetT>::type + walk(FnT &&callback) { + return detail::walk(this, std::forward(callback)); + } + //===--------------------------------------------------------------------===// // Uses //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/Visitors.h b/mlir/include/mlir/IR/Visitors.h --- a/mlir/include/mlir/IR/Visitors.h +++ b/mlir/include/mlir/IR/Visitors.h @@ -61,13 +61,49 @@ /// Traversal order for region, block and operation walk utilities. enum class WalkOrder { PreOrder, PostOrder }; +/// A utility class to encode the current walk stage for "generic" walkers. +/// When walking an operation, we can either choose a Pre/Post order walker +/// which invokes the callback on an operation before/after all its attached +/// regions have been visited, or choose a "generic" walker where the callback +/// is invoked on the operation N+1 times where N is the number of regions +/// attached to that operation. The `WalkStage` class below encodes the current +/// stage of the walk, i.e., which regions have already been visited, and the +/// callback accepts an additional argument for the current stage. Such +/// generic walkers that accept stage-aware callbacks are only applicable when +/// the callback operates on an operation (i.e., not applicable for callbacks +/// on Blocks or Regions). +class WalkStage { +public: + explicit WalkStage(Operation *op); + + /// Return true if parent operation is being visited before all regions. + bool isBeforeAllRegions() const { return nextRegion == 0; } + /// Returns true if parent operation is being visited just before visiting + /// region number `region`. + bool isBeforeRegion(int region) const { return nextRegion == region; } + /// Returns true if parent operation is being visited just after visiting + /// region number `region`. + bool isAfterRegion(int region) const { return nextRegion == region + 1; } + /// Return true if parent operation is being visited after all regions. + bool isAfterAllRegions() const { return nextRegion == numRegions; } + /// Advance the walk stage. + void advance() { nextRegion++; } + /// Returns the next region that will be visited. + int getNextRegion() const { return nextRegion; } + +private: + const int numRegions; + int nextRegion; +}; + namespace detail { /// Helper templates to deduce the first argument of a callback parameter. -template Arg first_argument_type(Ret (*)(Arg)); -template -Arg first_argument_type(Ret (F::*)(Arg)); -template -Arg first_argument_type(Ret (F::*)(Arg) const); +template +Arg first_argument_type(Ret (*)(Arg, Rest...)); +template +Arg first_argument_type(Ret (F::*)(Arg, Rest...)); +template +Arg first_argument_type(Ret (F::*)(Arg, Rest...) const); template decltype(first_argument_type(&F::operator())) first_argument_type(F); @@ -197,6 +233,87 @@ return detail::walk(op, function_ref(wrapperFn), Order); } +/// Generic walkers with stage aware callbacks. + +/// Walk all the operations nested under (and including) the given operation, +/// with the callback being invoked on each operation N+1 times, where N is the +/// number of regions attached to the operation. The `stage` input to the +/// callback indicates the current walk stage. This method is invoked for void +/// returning callbacks. +void walk(Operation *op, + function_ref callback); + +/// Walk all the operations nested under (and including) the given operation, +/// with the callback being invoked on each operation N+1 times, where N is the +/// number of regions attached to the operation. The `stage` input to the +/// callback indicates the current walk stage. This method is invoked for +/// skippable or interruptible callbacks. +WalkResult +walk(Operation *op, + function_ref callback); + +/// Walk all of the operations nested under and including the given operation. +/// This method is selected for stage-aware callbacks that operate on +/// Operation*. +/// +/// Example: +/// op->walk([](Operation *op, const WalkStage &stage) { ... }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value, RetT>::type +walk(Operation *op, FuncTy &&callback) { + return detail::walk(op, + function_ref(callback)); +} + +/// Walk all of the operations of type 'ArgT' nested under and including the +/// given operation. This method is selected for void returning callbacks that +/// operate on a specific derived operation type. +/// +/// Example: +/// op->walk([](ReturnOp op, const WalkStage &stage) { ... }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value && + std::is_same::value, + RetT>::type +walk(Operation *op, FuncTy &&callback) { + auto wrapperFn = [&](Operation *op, const WalkStage &stage) { + if (auto derivedOp = dyn_cast(op)) + callback(derivedOp, stage); + }; + return detail::walk( + op, function_ref(wrapperFn)); +} + +/// Walk all of the operations of type 'ArgT' nested under and including the +/// given operation. This method is selected for WalkReturn returning +/// interruptible callbacks that operate on a specific derived operation type. +/// +/// Example: +/// op->walk(op, [](ReturnOp op, const WalkStage &stage) { +/// if (some_invariant) +/// return WalkResult::interrupt(); +/// return WalkResult::advance(); +/// }); +template , + typename RetT = decltype(std::declval()( + std::declval(), std::declval()))> +typename std::enable_if::value && + std::is_same::value, + RetT>::type +walk(Operation *op, FuncTy &&callback) { + auto wrapperFn = [&](Operation *op, const WalkStage &stage) { + if (auto derivedOp = dyn_cast(op)) + return callback(derivedOp, stage); + return WalkResult::advance(); + }; + return detail::walk( + op, function_ref(wrapperFn)); +} + /// Utility to provide the return type of a templated walk method. template using walkResultType = decltype(walk(nullptr, std::declval())); diff --git a/mlir/lib/IR/Visitors.cpp b/mlir/lib/IR/Visitors.cpp --- a/mlir/lib/IR/Visitors.cpp +++ b/mlir/lib/IR/Visitors.cpp @@ -11,6 +11,9 @@ using namespace mlir; +WalkStage::WalkStage(Operation *op) + : numRegions(op->getNumRegions()), nextRegion(0) {} + /// Walk all of the regions/blocks/operations nested under and including the /// given operation. Regions, blocks and operations at the same nesting level /// are visited in lexicographical order. The walk order for enclosing regions, @@ -67,6 +70,25 @@ callback(op); } +void detail::walk(Operation *op, + function_ref callback) { + WalkStage stage(op); + + for (Region ®ion : op->getRegions()) { + // Invoke callback on the parent op before visiting each child region. + callback(op, stage); + stage.advance(); + + for (Block &block : region) { + for (Operation &nestedOp : block) + walk(&nestedOp, callback); + } + } + + // Invoke callback after all regions have been visited. + callback(op, stage); +} + /// Walk all of the regions/blocks/operations nested under and including the /// given operation. These functions walk operations until an interrupt result /// is returned by the callback. Walks on regions, blocks and operations may @@ -157,3 +179,29 @@ return callback(op); return WalkResult::advance(); } + +WalkResult detail::walk( + Operation *op, + function_ref callback) { + WalkStage stage(op); + + for (Region ®ion : op->getRegions()) { + // Invoke callback on the parent op before visiting each child region. + WalkResult result = callback(op, stage); + + if (result.wasSkipped()) + return WalkResult::advance(); + if (result.wasInterrupted()) + return WalkResult::interrupt(); + + stage.advance(); + + for (Block &block : region) { + // Early increment here in the case where the operation is erased. + for (Operation &nestedOp : llvm::make_early_inc_range(block)) + if (walk(&nestedOp, callback).wasInterrupted()) + return WalkResult::interrupt(); + } + } + return callback(op, stage); +} diff --git a/mlir/test/IR/generic-visitors-interrupt.mlir b/mlir/test/IR/generic-visitors-interrupt.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/generic-visitors-interrupt.mlir @@ -0,0 +1,157 @@ +// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// Walk is interrupted before visiting "foo" +func @main(%arg0: f32) -> f32 { + %v1 = "foo"() {interrupt_before_all = true} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 walk was interrupted + +// ----- + +// Walk is interrupted after visiting "foo" (which has a single empty region) +func @main(%arg0: f32) -> f32 { + %v1 = "foo"() ({ "bar"() : ()-> () }) {interrupt_after_all = true} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'foo' before all regions +// CHECK: step 3 op 'bar' before all regions +// CHECK: step 4 walk was interrupted + +// ----- + +// Walk is interrupted after visiting "foo"'s 1st region. +func @main(%arg0: f32) -> f32 { + %v1 = "foo"() ({ + "bar0"() : () -> () + }, { + "bar1"() : () -> () + }) {interrupt_after_region = 0} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'foo' before all regions +// CHECK: step 3 op 'bar0' before all regions +// CHECK: step 4 walk was interrupted + + +// ----- + +// Test static filtering. +func @main() { + "foo"() : () -> () + "test.two_region_op"()( + {"work"() : () -> ()}, + {"work"() : () -> ()} + ) {interrupt_after_all = true} : () -> () + return +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'foo' before all regions +// CHECK: step 3 op 'test.two_region_op' before all regions +// CHECK: step 4 op 'work' before all regions +// CHECK: step 5 op 'test.two_region_op' before region #1 +// CHECK: step 6 op 'work' before all regions +// CHECK: step 7 walk was interrupted +// CHECK: step 8 op 'test.two_region_op' before all regions +// CHECK: step 9 op 'test.two_region_op' before region #1 +// CHECK: step 10 walk was interrupted + +// ----- + +// Test static filtering. +func @main() { + "foo"() : () -> () + "test.two_region_op"()( + {"work"() : () -> ()}, + {"work"() : () -> ()} + ) {interrupt_after_region = 0} : () -> () + return +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'foo' before all regions +// CHECK: step 3 op 'test.two_region_op' before all regions +// CHECK: step 4 op 'work' before all regions +// CHECK: step 5 walk was interrupted +// CHECK: step 6 op 'test.two_region_op' before all regions +// CHECK: step 7 walk was interrupted + +// ----- +// Test skipping. + +// Walk is skipped before visiting "foo". +func @main(%arg0: f32) -> f32 { + %v1 = "foo"() ({ + "bar0"() : () -> () + }, { + "bar1"() : () -> () + }) {skip_before_all = true} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'arith.addf' before all regions +// CHECK: step 3 op 'std.return' before all regions +// CHECK: step 4 op 'builtin.func' after all regions +// CHECK: step 5 op 'builtin.module' after all regions + +// ----- +// Walk is skipped after visiting all regions of "foo". +func @main(%arg0: f32) -> f32 { + %v1 = "foo"() ({ + "bar0"() : () -> () + }, { + "bar1"() : () -> () + }) {skip_after_all = true} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'foo' before all regions +// CHECK: step 3 op 'bar0' before all regions +// CHECK: step 4 op 'foo' before region #1 +// CHECK: step 5 op 'bar1' before all regions +// CHECK: step 6 op 'arith.addf' before all regions +// CHECK: step 7 op 'std.return' before all regions +// CHECK: step 8 op 'builtin.func' after all regions +// CHECK: step 9 op 'builtin.module' after all regions + +// ----- +// Walk is skipped after visiting first region of "foo". +func @main(%arg0: f32) -> f32 { + %v1 = "foo"() ({ + "bar0"() : () -> () + }, { + "bar1"() : () -> () + }) {skip_after_region = 0} : () -> f32 + %v2 = arith.addf %v1, %arg0 : f32 + return %v2 : f32 +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'foo' before all regions +// CHECK: step 3 op 'bar0' before all regions +// CHECK: step 4 op 'arith.addf' before all regions +// CHECK: step 5 op 'std.return' before all regions +// CHECK: step 6 op 'builtin.func' after all regions +// CHECK: step 7 op 'builtin.module' after all regions diff --git a/mlir/test/IR/generic-visitors.mlir b/mlir/test/IR/generic-visitors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/generic-visitors.mlir @@ -0,0 +1,63 @@ +// RUN: mlir-opt -test-generic-ir-visitors -allow-unregistered-dialect -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-generic-ir-visitors-interrupt -allow-unregistered-dialect -split-input-file %s | FileCheck %s + +// Verify the different configurations of generic IR visitors. + +func @structured_cfg() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + scf.for %i = %c1 to %c10 step %c1 { + %cond = "use0"(%i) : (index) -> (i1) + scf.if %cond { + "use1"(%i) : (index) -> () + } else { + "use2"(%i) : (index) -> () + } + "use3"(%i) : (index) -> () + } + return +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 1 op 'builtin.func' before all regions +// CHECK: step 2 op 'arith.constant' before all regions +// CHECK: step 3 op 'arith.constant' before all regions +// CHECK: step 4 op 'arith.constant' before all regions +// CHECK: step 5 op 'scf.for' before all regions +// CHECK: step 6 op 'use0' before all regions +// CHECK: step 7 op 'scf.if' before all regions +// CHECK: step 8 op 'use1' before all regions +// CHECK: step 9 op 'scf.yield' before all regions +// CHECK: step 10 op 'scf.if' before region #1 +// CHECK: step 11 op 'use2' before all regions +// CHECK: step 12 op 'scf.yield' before all regions +// CHECK: step 13 op 'scf.if' after all regions +// CHECK: step 14 op 'use3' before all regions +// CHECK: step 15 op 'scf.yield' before all regions +// CHECK: step 16 op 'scf.for' after all regions +// CHECK: step 17 op 'std.return' before all regions +// CHECK: step 18 op 'builtin.func' after all regions +// CHECK: step 19 op 'builtin.module' after all regions + +// ----- +// Test the specific operation type visitor. + +func @correct_number_of_regions() { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + scf.for %i = %c1 to %c10 step %c1 { + "test.two_region_op"()( + {"work"() : () -> ()}, + {"work"() : () -> ()} + ) : () -> () + } + return +} + +// CHECK: step 0 op 'builtin.module' before all regions +// CHECK: step 15 op 'builtin.module' after all regions +// CHECK: step 16 op 'test.two_region_op' before all regions +// CHECK: step 17 op 'test.two_region_op' before region #1 +// CHECK: step 18 op 'test.two_region_op' after all regions diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -15,6 +15,7 @@ TestSymbolUses.cpp TestTypes.cpp TestVisitors.cpp + TestVisitorsGeneric.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/IR/TestVisitorsGeneric.cpp b/mlir/test/lib/IR/TestVisitorsGeneric.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestVisitorsGeneric.cpp @@ -0,0 +1,123 @@ +//===- TestIRVisitorsGeneric.cpp - Pass to test the Generic IR visitors ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +static std::string getStageDescription(const WalkStage &stage) { + if (stage.isBeforeAllRegions()) + return "before all regions"; + if (stage.isAfterAllRegions()) + return "after all regions"; + return "before region #" + std::to_string(stage.getNextRegion()); +} + +namespace { +/// This pass exercises generic visitor with void callbacks and prints the order +/// and stage in which operations are visited. +class TestGenericIRVisitorPass + : public PassWrapper> { +public: + StringRef getArgument() const final { return "test-generic-ir-visitors"; } + StringRef getDescription() const final { return "Test generic IR visitors."; } + void runOnOperation() override { + Operation *outerOp = getOperation(); + int stepNo = 0; + outerOp->walk([&](Operation *op, const WalkStage &stage) { + llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " + << getStageDescription(stage) << "\n"; + }); + + // Exercise static inference of operation type. + outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { + llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " + << getStageDescription(stage) << "\n"; + }); + } +}; + +/// This pass exercises the generic visitor with non-void callbacks and prints +/// the order and stage in which operations are visited. It will interrupt the +/// walk based on attributes peesent in the IR. +class TestGenericIRVisitorInterruptPass + : public PassWrapper> { +public: + StringRef getArgument() const final { + return "test-generic-ir-visitors-interrupt"; + } + StringRef getDescription() const final { + return "Test generic IR visitors with interrupts."; + } + void runOnOperation() override { + Operation *outerOp = getOperation(); + int stepNo = 0; + + auto walker = [&](Operation *op, const WalkStage &stage) { + if (auto interruptBeforeAall = + op->getAttrOfType("interrupt_before_all")) + if (interruptBeforeAall.getValue() && stage.isBeforeAllRegions()) + return WalkResult::interrupt(); + + if (auto interruptAfterAll = + op->getAttrOfType("interrupt_after_all")) + if (interruptAfterAll.getValue() && stage.isAfterAllRegions()) + return WalkResult::interrupt(); + + if (auto interruptAfterRegion = + op->getAttrOfType("interrupt_after_region")) + if (stage.isAfterRegion( + static_cast(interruptAfterRegion.getInt()))) + return WalkResult::interrupt(); + + if (auto skipBeforeAall = op->getAttrOfType("skip_before_all")) + if (skipBeforeAall.getValue() && stage.isBeforeAllRegions()) + return WalkResult::skip(); + + if (auto skipAfterAll = op->getAttrOfType("skip_after_all")) + if (skipAfterAll.getValue() && stage.isAfterAllRegions()) + return WalkResult::skip(); + + if (auto skipAfterRegion = + op->getAttrOfType("skip_after_region")) + if (stage.isAfterRegion(static_cast(skipAfterRegion.getInt()))) + return WalkResult::skip(); + + llvm::outs() << "step " << stepNo++ << " op '" << op->getName() << "' " + << getStageDescription(stage) << "\n"; + return WalkResult::advance(); + }; + + // Interrupt the walk based on attributes on the operation. + auto result = outerOp->walk(walker); + + if (result.wasInterrupted()) + llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; + + // Exercise static inference of operation type. + result = outerOp->walk([&](test::TwoRegionOp op, const WalkStage &stage) { + return walker(op, stage); + }); + + if (result.wasInterrupted()) + llvm::outs() << "step " << stepNo++ << " walk was interrupted\n"; + } +}; + +} // namespace + +namespace mlir { +namespace test { +void registerTestGenericIRVisitorsPass() { + PassRegistration(); + PassRegistration(); +} + +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -78,6 +78,8 @@ void registerTestComposeSubView(); void registerTestGpuParallelLoopMappingPass(); void registerTestIRVisitorsPass(); +void registerTestGenericIRVisitorsPass(); +void registerTestGenericIRVisitorsInterruptPass(); void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgControlFuseByExpansion(); @@ -171,6 +173,7 @@ mlir::test::registerTestComposeSubView(); mlir::test::registerTestGpuParallelLoopMappingPass(); mlir::test::registerTestIRVisitorsPass(); + mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); mlir::test::registerTestLinalgCodegenStrategy(); mlir::test::registerTestLinalgControlFuseByExpansion();