diff --git a/mlir/include/mlir/Analysis/LoopInfo.h b/mlir/include/mlir/Analysis/LoopInfo.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/LoopInfo.h @@ -0,0 +1,80 @@ +//===- LoopInfo.h - LoopInfo analysis for regions ---------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the LoopInfo analysis for MLIR. The LoopInfo is used to +// identify natural loops and determine the loop depth of various nodes of a +// CFG. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_LOOPINFO_H +#define MLIR_ANALYSIS_LOOPINFO_H + +#include "mlir/IR/Dominance.h" +#include "mlir/IR/RegionGraphTraits.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopInfoImpl.h" + +namespace llvm { + +// llvm::GraphTrait specializations that are required for the LLVM's generic +// LoopInfo. Note that the const_casts are required because MLIR has no constant +// accessors on IR constructs. + +template <> +struct GraphTraits { + using ChildIteratorType = mlir::Block::succ_iterator; + using Node = const mlir::Block; + using NodeRef = Node *; + + static NodeRef getEntryNode(NodeRef node) { return node; } + + static ChildIteratorType child_begin(NodeRef node) { + return const_cast(node)->succ_begin(); + } + static ChildIteratorType child_end(NodeRef node) { + return const_cast(node)->succ_end(); + } +}; + +template <> +struct GraphTraits> { + using ChildIteratorType = mlir::Block::pred_iterator; + using Node = const mlir::Block; + using NodeRef = Node *; + + static NodeRef getEntryNode(Inverse inverseGraph) { + return inverseGraph.Graph; + } + + static ChildIteratorType child_begin(NodeRef node) { + return const_cast(node)->pred_begin(); + } + static ChildIteratorType child_end(NodeRef node) { + return const_cast(node)->pred_end(); + } +}; +} // namespace llvm + +namespace mlir { +class Loop : public llvm::LoopBase { +private: + explicit Loop(mlir::Block *block); + + friend class llvm::LoopBase; + friend class llvm::LoopInfoBase; +}; + +/// Instantiate a variant of LLVM LoopInfo that works on mlir::Block. +class LoopInfo : public llvm::LoopInfoBase { +public: + LoopInfo(const llvm::DominatorTreeBase &domTree); +}; +} // namespace mlir + +#endif // MLIR_ANALYSIS_LOOPINFO_H diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ DataLayoutAnalysis.cpp FlatLinearValueConstraints.cpp Liveness.cpp + LoopInfo.cpp SliceAnalysis.cpp AliasAnalysis/LocalAliasAnalysis.cpp @@ -24,6 +25,7 @@ DataLayoutAnalysis.cpp FlatLinearValueConstraints.cpp Liveness.cpp + LoopInfo.cpp SliceAnalysis.cpp AliasAnalysis/LocalAliasAnalysis.cpp diff --git a/mlir/lib/Analysis/LoopInfo.cpp b/mlir/lib/Analysis/LoopInfo.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/LoopInfo.cpp @@ -0,0 +1,19 @@ +//===- LoopInfo.cpp - LoopInfo analysis for regions -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/LoopInfo.h" + +using namespace mlir; + +::mlir::Loop::Loop(mlir::Block *block) + : llvm::LoopBase(block) {} + +::mlir::LoopInfo::LoopInfo( + const llvm::DominatorTreeBase &domTree) { + analyze(domTree); +} diff --git a/mlir/test/Analysis/test-loopinfo.mlir b/mlir/test/Analysis/test-loopinfo.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-loopinfo.mlir @@ -0,0 +1,115 @@ +// RUN: mlir-opt -pass-pipeline="builtin.module(any(test-loop-info))" --split-input-file %s 2>&1 | FileCheck %s + +// CHECK-LABEL: Testing : "no_loop_single_block" +// CHECK: no loops +func.func @no_loop_single_block() { + return +} + +// ----- + +// CHECK-LABEL: Testing : "no_loop" +// CHECK: no loops +func.func @no_loop() { + cf.br ^bb1 +^bb1: + return +} + +// ----- + +// CHECK-LABEL: Testing : "simple_loop" +// CHECK-NEXT: Blocks : ^[[BB0:.*]], ^[[BB1:.*]], ^[[BB2:.*]], ^[[BB3:.*]] +// CHECK: Loop at depth 1 containing: +// CHECK-SAME: ^[[BB1]]
+// CHECK-SAME: ^[[BB2]] +func.func @simple_loop(%c: i1) { + cf.br ^bb1 +^bb1: + cf.cond_br %c, ^bb2, ^bb3 +^bb2: + cf.br ^bb1 +^bb3: + return +} + +// ----- + +// CHECK-LABEL: Testing : "single_block_loop" +// CHECK-NEXT: Blocks : ^[[BB0:.*]], ^[[BB1:.*]], ^[[BB2:.*]] +// CHECK: Loop at depth 1 containing: +// CHECK-SAME: ^[[BB1]]
+func.func @single_block_loop(%c: i1) { + cf.br ^bb1 +^bb1: + cf.cond_br %c, ^bb1, ^bb2 +^bb2: + return +} + +// ----- + +// CHECK-LABEL: Testing : "nested_loop" +// CHECK-NEXT: Blocks : ^[[BB0:.*]], ^[[BB1:.*]], ^[[BB2:.*]], ^[[BB3:.*]], ^[[BB4:.*]] +// CHECK: Loop at depth 1 +// CHECK-SAME: ^[[BB1]]
+// CHECK-SAME: ^[[BB2]] +// CHECK-SAME: ^[[BB3]] +// CHECK: Loop at depth 2 +// CHECK-SAME: ^[[BB2]]
+// CHECK-SAME: ^[[BB3]] +func.func @nested_loop(%c: i1) { + cf.br ^bb1 +^bb1: + cf.cond_br %c, ^bb2, ^bb4 +^bb2: + cf.cond_br %c, ^bb1, ^bb3 +^bb3: + cf.br ^bb2 +^bb4: + return +} + +// ----- + +// CHECK-LABEL: Testing : "multi_latch" +// CHECK-NEXT: Blocks : ^[[BB0:.*]], ^[[BB1:.*]], ^[[BB2:.*]], ^[[BB3:.*]], ^[[BB4:.*]] +// CHECK: Loop at depth 1 +// CHECK-SAME: ^[[BB1]]
+// CHECK-SAME: ^[[BB2]] +// CHECK-SAME: ^[[BB3]] +func.func @multi_latch(%c: i1) { + cf.br ^bb1 +^bb1: + cf.cond_br %c, ^bb4, ^bb2 +^bb2: + cf.cond_br %c, ^bb1, ^bb3 +^bb3: + cf.br ^bb1 +^bb4: + return +} + +// ----- + +// CHECK-LABEL: Testing : "multiple_loops" +// CHECK-NEXT: Blocks : ^[[BB0:.*]], ^[[BB1:.*]], ^[[BB2:.*]], ^[[BB3:.*]], ^[[BB4:.*]], ^[[BB5:.*]] +// CHECK: Loop at depth 1 +// CHECK-SAME: ^[[BB3]]
+// CHECK-SAME: ^[[BB4]] +// CHECK: Loop at depth 1 +// CHECK-SAME: ^[[BB1]]
+// CHECK-SAME: ^[[BB2]] +func.func @multiple_loops(%c: i1) { + cf.br ^bb1 +^bb1: + cf.br ^bb2 +^bb2: + cf.cond_br %c, ^bb3, ^bb1 +^bb3: + cf.cond_br %c, ^bb5, ^bb4 +^bb4: + cf.br ^bb3 +^bb5: + return +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -4,6 +4,7 @@ TestCallGraph.cpp TestDataFlowFramework.cpp TestLiveness.cpp + TestLoopInfo.cpp TestMatchReduction.cpp TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp diff --git a/mlir/test/lib/Analysis/TestLoopInfo.cpp b/mlir/test/lib/Analysis/TestLoopInfo.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestLoopInfo.cpp @@ -0,0 +1,71 @@ +//===- TestLoopInfo.cpp - Test loop info analysis -------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing the LoopInfo analysis. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/LoopInfo.h" +#include "mlir/IR/FunctionInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// A testing pass that applies the LoopInfo analysis on a region and prints the +/// information it collected to llvm::errs(). +struct TestLoopInfo + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLoopInfo) + + StringRef getArgument() const final { return "test-loop-info"; } + StringRef getDescription() const final { + return "Test the loop info analysis."; + } + + void runOnOperation() override; +}; +} // namespace + +void TestLoopInfo::runOnOperation() { + auto func = getOperation(); + DominanceInfo &domInfo = getAnalysis(); + Region ®ion = func.getFunctionBody(); + + // Prints the label of the test. + llvm::errs() << "Testing : " << func.getNameAttr() << "\n"; + + // Print all the block identifiers first such that the tests can match them. + llvm::errs() << "Blocks : "; + region.front().printAsOperand(llvm::errs()); + for (auto &block : region.getBlocks()) { + llvm::errs() << ", "; + block.printAsOperand(llvm::errs()); + } + llvm::errs() << "\n"; + + if (region.getBlocks().size() == 1) { + llvm::errs() << "no loops\n"; + return; + } + + llvm::DominatorTreeBase &domTree = + domInfo.getDomTree(®ion); + mlir::LoopInfo loopInfo(domTree); + + if (loopInfo.getTopLevelLoops().empty()) + llvm::errs() << "no loops\n"; + else + loopInfo.print(llvm::errs()); +} + +namespace mlir { +namespace test { +void registerTestLoopInfoPass() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -99,6 +99,7 @@ void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); +void registerTestLoopInfoPass(); void registerTestLoopMappingPass(); void registerTestLoopUnrollingPass(); void registerTestLowerToLLVM(); @@ -211,6 +212,7 @@ mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessPass(); mlir::test::registerTestLoopFusion(); + mlir::test::registerTestLoopInfoPass(); mlir::test::registerTestLoopMappingPass(); mlir::test::registerTestLoopUnrollingPass(); mlir::test::registerTestLowerToLLVM();