diff --git a/mlir/lib/Analysis/SliceAnalysis.cpp b/mlir/lib/Analysis/SliceAnalysis.cpp --- a/mlir/lib/Analysis/SliceAnalysis.cpp +++ b/mlir/lib/Analysis/SliceAnalysis.cpp @@ -168,20 +168,26 @@ }; } // namespace -static void DFSPostorder(Operation *current, DFSState *state) { - for (Value result : current->getResults()) { - for (Operation *op : result.getUsers()) - DFSPostorder(op, state); - } - bool inserted; - using IterTy = decltype(state->seen.begin()); - IterTy iter; - std::tie(iter, inserted) = state->seen.insert(current); - if (inserted) { - if (state->toSort.count(current) > 0) { - state->topologicalCounts.push_back(current); +static void DFSPostorder(Operation *root, DFSState *state) { + SmallVector queue(1, root); + std::vector ops; + while (!queue.empty()) { + Operation *current = queue.pop_back_val(); + ops.push_back(current); + for (Value result : current->getResults()) { + for (Operation *op : result.getUsers()) + queue.push_back(op); + } + for (Region ®ion : current->getRegions()) { + for (Operation &op : region.getOps()) + queue.push_back(&op); } } + + for (Operation *op : llvm::reverse(ops)) { + if (state->seen.insert(op).second && state->toSort.count(op) > 0) + state->topologicalCounts.push_back(op); + } } SetVector diff --git a/mlir/test/Analysis/test-topoligical-sort.mlir b/mlir/test/Analysis/test-topoligical-sort.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-topoligical-sort.mlir @@ -0,0 +1,21 @@ +// RUN: mlir-opt %s -test-print-topological-sort 2>&1 | FileCheck %s + +// CHECK-LABEL: Testing : region +// CHECK: arith.addi {{.*}} : index +// CHECK-NEXT: scf.for +// CHECK: } {__test_sort_original_idx__ = 2 : i64} +// CHECK-NEXT: arith.addi {{.*}} : i32 +// CHECK-NEXT: arith.subi {{.*}} : i32 +func @region( + %arg0 : index, %arg1 : index, %arg2 : index, %arg3 : index, + %arg4 : i32, %arg5 : i32, %arg6 : i32, + %buffer : memref) { + %0 = arith.addi %arg4, %arg5 {__test_sort_original_idx__ = 0} : i32 + %idx = arith.addi %arg0, %arg1 {__test_sort_original_idx__ = 3} : index + scf.for %arg7 = %idx to %arg2 step %arg3 { + %2 = arith.addi %0, %arg5 : i32 + %3 = arith.subi %2, %arg6 {__test_sort_original_idx__ = 1} : i32 + memref.store %3, %buffer[] : memref + } {__test_sort_original_idx__ = 2} + return +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -8,6 +8,7 @@ TestMemRefDependenceCheck.cpp TestMemRefStrideCalculation.cpp TestNumberOfExecutions.cpp + TestSlice.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestSlice.cpp @@ -0,0 +1,50 @@ +//===------------- TestSlice.cpp - Test slice related analisis ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +static const StringLiteral kOrderMarker = "__test_sort_original_idx__"; + +namespace { + +struct TestTopologicalSortPass + : public PassWrapper { + StringRef getArgument() const final { return "test-print-topological-sort"; } + StringRef getDescription() const final { + return "Print operations in topological order"; + } + void runOnFunction() override { + std::map ops; + getFunction().walk([&ops](Operation *op) { + if (auto originalOrderAttr = op->getAttrOfType(kOrderMarker)) + ops[originalOrderAttr.getInt()] = op; + }); + SetVector sortedOp; + for (auto op : ops) + sortedOp.insert(op.second); + sortedOp = topologicalSort(sortedOp); + llvm::errs() << "Testing : " << getFunction().getName() << "\n"; + for (Operation *op : sortedOp) { + op->print(llvm::errs()); + llvm::errs() << "\n"; + } + } +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestSliceAnalysisPass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -106,6 +106,7 @@ void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); +void registerTestSliceAnalysisPass(); void registerTestVectorConversions(); } // namespace test } // namespace mlir @@ -195,6 +196,7 @@ mlir::test::registerTestPDLByteCodePass(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass(); + mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestVectorConversions(); } #endif