diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -0,0 +1,454 @@ +//===- DataFlowFramework.h - A generic framework for data-flow analysis ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a generic framework for writing data-flow analysis in MLIR. +// The framework consists of a solver, which runs the fixed-point iteration and +// manages analysis dependencies, and a data-flow analysis class used to +// implement specific analyses. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H +#define MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H + +#include "mlir/Analysis/DataFlowAnalysis.h" +#include "mlir/IR/Operation.h" +#include "mlir/Support/StorageUniquer.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/Support/TypeName.h" +#include + +namespace mlir { + +/// Forward declare the analysis state class. +class AnalysisState; + +//===----------------------------------------------------------------------===// +// GenericProgramPoint +//===----------------------------------------------------------------------===// + +/// Abstract class for generic program points. In classical data-flow analysis, +/// programs points represent positions in a program to which lattice elements +/// are attached. In sparse data-flow analysis, these can be SSA values, and in +/// dense data-flow analysis, these are the program points before and after +/// every operation. +/// +/// In the general MLIR data-flow analysis framework, program points are an +/// extensible concept. Program points are uniquely identifiable objects to +/// which analysis states can be attached. The semantics of program points are +/// defined by the analyses that specify their transfer functions. +/// +/// Program points are implemented using MLIR's storage uniquer framework and +/// type ID system to provide RTTI. +class GenericProgramPoint : public StorageUniquer::BaseStorage { +public: + virtual ~GenericProgramPoint(); + + /// Get the abstract program point's type identifier. + TypeID getTypeID() const { return typeID; } + + /// Get a derived source location for the program point. + virtual Location getLoc() const = 0; + + /// Print the program point. + virtual void print(raw_ostream &os) const = 0; + +protected: + /// Create an abstract program point with type identifier. + explicit GenericProgramPoint(TypeID typeID) : typeID(typeID) {} + +private: + /// The type identifier of the program point. + TypeID typeID; +}; + +//===----------------------------------------------------------------------===// +// GenericProgramPointBase +//===----------------------------------------------------------------------===// + +/// Base class for generic program points based on a concrete program point +/// type and a content key. This class defines the common methods required for +/// operability with the storage uniquer framework. +/// +/// The provided key type uniquely identifies the concrete program point +/// instance and are the data members of the class. +template +class GenericProgramPointBase : public GenericProgramPoint { +public: + /// The concrete key type used by the storage uniquer. This class is uniqued + /// by its contents. + using KeyTy = Value; + /// Alias for the base class. + using Base = GenericProgramPointBase; + + /// Construct an instance of the program point using the provided value and + /// the type ID of the concrete type. + template + explicit GenericProgramPointBase(ValueT &&value) + : GenericProgramPoint(TypeID::get()), + value(std::forward(value)) {} + + /// Get a uniqued instance of this program point class with the given + /// arguments. + template + static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { + return uniquer.get(/*initFn=*/{}, std::forward(args)...); + } + + /// Allocate space for a program point and construct it in-place. + template + static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, + ValueT &&value) { + return new (alloc.allocate()) + ConcreteT(std::forward(value)); + } + + /// Two program points are equal if their values are equal. + bool operator==(const Value &value) const { return this->value == value; } + + /// Provide LLVM-style RTTI using type IDs. + static bool classof(const GenericProgramPoint *point) { + return point->getTypeID() == TypeID::get(); + } + + /// Get the contents of the program point. + const Value &getValue() const { return value; } + +private: + /// The program point value. + Value value; +}; + +//===----------------------------------------------------------------------===// +// ProgramPoint +//===----------------------------------------------------------------------===// + +/// Fundamental IR components are supported as first-class program points. +struct ProgramPoint : public PointerUnion { + using ParentTy = PointerUnion; + /// Inherit constructors. + using ParentTy::PointerUnion; + /// Allow implicit conversion from the parent type. + ProgramPoint(ParentTy point) : ParentTy(point) {} + + /// Print the program point. + void print(raw_ostream &os) const; + + /// Get the source location of the program point. + Location getLoc() const; +}; + +/// Forward declaration of the data-flow analysis class. +class DataFlowAnalysis; + +//===----------------------------------------------------------------------===// +// DataFlowSolver +//===----------------------------------------------------------------------===// + +/// The general data-flow analysis solver. This class is responsible for +/// orchestrating child data-flow analyses, running the fixed-point iteration +/// algorithm, managing analysis state and program point memory, and tracking +/// dependencies beteen analyses, program points, and analysis states. +/// +/// Steps to run a data-flow analysis: +/// +/// 1. Load and initialize children analyses. Children analyses are instantiated +/// in the solver and initialized, building their dependency relations. +/// 2. Configure and run the analysis. The solver invokes the children analyses +/// according to their dependency relations until a fixed point is reached. +/// 3. Query analysis state results from the solver. +/// +/// TODO: Optimize the internal implementation of the solver. +class DataFlowSolver { +public: + /// Load an analysis into the solver. Return the analysis instance. + template + AnalysisT *load(Args &&...args); + + /// Initialize the children analyses starting from the provided top-level + /// operation and run the analysis until fixpoint. + LogicalResult initializeAndRun(Operation *top); + + /// Lookup an analysis state for the given program point. Returns null if one + /// does not exist. + template + const StateT *lookupState(PointT point) const { + auto it = analysisStates.find({point, TypeID::get()}); + if (it == analysisStates.end()) + return nullptr; + return static_cast(it->second.get()); + } + + /// Get a uniqued program point instance. If one is not present, it is + /// created with the provided arguments. + template + PointT *getProgramPoint(Args &&...args) { + return PointT::get(uniquer, std::forward(args)...); + } + + /// A work item on the solver queue is a program point, child analysis pair. + /// Each item is processed by invoking the child analysis at the program + /// point. + using WorkItem = std::pair; + /// Push a work item onto the worklist. + void enqueue(WorkItem item) { worklist.push(std::move(item)); } + +protected: + /// Get the state associated with the given program point. If it does not + /// exist, create an uninitialized state. + template + StateT *getOrCreateState(PointT point); + + /// Propagate an update to an analysis state if it changed by pushing + /// dependent work items to the back of the queue. + void propagateIfChanged(AnalysisState *state, ChangeResult changed); + + /// Add a dependency to an analysis state on a child analysis and program + /// point. If the state is updated, the child analysis must be invoked on the + /// given program point again. + void addDependency(AnalysisState *state, DataFlowAnalysis *analysis, + ProgramPoint point); + +private: + /// The solver's work queue. Work items can be inserted to the front of the + /// queue to be processed greedily, speeding up computations that otherwise + /// quickly degenerate to quadratic due to propagation of state updates. + std::queue worklist; + + /// Type-erased instances of the children analyses. + SmallVector> childAnalyses; + + /// The storage uniquer instance that owns the memory of the allocated program + /// points. + StorageUniquer uniquer; + + /// A type-erased map of program points to associated analysis states for + /// first-class program points. + DenseMap, std::unique_ptr> + analysisStates; + + /// Allow the base child analysis class to access the internals of the solver. + friend class DataFlowAnalysis; +}; + +//===----------------------------------------------------------------------===// +// AnalysisState +//===----------------------------------------------------------------------===// + +/// Base class for generic analysis states. Analysis states contain data-flow +/// information that are attached to program points and which evolve as the +/// analysis iterates. +/// +/// This class places no restrictions on the semantics of analysis states beyond +/// these requirements. +/// +/// 1. Querying the state of a program point prior to visiting that point +/// results in uninitialized state. Analyses must be aware of unintialized +/// states. +/// 2. Analysis states can reach fixpoints, where subsequent updates will never +/// trigger a change in the state. +/// 3. Analysis states that are uninitialized can be forcefully initialized to a +/// default value. +class AnalysisState { +public: + virtual ~AnalysisState(); + + /// Returns true if the analysis state is uninitialized. + virtual bool isUninitialized() const = 0; + + /// Force an uninitialized analysis state to initialize itself with a default + /// value. + virtual ChangeResult defaultInitialize() = 0; + + /// Print the contents of the analysis state. + virtual void print(raw_ostream &os) const = 0; + +protected: + /// Create the analysis state at the given program point. + AnalysisState(ProgramPoint point) : point(point) {} + + /// This function is called by the solver when the analysis state is updated + /// to optionally enqueue more work items. For example, if a state tracks + /// dependents through the IR (e.g. use-def chains), this function can be + /// implemented to push those dependents on the worklist. + virtual void onUpdate(DataFlowSolver *solver) const {} + + /// The dependency relations originating from this analysis state. An entry + /// `state -> (analysis, point)` is created when `analysis` queries `state` + /// when updating `point`. + /// + /// When this state is updated, all dependent child analysis invocations are + /// pushed to the back of the queue. Use a `SetVector` to keep the analysis + /// deterministic. + /// + /// Store the dependents on the analysis state for efficiency. + SetVector dependents; + + /// The program point to which the state belongs. + ProgramPoint point; + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + /// When compiling with debugging, keep a name for the analysis state. + StringRef debugName; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + + /// Allow the framework to access the dependents. + friend class DataFlowSolver; +}; + +//===----------------------------------------------------------------------===// +// DataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// Base class for all data-flow analyses. A child analysis is expected to build +/// an initial dependency graph (and optionally provide an initial state) when +/// initialized and define transfer functions when visiting program points. +/// +/// In classical data-flow analysis, the dependency graph is fixed and analyses +/// define explicit transfer functions between input states and output states. +/// In this framework, however, the dependency graph can change during the +/// analysis, and transfer functions are opaque such that the solver doesn't +/// know what states calling `visit` on an analysis will be updated. This allows +/// multiple analyses to plug in and provide values for the same state. +/// +/// Generally, when an analysis queries an uninitialized state, it is expected +/// to "bail out", i.e., not provide any updates. When the value is initialized, +/// the solver will re-invoke the analysis. If the solver exhausts its worklist, +/// however, and there are still uninitialized states, the solver "nudges" the +/// analyses by default-initializing those states. +class DataFlowAnalysis { +public: + virtual ~DataFlowAnalysis(); + + /// Initialize the analysis from the provided top-level operation by building + /// an initial dependency graph between all program points of interest. This + /// can be implemented by calling `visit` on all program points of interest + /// below the top-level operation. + /// + /// An analysis can optionally provide initial values to certain analysis + /// states to influence the evolution of the analysis. + virtual LogicalResult initialize(Operation *top) = 0; + + /// Visit the given program point. This function is invoked by the solver on + /// this analysis with a given program point when a dependent analysis state + /// is updated. The function is similar to a transfer function; it queries + /// certain analysis states and sets other states. + /// + /// The function is expected to create dependencies on queried states and + /// propagate updates on changed states. A dependency can be created by + /// calling `addDependency` between the input state and a program point, + /// indicating that, if the state is updated, the solver should invoke `solve` + /// on the program point. The dependent point does not have to be the same as + /// the provided point. An update to a state is propagated by calling + /// `propagateIfChange` on the state. If the state has changed, then all its + /// dependents are placed on the worklist. + /// + /// The dependency graph does not need to be static. Each invocation of + /// `visit` can add new dependencies, but these dependecies will not be + /// dynamically added to the worklist because the solver doesn't know what + /// will provide a value for then. + virtual LogicalResult visit(ProgramPoint point) = 0; + +protected: + /// Create an analysis with a reference to the parent solver. + explicit DataFlowAnalysis(DataFlowSolver &solver); + + /// Create a dependency between the given analysis state and program point + /// on this analysis. + void addDependency(AnalysisState *state, ProgramPoint point); + + /// Propagate an update to a state if it changed. + void propagateIfChanged(AnalysisState *state, ChangeResult changed); + + /// Register a custom program point class. + template + void registerPointKind() { + solver.uniquer.registerParametricStorageType(); + } + + /// Get or create a custom program point. + template + PointT *getProgramPoint(Args &&...args) { + return solver.getProgramPoint(std::forward(args)...); + } + + /// Get the analysis state assiocated with the program point. The returned + /// state is expected to be "write-only", and any updates need to be + /// propagated by `propagateIfChanged`. + template + StateT *getOrCreate(PointT point) { + return solver.getOrCreateState(point); + } + + /// Get a read-only analysis state for the given point and create a dependency + /// on `dependent`. If the return state is updated elsewhere, this analysis is + /// re-invoked on the dependent. + template + const StateT *getOrCreateFor(ProgramPoint dependent, PointT point) { + StateT *state = getOrCreate(point); + addDependency(state, dependent); + return state; + } + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + /// When compiling with debugging, keep a name for the analyis. + StringRef debugName; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + +private: + /// The parent data-flow solver. + DataFlowSolver &solver; + + /// Allow the data-flow solver to access the internals of this class. + friend class DataFlowSolver; +}; + +template +AnalysisT *DataFlowSolver::load(Args &&...args) { + childAnalyses.emplace_back(new AnalysisT(*this, std::forward(args)...)); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + childAnalyses.back().get()->debugName = llvm::getTypeName(); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + return static_cast(childAnalyses.back().get()); +} + +template +StateT *DataFlowSolver::getOrCreateState(PointT point) { + std::unique_ptr &state = + analysisStates[{{point}, TypeID::get()}]; + if (!state) { + state = std::unique_ptr(new StateT({point})); +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + state->debugName = llvm::getTypeName(); +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + } + return static_cast(state.get()); +} + +inline raw_ostream &operator<<(raw_ostream &os, const AnalysisState &state) { + state.print(os); + return os; +} + +inline raw_ostream &operator<<(raw_ostream &os, ProgramPoint point) { + point.print(os); + return os; +} + +} // end namespace mlir + +namespace llvm { +/// Allow hashing of program points. +template <> +struct DenseMapInfo + : public DenseMapInfo {}; +} // end namespace llvm + +#endif // MLIR_ANALYSIS_DATAFLOWFRAMEWORK_H diff --git a/mlir/lib/Analysis/CMakeLists.txt b/mlir/lib/Analysis/CMakeLists.txt --- a/mlir/lib/Analysis/CMakeLists.txt +++ b/mlir/lib/Analysis/CMakeLists.txt @@ -16,6 +16,7 @@ BufferViewFlowAnalysis.cpp CallGraph.cpp DataFlowAnalysis.cpp + DataFlowFramework.cpp DataLayoutAnalysis.cpp IntRangeAnalysis.cpp Liveness.cpp diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -0,0 +1,161 @@ +//===- DataFlowFramework.cpp - A generic framework for data-flow analysis -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlowFramework.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "dataflow" +#if LLVM_ENABLE_ABI_BREAKING_CHECKS +#define DATAFLOW_DEBUG(X) LLVM_DEBUG(X) +#else +#define DATAFLOW_DEBUG(X) +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// GenericProgramPoint +//===----------------------------------------------------------------------===// + +GenericProgramPoint::~GenericProgramPoint() = default; + +//===----------------------------------------------------------------------===// +// AnalysisState +//===----------------------------------------------------------------------===// + +AnalysisState::~AnalysisState() = default; + +//===----------------------------------------------------------------------===// +// ProgramPoint +//===----------------------------------------------------------------------===// + +void ProgramPoint::print(raw_ostream &os) const { + if (isNull()) { + os << ""; + return; + } + if (auto *programPoint = dyn_cast()) + return programPoint->print(os); + if (auto *op = dyn_cast()) + return op->print(os); + if (auto value = dyn_cast()) + return value.print(os); + if (auto *block = dyn_cast()) + return block->print(os); + auto *region = get(); + os << "{\n"; + for (Block &block : *region) { + block.print(os); + os << "\n"; + } + os << "}"; +} + +Location ProgramPoint::getLoc() const { + if (auto *programPoint = dyn_cast()) + return programPoint->getLoc(); + if (auto *op = dyn_cast()) + return op->getLoc(); + if (auto value = dyn_cast()) + return value.getLoc(); + if (auto *block = dyn_cast()) + return block->getParent()->getLoc(); + return get()->getLoc(); +} + +//===----------------------------------------------------------------------===// +// DataFlowSolver +//===----------------------------------------------------------------------===// + +LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { + // Initialize the analyses. + for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { + DATAFLOW_DEBUG(llvm::dbgs() + << "Priming analysis: " << analysis.debugName << "\n"); + if (failed(analysis.initialize(top))) + return failure(); + } + + // Run the analysis until fixpoint. + ProgramPoint point; + DataFlowAnalysis *analysis; + + do { + // Exhaust the worklist. + while (!worklist.empty()) { + std::tie(point, analysis) = worklist.front(); + worklist.pop(); + + DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName + << "' on: " << point << "\n"); + if (failed(analysis->visit(point))) + return failure(); + } + + // "Nudge" the state of the analysis by forcefully initializing states that + // are still uninitialized. All uninitialized states in the graph can be + // initialized in any order because the analysis reached fixpoint, meaning + // that there are no work items that would have further nudged the analysis. + for (AnalysisState &state : + llvm::make_pointee_range(llvm::make_second_range(analysisStates))) { + if (!state.isUninitialized()) + continue; + DATAFLOW_DEBUG(llvm::dbgs() << "Default initializing " << state.debugName + << " of " << state.point << "\n"); + propagateIfChanged(&state, state.defaultInitialize()); + } + + // Iterate until all states are in some initialized state and the worklist + // is exhausted. + } while (!worklist.empty()); + + return success(); +} + +void DataFlowSolver::propagateIfChanged(AnalysisState *state, + ChangeResult changed) { + if (changed == ChangeResult::Change) { + DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName + << " of " << state->point << "\n" + << "Value: " << *state << "\n"); + for (const WorkItem &item : state->dependents) + enqueue(item); + state->onUpdate(this); + } +} + +void DataFlowSolver::addDependency(AnalysisState *state, + DataFlowAnalysis *analysis, + ProgramPoint point) { + auto inserted = state->dependents.insert({point, analysis}); + (void)inserted; + DATAFLOW_DEBUG({ + if (inserted) { + llvm::dbgs() << "Creating dependency between " << state->debugName + << " of " << state->point << "\nand " << analysis->debugName + << " on " << point << "\n"; + } + }); +} + +//===----------------------------------------------------------------------===// +// DataFlowAnalysis +//===----------------------------------------------------------------------===// + +DataFlowAnalysis::~DataFlowAnalysis() = default; + +DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {} + +void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) { + solver.addDependency(state, this, point); +} + +void DataFlowAnalysis::propagateIfChanged(AnalysisState *state, + ChangeResult changed) { + solver.propagateIfChanged(state, changed); +} diff --git a/mlir/test/Analysis/test-foo-analysis.mlir b/mlir/test/Analysis/test-foo-analysis.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-foo-analysis.mlir @@ -0,0 +1,95 @@ +// RUN: mlir-opt -split-input-file -pass-pipeline='func.func(test-foo-analysis)' %s 2>&1 | FileCheck %s + +// CHECK-LABEL: function: @test_default_init +func.func @test_default_init() -> () { + // CHECK: a -> 0 + "test.foo"() {tag = "a"} : () -> () + return +} + +// ----- + +// CHECK-LABEL: function: @test_one_join +func.func @test_one_join() -> () { + // CHECK: a -> 0 + "test.foo"() {tag = "a"} : () -> () + // CHECK: b -> 1 + "test.foo"() {tag = "b", foo = 1 : ui64} : () -> () + return +} + +// ----- + +// CHECK-LABEL: function: @test_two_join +func.func @test_two_join() -> () { + // CHECK: a -> 0 + "test.foo"() {tag = "a"} : () -> () + // CHECK: b -> 1 + "test.foo"() {tag = "b", foo = 1 : ui64} : () -> () + // CHECK: c -> 0 + "test.foo"() {tag = "c", foo = 1 : ui64} : () -> () + return +} + +// ----- + +// CHECK-LABEL: function: @test_fork +func.func @test_fork() -> () { + // CHECK: init -> 1 + "test.branch"() [^bb0, ^bb1] {tag = "init", foo = 1 : ui64} : () -> () + +^bb0: + // CHECK: a -> 3 + "test.branch"() [^bb2] {tag = "a", foo = 2 : ui64} : () -> () + +^bb1: + // CHECK: b -> 5 + "test.branch"() [^bb2] {tag = "b", foo = 4 : ui64} : () -> () + +^bb2: + // CHECK: end -> 6 + "test.foo"() {tag = "end"} : () -> () + return + +} + +// ----- + +// CHECK-LABEL: function: @test_simple_loop +func.func @test_simple_loop() -> () { + // CHECK: init -> 1 + "test.branch"() [^bb0] {tag = "init", foo = 1 : ui64} : () -> () + +^bb0: + // CHECK: a -> 1 + "test.foo"() {tag = "a", foo = 3 : ui64} : () -> () + "test.branch"() [^bb0, ^bb1] : () -> () + +^bb1: + // CHECK: end -> 3 + "test.foo"() {tag = "end"} : () -> () + return +} + +// ----- + +// CHECK-LABEL: function: @test_double_loop +func.func @test_double_loop() -> () { + // CHECK: init -> 2 + "test.branch"() [^bb0] {tag = "init", foo = 2 : ui64} : () -> () + +^bb0: + // CHECK: a -> 1 + "test.foo"() {tag = "a", foo = 3 : ui64} : () -> () + "test.branch"() [^bb0, ^bb1] : () -> () + +^bb1: + // CHECK: b -> 4 + "test.foo"() {tag = "b", foo = 5 : ui64} : () -> () + "test.branch"() [^bb0, ^bb2] : () -> () + +^bb2: + // CHECK: end -> 4 + "test.foo"() {tag = "end"} : () -> () + return +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -3,6 +3,7 @@ TestAliasAnalysis.cpp TestCallGraph.cpp TestDataFlow.cpp + TestDataFlowFramework.cpp TestLiveness.cpp TestMatchReduction.cpp TestMemRefBoundCheck.cpp diff --git a/mlir/test/lib/Analysis/TestDataFlowFramework.cpp b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestDataFlowFramework.cpp @@ -0,0 +1,188 @@ +//===- TestDataFlowFramework.cpp - Test data-flow analysis framework ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This analysis state represents an integer that is XOR'd with other states. +class FooState : public AnalysisState { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooState) + + using AnalysisState::AnalysisState; + + /// Default-initialize the state to zero. + ChangeResult defaultInitialize() override { return join(0); } + + /// Returns true if the state is uninitialized. + bool isUninitialized() const override { return !state; } + + /// Print the integer value or "none" if uninitialized. + void print(raw_ostream &os) const override { + if (state) + os << *state; + else + os << "none"; + } + + /// Join the state with another. If either is unintialized, take the + /// initialized value. Otherwise, XOR the integer values. + ChangeResult join(const FooState &rhs) { + if (rhs.isUninitialized()) + return ChangeResult::NoChange; + return join(*rhs.state); + } + ChangeResult join(uint64_t value) { + if (isUninitialized()) { + state = value; + return ChangeResult::Change; + } + uint64_t before = *state; + state = before ^ value; + return before == *state ? ChangeResult::NoChange : ChangeResult::Change; + } + + /// Set the value of the state directly. + ChangeResult set(const FooState &rhs) { + if (state == rhs.state) + return ChangeResult::NoChange; + state = rhs.state; + return ChangeResult::Change; + } + + /// Returns the integer value of the state. + uint64_t getValue() const { return *state; } + +private: + /// An optional integer value. + Optional state; +}; + +/// This analysis computes `FooState` across operations and control-flow edges. +/// If an op specifies a `foo` integer attribute, the contained value is XOR'd +/// with the value before the operation. +class FooAnalysis : public DataFlowAnalysis { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FooAnalysis) + + using DataFlowAnalysis::DataFlowAnalysis; + + LogicalResult initialize(Operation *top) override; + LogicalResult visit(ProgramPoint point) override; + +private: + void visitBlock(Block *block); + void visitOperation(Operation *op); +}; + +struct TestFooAnalysisPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestFooAnalysisPass) + + StringRef getArgument() const override { return "test-foo-analysis"; } + + void runOnOperation() override; +}; +} // namespace + +LogicalResult FooAnalysis::initialize(Operation *top) { + if (top->getNumRegions() != 1) + return top->emitError("expected a single region top-level op"); + + // Initialize the top-level state. + getOrCreate(&top->getRegion(0).front())->join(0); + + // Visit all nested blocks and operations. + for (Block &block : top->getRegion(0)) { + visitBlock(&block); + for (Operation &op : block) { + if (op.getNumRegions()) + return op.emitError("unexpected op with regions"); + visitOperation(&op); + } + } + return success(); +} + +LogicalResult FooAnalysis::visit(ProgramPoint point) { + if (auto *op = point.dyn_cast()) { + visitOperation(op); + return success(); + } + if (auto *block = point.dyn_cast()) { + visitBlock(block); + return success(); + } + return emitError(point.getLoc(), "unknown point kind"); +} + +void FooAnalysis::visitBlock(Block *block) { + if (block->isEntryBlock()) { + // This is the initial state. Let the framework default-initialize it. + return; + } + FooState *state = getOrCreate(block); + ChangeResult result = ChangeResult::NoChange; + for (Block *pred : block->getPredecessors()) { + // Join the state at the terminators of all predecessors. + const FooState *predState = + getOrCreateFor(block, pred->getTerminator()); + result |= state->join(*predState); + } + propagateIfChanged(state, result); +} + +void FooAnalysis::visitOperation(Operation *op) { + FooState *state = getOrCreate(op); + ChangeResult result = ChangeResult::NoChange; + + // Copy the state across the operation. + const FooState *prevState; + if (Operation *prev = op->getPrevNode()) + prevState = getOrCreateFor(op, prev); + else + prevState = getOrCreateFor(op, op->getBlock()); + result |= state->set(*prevState); + + // Modify the state with the attribute, if specified. + if (auto attr = op->getAttrOfType("foo")) { + uint64_t value = attr.getUInt(); + result |= state->join(value); + } + propagateIfChanged(state, result); +} + +void TestFooAnalysisPass::runOnOperation() { + func::FuncOp func = getOperation(); + DataFlowSolver solver; + solver.load(); + if (failed(solver.initializeAndRun(func))) + return signalPassFailure(); + + raw_ostream &os = llvm::errs(); + os << "function: @" << func.getSymName() << "\n"; + + func.walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag) + return; + const FooState *state = solver.lookupState(op); + assert(state && !state->isUninitialized()); + os << tag.getValue() << " -> " << state->getValue() << "\n"; + }); +} + +namespace mlir { +namespace test { +void registerTestFooAnalysisPass() { PassRegistration(); } +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -77,6 +77,7 @@ void registerTestDominancePass(); void registerTestDynamicPipelinePass(); void registerTestExpandMathPass(); +void registerTestFooAnalysisPass(); void registerTestComposeSubView(); void registerTestMultiBuffering(); void registerTestIntRangeInference(); @@ -174,6 +175,7 @@ mlir::test::registerTestDominancePass(); mlir::test::registerTestDynamicPipelinePass(); mlir::test::registerTestExpandMathPass(); + mlir::test::registerTestFooAnalysisPass(); mlir::test::registerTestComposeSubView(); mlir::test::registerTestMultiBuffering(); mlir::test::registerTestIntRangeInference(); diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5784,17 +5784,12 @@ "lib/Analysis/*/*.cpp", "lib/Analysis/*/*.h", ], - exclude = [ - "lib/Analysis/Vector*.cpp", - "lib/Analysis/Vector*.h", - ], ), hdrs = glob( [ "include/mlir/Analysis/*.h", "include/mlir/Analysis/*/*.h", ], - exclude = ["include/mlir/Analysis/Vector*.h"], ), includes = ["include"], deps = [ diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -26,6 +26,7 @@ "//mlir:Affine", "//mlir:AffineAnalysis", "//mlir:Analysis", + "//mlir:FuncDialect", "//mlir:IR", "//mlir:MemRefDialect", "//mlir:Pass",