diff --git a/mlir/include/mlir/Analysis/NumberOfExecutions.h b/mlir/include/mlir/Analysis/NumberOfExecutions.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/NumberOfExecutions.h @@ -0,0 +1,67 @@ +//===- NumberOfExecutions.h - Number of executions analysis -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains an analysis for computing how many times a block within a +// region is executed. The analysis iterates over all associated regions that +// are attached to the given top-level operation. +// +// It is possible to query number of executions information on block level. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H +#define MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" + +namespace mlir { + +class Block; +class BlockNumberOfExecutionsInfo; +class Operation; + +/// Represents an analysis for computing how many times a block within a region +/// is executed. The analysis iterates over all associated regions that are +/// attached to the given top-level operation. +class NumberOfExecutions { +public: + /// Creates a new NumberOfExecutions analysis that computes how many times a + /// block within a region is executed for all associated regions. + NumberOfExecutions(Operation *op); + + /// Dumps the number of executions information to the given stream. + void print(raw_ostream &os) const; + +private: + /// The operation this analysis was constructed from. + Operation *operation; + + /// A mapping from blocks to number of executions information. + DenseMap blockNumbersOfExecution; +}; + +class BlockNumberOfExecutionsInfo { +public: + Optional getNumberOfExecutions() const; + + static BlockNumberOfExecutionsInfo once(Block *block); + static BlockNumberOfExecutionsInfo unknown(Block *block); + +private: + BlockNumberOfExecutionsInfo(Block *block, + Optional numberOfExecutions); + + Block *block; + Optional numberOfExecutions; +}; + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_NUMBER_OF_EXECUTIONS_H diff --git a/mlir/include/mlir/Dialect/Async/IR/Async.h b/mlir/include/mlir/Dialect/Async/IR/Async.h --- a/mlir/include/mlir/Dialect/Async/IR/Async.h +++ b/mlir/include/mlir/Dialect/Async/IR/Async.h @@ -19,6 +19,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -14,6 +14,7 @@ #define ASYNC_OPS include "mlir/Dialect/Async/IR/AsyncBase.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" //===----------------------------------------------------------------------===// @@ -26,6 +27,7 @@ def Async_ExecuteOp : Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">, + DeclareOpInterfaceMethods, AttrSizedOperandSegments]> { let summary = "Asynchronous execute operation"; let description = [{ @@ -78,6 +80,12 @@ let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; let verifier = [{ return ::verify(*this); }]; + + let extraClassDeclaration = [{ + // RegionBranchOpInterface declarations. + void getNumRegionInvocations(ArrayRef operands, + SmallVectorImpl &countPerRegion); + }]; } def Async_YieldOp : diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -130,6 +130,26 @@ "void", "getSuccessorRegions", (ins "Optional":$index, "ArrayRef":$operands, "SmallVectorImpl &":$regions) + >, + InterfaceMethod<[{ + Returns the number of times this operation will invoke the attached + regions. If the number of region invocations is not know statically it + will return the `-1` value for it. + + `operands` is a set of optional attributes that either correspond to a + constant values for each operand of this operation, or null if that + operand is not a constant. + }], + "void", "getNumRegionInvocations", + (ins "ArrayRef":$operands, + "SmallVectorImpl &":$countPerRegion), [{}], + /*defaultImplementation=*/[{ + unsigned numRegions = this->getOperation()->getNumRegions(); + assert(countPerRegion.empty()); + countPerRegion.resize(numRegions); + for (unsigned i = 0; i < numRegions; ++i) + countPerRegion[i] = -1; + }] > ]; diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -6,6 +6,7 @@ Liveness.cpp LoopAnalysis.cpp NestedMatcher.cpp + NumberOfExecutions.cpp PresburgerSet.cpp SliceAnalysis.cpp Utils.cpp @@ -15,6 +16,7 @@ BufferAliasAnalysis.cpp CallGraph.cpp Liveness.cpp + NumberOfExecutions.cpp SliceAnalysis.cpp ADDITIONAL_HEADER_DIRS @@ -53,5 +55,5 @@ MLIRPresburger MLIRSCF ) - + add_subdirectory(Presburger) diff --git a/mlir/lib/Analysis/NumberOfExecutions.cpp b/mlir/lib/Analysis/NumberOfExecutions.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/NumberOfExecutions.cpp @@ -0,0 +1,143 @@ +//===- NumberOfExecutions.cpp - Number of executions analysis -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implementation of the number of executions analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/FunctionExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Support/raw_ostream.h" + +#include "mlir/Analysis/NumberOfExecutions.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" + +#define DEBUG_TYPE "number-of-executions-analysis" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// NumberOfExecutions +//===----------------------------------------------------------------------===// + +/// Computes blocks number of executions information for the given region. +void computeRegionBlockNumberOfExecutions( + Region ®ion, DenseMap &blockInfo) { + + // Check if we know how many times parent operation will invoke this region. + // TODO: Pass constant operands to RegionInvocationsOpInterface; + Operation *parentOp = region.getParentOp(); + SmallVector numRegionInvocations; + if (auto regionInterface = dyn_cast(parentOp)) { + SmallVector operands(parentOp->getNumOperands()); + regionInterface.getNumRegionInvocations(operands, numRegionInvocations); + } + + // Functions will always execute attached region once. + if (auto func = dyn_cast(parentOp)) + numRegionInvocations.push_back(1); + + // If we do not have region invocation information then we can't compute + // execution number for the blocks inside the region. + int regionId = region.getRegionNumber(); + if (numRegionInvocations.empty() || numRegionInvocations[regionId] == -1) { + for (Block &block : region) { + blockInfo.insert({&block, BlockNumberOfExecutionsInfo::unknown(&block)}); + } + return; + } + + // DFS traversal looking for loops in the CFG. + llvm::SmallSet loopStart; + + llvm::unique_function &)> dfs = + [&](Block *block, llvm::SmallSet &visited) { + // Found a loop in the CFG that starts at the `block`. + if (visited.contains(block)) { + loopStart.insert(block); + return; + } + + // Continue DFS traversal. + visited.insert(block); + for (Block *successor : block->getSuccessors()) + dfs(successor, visited); + visited.erase(block); + }; + + llvm::SmallSet visited; + dfs(®ion.front(), visited); + + // Start from the entry block and follow only blocks with single succesor. + Block *block = ®ion.front(); + while (block && !loopStart.contains(block)) { + // Block will be executed exactly once. + blockInfo.insert({block, BlockNumberOfExecutionsInfo::once(block)}); + + // We reached the exit block or block with multiple successors. + if (block->getNumSuccessors() != 1) + break; + + // Continue traversal. + block = block->getSuccessor(0); + } + + // For all blocks that we did not visit set the executions number to unknown. + for (Block &block : region) { + if (blockInfo.count(&block)) + continue; + blockInfo.insert({&block, BlockNumberOfExecutionsInfo::unknown(&block)}); + } +} + +/// Creates a new NumberOfExecutions analysis that computes how many times a +/// block within a region is executed for all associated regions. +NumberOfExecutions::NumberOfExecutions(Operation *op) : operation(op) { + operation->walk([&](Region *region) { + computeRegionBlockNumberOfExecutions(*region, blockNumbersOfExecution); + }); +} + +void NumberOfExecutions::print(raw_ostream &os) const { + unsigned blockId = 0; + + operation->walk([&](Block *block) { + auto it = blockNumbersOfExecution.find(block); + assert(it != blockNumbersOfExecution.end()); + + llvm::errs() << "Block: " << blockId++ << "\n"; + llvm::errs() << "Number of executions: "; + if (auto n = it->getSecond().getNumberOfExecutions()) + llvm::errs() << *n << "\n"; + else + llvm::errs() << "\n"; + }); +} + +//===----------------------------------------------------------------------===// +// BlockNumberOfExecutionsInfo +//===----------------------------------------------------------------------===// + +BlockNumberOfExecutionsInfo::BlockNumberOfExecutionsInfo( + Block *block, Optional numberOfExecutions) + : block(block), numberOfExecutions(numberOfExecutions) {} + +Optional BlockNumberOfExecutionsInfo::getNumberOfExecutions() const { + return numberOfExecutions; +} + +BlockNumberOfExecutionsInfo BlockNumberOfExecutionsInfo::once(Block *block) { + return {block, 1}; +} +BlockNumberOfExecutionsInfo BlockNumberOfExecutionsInfo::unknown(Block *block) { + return {block, None}; +} diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -118,6 +118,28 @@ constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; +void ExecuteOp::getNumRegionInvocations( + ArrayRef operands, SmallVectorImpl &countPerRegion) { + (void)operands; + assert(countPerRegion.empty()); + countPerRegion.resize(1); + countPerRegion[0] = 1; +} + +void ExecuteOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // The `body` region branch back to the parent operation. + if (index.hasValue()) { + assert(*index == 0); + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // Otherwise the successor is the body region. + regions.push_back(RegionSuccessor(&body())); +} + static void print(OpAsmPrinter &p, ExecuteOp op) { p << op.getOperationName(); diff --git a/mlir/test/Analysis/test-number-of-executions.mlir b/mlir/test/Analysis/test-number-of-executions.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-number-of-executions.mlir @@ -0,0 +1,127 @@ +// RUN: mlir-opt %s -test-print-number-of-executions -split-input-file 2>&1 \ +// RUN: | FileCheck %s --dump-input=always + +// CHECK-LABEL: Number of executions: empty +func @empty() { + // CHECK: Block: 0 + // CHECK: Number of executions: 1 + return +} + +// ----- + +// CHECK-LABEL: Number of executions: sequential +func @sequential() { + // CHECK: Block: 0 + // CHECK: Number of executions: 1 + br ^bb1 +^bb1: + // CHECK: Block: 1 + // CHECK: Number of executions: 1 + br ^bb2 +^bb2: + // CHECK: Block: 2 + // CHECK: Number of executions: 1 + return +} + +// ----- + +// CHECK-LABEL: Number of executions: conditional +func @conditional(%cond : i1) { + // CHECK: Block: 0 + // CHECK: Number of executions: 1 + br ^bb1 +^bb1: + // CHECK: Block: 1 + // CHECK: Number of executions: 1 + cond_br %cond, ^bb2, ^bb3 +^bb2: + // CHECK: Block: 2 + // CHECK: Number of executions: + br ^bb4 +^bb3: + // CHECK: Block: 3 + // CHECK: Number of executions: + br ^bb4 +^bb4: + // CHECK: Block: 4 + // CHECK: Number of executions: + return +} + +// ----- + +// CHECK-LABEL: Number of executions: loop +func @loop(%cond : i1) { + // CHECK: Block: 0 + // CHECK: Number of executions: 1 + br ^bb1 +^bb1: + // CHECK: Block: 1 + // CHECK: Number of executions: + br ^bb2 +^bb2: + // CHECK: Block: 2 + // CHECK: Number of executions: + br ^bb3 +^bb3: + // CHECK: Block: 3 + // CHECK: Number of executions: + cond_br %cond, ^bb1, ^bb4 +^bb4: + // CHECK: Block: 4 + // CHECK: Number of executions: + return +} + +// ----- + +// CHECK-LABEL: Number of executions: scf_if_dynamic_branch +func @scf_if_dynamic_branch(%cond : i1) { + // CHECK: Block: 0 + // CHECK: Number of executions: 1 + scf.if %cond { + // CHECK: Block: 1 + // CHECK: Number of executions: + } else { + // CHECK: Block: 2 + // CHECK: Number of executions: + } + return +} + +// ----- + +// CHECK-LABEL: Number of executions: async_execute +func @async_execute() { + // CHECK: Block: 0 + // CHECK: Number of executions: 1 + async.execute { + // CHECK: Block: 1 + // CHECK: Number of executions: 1 + async.yield + } + return +} + +// ----- + +// CHECK-LABEL: Number of executions: async_execute_with_scf_if +func @async_execute_with_scf_if(%cond : i1) { + // CHECK: Block: 0 + // CHECK: Number of executions: 1 + async.execute { + // CHECK: Block: 1 + // CHECK: Number of executions: 1 + scf.if %cond { + // CHECK: Block: 2 + // CHECK: Number of executions: + } else { + // CHECK: Block: 3 + // CHECK: Number of executions: + } + async.yield + } + return +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ TestLoopMapping.cpp TestLoopParametricTiling.cpp TestLoopUnrolling.cpp + TestNumberOfExecutions.cpp TestOpaqueLoc.cpp TestMemRefBoundCheck.cpp TestMemRefDependenceCheck.cpp diff --git a/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp b/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestNumberOfExecutions.cpp @@ -0,0 +1,37 @@ +//===- TestNumberOfExecutions.cpp - Test number of executions analysis ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for constructing and resolving number of +// executions information. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/NumberOfExecutions.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { + +struct TestNumberOfExecutionsPass + : public PassWrapper { + void runOnFunction() override { + llvm::errs() << "Number of executions: " << getFunction().getName() << "\n"; + getAnalysis().print(llvm::errs()); + } +}; + +} // end anonymous namespace + +namespace mlir { +void registerTestNumberOfExecutionsPass() { + PassRegistration( + "test-print-number-of-executions", + "Print the contents of a constructed number of executions analysis."); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -72,6 +72,7 @@ void registerTestMatchers(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); +void registerTestNumberOfExecutionsPass(); void registerTestOpaqueLoc(); void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestPrintDefUsePass(); @@ -133,6 +134,7 @@ registerTestMatchers(); registerTestMemRefDependenceCheck(); registerTestMemRefStrideCalculation(); + registerTestNumberOfExecutionsPass(); registerTestOpaqueLoc(); registerTestPreparationPassWithAllowedMemrefResults(); registerTestPrintDefUsePass();