diff --git a/mlir/include/mlir/Analysis/DenseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/DenseDataFlowAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Analysis/DenseDataFlowAnalysis.h @@ -0,0 +1,167 @@ +//===- DenseDataFlowAnalysis.h - Dense data-flow analysis -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements dense data-flow analysis using the data-flow analysis +// framework. The analysis is forward and conditional and uses the results of +// dead code analysis to prune dead code during the analysis. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H +#define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H + +#include "mlir/Analysis/SparseDataFlowAnalysis.h" + +namespace mlir { + +//===----------------------------------------------------------------------===// +// AbstractDenseLattice +//===----------------------------------------------------------------------===// + +/// This class represents a dense lattice. A dense lattice is attached to +/// operations to represent the program state after their execution or to blocks +/// to represent the program state at the beginning of the block. A dense +/// lattice is propagated through the IR by dense data-flow analysis. +class AbstractDenseLattice : public AnalysisState { +public: + /// A dense lattice can only be created for operations and blocks. + using AnalysisState::AnalysisState; + + /// Join the lattice across control-flow or callgraph edges. + virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0; + + /// Reset the dense lattice to a pessimistic value. This occurs when the + /// analysis cannot reason about the data-flow. + virtual ChangeResult reset() = 0; + + /// Returns true if the lattice state has reached a pessimistic fixpoint. That + /// is, no further modifications to the lattice can occur. + virtual bool isAtFixpoint() const = 0; +}; + +//===----------------------------------------------------------------------===// +// AbstractDenseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// Base class for dense data-flow analyses. Dense data-flow analysis attaches a +/// lattice between the execution of operations and implements a transfer +/// function from the lattice before each operation to the lattice after. The +/// lattice contains information about the state of the program at that point. +/// +/// In this implementation, a lattice attached to an operation represents the +/// state of the program after its execution, and a lattice attached to block +/// represents the state of the program right before it starts executing its +/// body. +class AbstractDenseDataFlowAnalysis : public DataFlowAnalysis { +public: + using DataFlowAnalysis::DataFlowAnalysis; + + /// Initialize the analysis by visiting every program point whose execution + /// may modify the program state; that is, every operation and block. + LogicalResult initialize(Operation *top) override; + + /// Visit a program point that modifies the state of the program. If this is a + /// block, then the state is propagated from control-flow predecessors or + /// callsites. If this is a call operation or region control-flow operation, + /// then the state after the execution of the operation is set by control-flow + /// or the callgraph. Otherwise, this function invokes the operation transfer + /// function. + LogicalResult visit(ProgramPoint point) override; + +protected: + /// Propagate the dense lattice before the execution of an operation to the + /// lattice after its execution. + virtual void visitOperationImpl(Operation *op, + const AbstractDenseLattice &before, + AbstractDenseLattice *after) = 0; + + /// Get the dense lattice after the execution of the given program point. + virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0; + + /// Get the dense lattice after the execution of the given program point and + /// add it as a dependency to a program point. + const AbstractDenseLattice *getLatticeFor(ProgramPoint dependee, + ProgramPoint point); + + /// Mark the dense lattice as having reached its pessimistic fixpoint and + /// propagate an update if it changed. + void reset(AbstractDenseLattice *lattice) { + propagateIfChanged(lattice, lattice->reset()); + } + + /// Join a lattice with another and propagate an update if it changed. + void join(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) { + propagateIfChanged(lhs, lhs->join(rhs)); + } + +private: + /// Visit an operation. If this is a call operation or region control-flow + /// operation, then the state after the execution of the operation is set by + /// control-flow or the callgraph. Otherwise, this function invokes the + /// operation transfer function. + void visitOperation(Operation *op); + + /// Visit a block. The state at the start of the block is propagated from + /// control-flow predecessors or callsites + void visitBlock(Block *block); + + /// Visit a program point within a region branch operation with predecessors + /// in it. This can either be an entry block of one of the regions of the + /// parent operation itself. + void visitRegionBranchOperation(ProgramPoint point, + RegionBranchOpInterface branch, + AbstractDenseLattice *after); +}; + +//===----------------------------------------------------------------------===// +// DenseDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// A dense (forward) data-flow analysis for propagating lattices before and +/// after the execution of every operation across the IR by implementing +/// transfer functions for operations. +/// +/// `StateT` is expected to be a subclass of `AbstractDenseLattice`. +template +class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis { +public: + using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis; + + /// Visit an operation with the dense lattice before its execution. This + /// function is expected to set the dense lattice after its execution. + virtual void visitOperation(Operation *op, const LatticeT &before, + LatticeT *after) = 0; + +protected: + /// Get the dense lattice after this program point. + LatticeT *getLattice(ProgramPoint point) override { + return getOrCreate(point); + } + +private: + /// Type-erased wrappers that convert the abstract dense lattice to a derived + /// lattice and invoke the virtual hooks operating on the derived lattice. + void visitOperationImpl(Operation *op, const AbstractDenseLattice &before, + AbstractDenseLattice *after) override { + visitOperation(op, static_cast(before), + static_cast(after)); + } +}; + +//===----------------------------------------------------------------------===// +// DenseLattice +//===----------------------------------------------------------------------===// + +template +class DenseLattice : public AbstractDenseLattice { +public: +}; + +} // end namespace mlir + +#endif // MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H diff --git a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h --- a/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h +++ b/mlir/include/mlir/Analysis/SparseDataFlowAnalysis.h @@ -84,11 +84,10 @@ /// template class Lattice : public AbstractLattice { public: - using AbstractLattice::AbstractLattice; - - /// Get a lattice element with a known value. - Lattice(const ValueT &knownValue = ValueT()) - : AbstractLattice(Value()), knownValue(knownValue) {} + /// Construct a lattice with a known value. + explicit Lattice(Value value) + : AbstractLattice(value), + knownValue(ValueT::getPessimisticValueState(value)) {} /// Return the value held by this lattice. This requires that the value is /// initialized. @@ -245,6 +244,9 @@ /// This lattice value represents a known constant value of a lattice. class ConstantValue { public: + /// The pessimistic value state of the constant value is unknown. + static ConstantValue getPessimisticValueState(Value value) { return {}; } + /// Construct a constant value with a known constant. ConstantValue(Attribute knownValue = {}, Dialect *dialect = nullptr) : constant(knownValue), dialect(dialect) {} diff --git a/mlir/lib/Analysis/DenseDataFlowAnalysis.cpp b/mlir/lib/Analysis/DenseDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Analysis/DenseDataFlowAnalysis.cpp @@ -0,0 +1,164 @@ +//===- DenseDataFlowAnalysis.cpp - Dense data-flow analysis ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DenseDataFlowAnalysis.h" + +using namespace mlir; + +LogicalResult AbstractDenseDataFlowAnalysis::initialize(Operation *top) { + // Visit every operation and block. + visitOperation(top); + for (Region ®ion : top->getRegions()) { + for (Block &block : region) { + visitBlock(&block); + for (Operation &op : block) + if (failed(initialize(&op))) + return failure(); + } + } + return success(); +} + +LogicalResult AbstractDenseDataFlowAnalysis::visit(ProgramPoint point) { + if (auto *op = point.dyn_cast()) + visitOperation(op); + else if (auto *block = point.dyn_cast()) + visitBlock(block); + else + return failure(); + return success(); +} + +void AbstractDenseDataFlowAnalysis::visitOperation(Operation *op) { + // If the containing block is not executable, bail out. + if (!getOrCreateFor(op, op->getBlock())->isLive()) + return; + + // Get the dense lattice to update. + AbstractDenseLattice *after = getLattice(op); + if (after->isAtFixpoint()) + return; + + // If this op implements region control-flow, then control-flow dictates its + // transfer function. + if (auto branch = dyn_cast(op)) + return visitRegionBranchOperation(op, branch, after); + + // If this is a call operation, then join its lattices across known return + // sites. + if (auto call = dyn_cast(op)) { + const auto *predecessors = getOrCreateFor(op, call); + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (!predecessors->allPredecessorsKnown()) + return reset(after); + for (Operation *predecessor : predecessors->getKnownPredecessors()) + join(after, *getLatticeFor(op, predecessor)); + return; + } + + // Get the dense state before the execution of the op. + const AbstractDenseLattice *before; + if (Operation *prev = op->getPrevNode()) + before = getLatticeFor(op, prev); + else + before = getLatticeFor(op, op->getBlock()); + // If the incoming lattice is uninitialized, bail out. + if (before->isUninitialized()) + return; + + // Invoke the operation transfer function. + visitOperationImpl(op, *before, after); +} + +void AbstractDenseDataFlowAnalysis::visitBlock(Block *block) { + // If the block is not executable, bail out. + if (!getOrCreateFor(block, block)->isLive()) + return; + + // Get the dense lattice to update. + AbstractDenseLattice *after = getLattice(block); + if (after->isAtFixpoint()) + return; + + // The dense lattices of entry blocks are set by region control-flow or the + // callgraph. + if (block->isEntryBlock()) { + // Check if this block is the entry block of a callable region. + auto callable = dyn_cast(block->getParentOp()); + if (callable && callable.getCallableRegion() == block->getParent()) { + const auto *callsites = getOrCreateFor(block, callable); + // If not all callsites are known, conservatively mark all lattices as + // having reached their pessimistic fixpoints. + if (!callsites->allPredecessorsKnown()) + return reset(after); + for (Operation *callsite : callsites->getKnownPredecessors()) { + // Get the dense lattice before the callsite. + if (Operation *prev = callsite->getPrevNode()) + join(after, *getLatticeFor(block, prev)); + else + join(after, *getLatticeFor(block, callsite->getBlock())); + } + return; + } + + // Check if we can reason about the control-flow. + if (auto branch = dyn_cast(block->getParentOp())) + return visitRegionBranchOperation(block, branch, after); + + // Otherwise, we can't reason about the data-flow. + return reset(after); + } + + // Join the state with the state after the block's predecessors. + for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); + it != e; ++it) { + // Skip control edges that aren't executable. + Block *predecessor = *it; + if (!getOrCreateFor( + block, getProgramPoint(predecessor, block)) + ->isLive()) + continue; + + // Merge in the state from the predecessor's terminator. + join(after, *getLatticeFor(block, predecessor->getTerminator())); + } +} + +void AbstractDenseDataFlowAnalysis::visitRegionBranchOperation( + ProgramPoint point, RegionBranchOpInterface branch, + AbstractDenseLattice *after) { + // Get the terminator predecessors. + const auto *predecessors = getOrCreateFor(point, point); + assert(predecessors->allPredecessorsKnown() && + "unexpected unresolved region successors"); + + for (Operation *op : predecessors->getKnownPredecessors()) { + const AbstractDenseLattice *before; + // If the predecessor is the parent, get the state before the parent. + if (op == branch) { + if (Operation *prev = op->getPrevNode()) + before = getLatticeFor(point, prev); + else + before = getLatticeFor(point, op->getBlock()); + + // Otherwise, get the state after the terminator. + } else { + before = getLatticeFor(point, op); + } + join(after, *before); + } +} + +const AbstractDenseLattice * +AbstractDenseDataFlowAnalysis::getLatticeFor(ProgramPoint dependee, + ProgramPoint point) { + AbstractDenseLattice *state = getLattice(point); + addDependency(state, dependee); + return state; +} diff --git a/mlir/test/Analysis/test-last-modified-callgraph.mlir b/mlir/test/Analysis/test-last-modified-callgraph.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-last-modified-callgraph.mlir @@ -0,0 +1,67 @@ +// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s + +// CHECK-LABEL: test_tag: test_callsite +// CHECK: operand #0 +// CHECK-NEXT: - a +func.func private @single_callsite_fn(%ptr: memref) -> memref { + return {tag = "test_callsite"} %ptr : memref +} + +func.func @test_callsite() { + %ptr = memref.alloc() : memref + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "a"} : memref + %0 = func.call @single_callsite_fn(%ptr) : (memref) -> memref + return +} + +// CHECK-LABEL: test_tag: test_return_site +// CHECK: operand #0 +// CHECK-NEXT: - b +func.func private @single_return_site_fn(%ptr: memref) -> memref { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "b"} : memref + return %ptr : memref +} + +// CHECK-LABEL: test_tag: test_multiple_callsites +// CHECK: operand #0 +// CHECK-NEXT: write0 +// CHECK-NEXT: write1 +func.func @test_return_site(%ptr: memref) -> memref { + %0 = func.call @single_return_site_fn(%ptr) : (memref) -> memref + return {tag = "test_return_site"} %0 : memref +} + +func.func private @multiple_callsite_fn(%ptr: memref) -> memref { + return {tag = "test_multiple_callsites"} %ptr : memref +} + +func.func @test_multiple_callsites(%a: i32, %ptr: memref) -> memref { + memref.store %a, %ptr[] {tag_name = "write0"} : memref + %0 = func.call @multiple_callsite_fn(%ptr) : (memref) -> memref + memref.store %a, %ptr[] {tag_name = "write1"} : memref + %1 = func.call @multiple_callsite_fn(%ptr) : (memref) -> memref + return %ptr : memref +} + +// CHECK-LABEL: test_tag: test_multiple_return_sites +// CHECK: operand #0 +// CHECK-NEXT: return0 +// CHECK-NEXT: return1 +func.func private @multiple_return_site_fn(%cond: i1, %a: i32, %ptr: memref) -> memref { + cf.cond_br %cond, ^a, ^b + +^a: + memref.store %a, %ptr[] {tag_name = "return0"} : memref + return %ptr : memref + +^b: + memref.store %a, %ptr[] {tag_name = "return1"} : memref + return %ptr : memref +} + +func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref) -> memref { + %0 = func.call @multiple_return_site_fn(%cond, %a, %ptr) : (i1, i32, memref) -> memref + return {tag = "test_multiple_return_sites"} %0 : memref +} \ No newline at end of file diff --git a/mlir/test/Analysis/test-last-modified.mlir b/mlir/test/Analysis/test-last-modified.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/test-last-modified.mlir @@ -0,0 +1,115 @@ +// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s + +// CHECK-LABEL: test_tag: test_simple_mod +// CHECK: operand #0 +// CHECK-NEXT: - a +// CHECK: operand #1 +// CHECK-NEXT: - b +func.func @test_simple_mod(%arg0: memref, %arg1: memref) -> (memref, memref) { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + memref.store %c0, %arg0[] {tag_name = "a"} : memref + memref.store %c1, %arg1[] {tag_name = "b"} : memref + return {tag = "test_simple_mod"} %arg0, %arg1 : memref, memref +} + +// CHECK-LABEL: test_tag: test_simple_mod_overwrite_a +// CHECK: operand #1 +// CHECK-NEXT: - a +// CHECK-LABEL: test_tag: test_simple_mod_overwrite_b +// CHECK: operand #0 +// CHECK-NEXT: - b +func.func @test_simple_mod_overwrite(%arg0: memref) -> memref { + %c0 = arith.constant 0 : i32 + memref.store %c0, %arg0[] {tag = "test_simple_mod_overwrite_a", tag_name = "a"} : memref + %c1 = arith.constant 1 : i32 + memref.store %c1, %arg0[] {tag_name = "b"} : memref + return {tag = "test_simple_mod_overwrite_b"} %arg0 : memref +} + +// CHECK-LABEL: test_tag: test_mod_control_flow +// CHECK: operand #0 +// CHECK-NEXT: - b +// CHECK-NEXT: - a +func.func @test_mod_control_flow(%cond: i1, %ptr: memref) -> memref { + cf.cond_br %cond, ^a, ^b + +^a: + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "a"} : memref + cf.br ^c + +^b: + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "b"} : memref + cf.br ^c + +^c: + return {tag = "test_mod_control_flow"} %ptr : memref +} + +// CHECK-LABEL: test_tag: test_mod_dead_branch +// CHECK: operand #0 +// CHECK-NEXT: - a +func.func @test_mod_dead_branch(%arg: i32, %ptr: memref) -> memref { + %0 = arith.subi %arg, %arg : i32 + %1 = arith.constant -1 : i32 + %2 = arith.cmpi sgt, %0, %1 : i32 + cf.cond_br %2, ^a, ^b + +^a: + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "a"} : memref + cf.br ^c + +^b: + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "b"} : memref + cf.br ^c + +^c: + return {tag = "test_mod_dead_branch"} %ptr : memref +} + +// CHECK-LABEL: test_tag: test_mod_region_control_flow +// CHECK: operand #0 +// CHECK-NEXT: then +// CHECK-NEXT: else +func.func @test_mod_region_control_flow(%cond: i1, %ptr: memref) -> memref { + scf.if %cond { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "then"}: memref + } else { + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "else"} : memref + } + return {tag = "test_mod_region_control_flow"} %ptr : memref +} + +// CHECK-LABEL: test_tag: test_mod_dead_region +// CHECK: operand #0 +// CHECK-NEXT: else +func.func @test_mod_dead_region(%ptr: memref) -> memref { + %false = arith.constant false + scf.if %false { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag_name = "then"}: memref + } else { + %c1 = arith.constant 1 : i32 + memref.store %c1, %ptr[] {tag_name = "else"} : memref + } + return {tag = "test_mod_dead_region"} %ptr : memref +} + +// CHECK-LABEL: test_tag: unknown_memory_effects_a +// CHECK: operand #1 +// CHECK-NEXT: - a +// CHECK-LABEL: test_tag: unknown_memory_effects_b +// CHECK: operand #0 +// CHECK-NEXT: - +func.func @unknown_memory_effects(%ptr: memref) -> memref { + %c0 = arith.constant 0 : i32 + memref.store %c0, %ptr[] {tag = "unknown_memory_effects_a", tag_name = "a"} : memref + "test.unknown_effects"() : () -> () + return {tag = "unknown_memory_effects_b"} %ptr : memref +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ TestDataFlow.cpp TestDataFlowFramework.cpp TestDeadCodeAnalysis.cpp + TestDenseDataFlowAnalysis.cpp TestLiveness.cpp TestMatchReduction.cpp TestMemRefBoundCheck.cpp diff --git a/mlir/test/lib/Analysis/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/TestDenseDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/TestDenseDataFlowAnalysis.cpp @@ -0,0 +1,274 @@ +//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DenseDataFlowAnalysis.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// This lattice represents a single underlying value for an SSA value. +class UnderlyingValue { +public: + /// The pessimistic underlying value of a value is itself. + static UnderlyingValue getPessimisticValueState(Value value) { + return {value}; + } + + /// Create an underlying value state with a known underlying value. + UnderlyingValue(Value underlyingValue = {}) + : underlyingValue(underlyingValue) {} + + /// Returns the underlying value. + Value getUnderlyingValue() const { return underlyingValue; } + + /// Join two underlying values. If there are conflicting underlying values, + /// go to the pessimistic value. + static UnderlyingValue join(const UnderlyingValue &lhs, + const UnderlyingValue &rhs) { + return lhs.underlyingValue == rhs.underlyingValue ? lhs : UnderlyingValue(); + } + + /// Compare underlying values. + bool operator==(const UnderlyingValue &rhs) const { + return underlyingValue == rhs.underlyingValue; + } + + void print(raw_ostream &os) const { os << underlyingValue; } + +private: + Value underlyingValue; +}; + +/// This lattice represents, for a given memory resource, the potential last +/// operations that modified the resource. +class LastModification : public AbstractDenseLattice { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) + + using AbstractDenseLattice::AbstractDenseLattice; + + /// The lattice is always initialized. + bool isUninitialized() const override { return false; } + + /// Initialize the lattice. Does nothing. + ChangeResult defaultInitialize() override { return ChangeResult::NoChange; } + + /// Mark the lattice as having reached its pessimistic fixpoint. That is, the + /// last modifications of all memory resources are unknown. + ChangeResult reset() override { + if (lastMods.empty()) + return ChangeResult::NoChange; + lastMods.clear(); + return ChangeResult::Change; + } + + /// The lattice is never at a fixpoint. + bool isAtFixpoint() const override { return false; } + + /// Join the last modifications. + ChangeResult join(const AbstractDenseLattice &lattice) override { + const auto &rhs = static_cast(lattice); + ChangeResult result = ChangeResult::NoChange; + for (const auto &mod : rhs.lastMods) { + auto &lhsMod = lastMods[mod.first]; + if (lhsMod != mod.second) { + lhsMod.insert(mod.second.begin(), mod.second.end()); + result |= ChangeResult::Change; + } + } + return result; + } + + /// Set the last modification of a value. + ChangeResult set(Value value, Operation *op) { + auto &lastMod = lastMods[value]; + ChangeResult result = ChangeResult::NoChange; + if (lastMod.size() != 1 || *lastMod.begin() != op) { + result = ChangeResult::Change; + lastMod.clear(); + lastMod.insert(op); + } + return result; + } + + /// Get the last modifications of a value. Returns none if the last + /// modifications are not known. + Optional> getLastModifiers(Value value) const { + auto it = lastMods.find(value); + if (it == lastMods.end()) + return {}; + return it->second.getArrayRef(); + } + + void print(raw_ostream &os) const override { + for (const auto &lastMod : lastMods) { + os << lastMod.first << ":\n"; + for (Operation *op : lastMod.second) + os << " " << *op << "\n"; + } + } + +private: + /// The potential last modifications of a memory resource. Use a set vector to + /// keep the results deterministic. + DenseMap, + SmallPtrSet>> + lastMods; +}; + +class LastModifiedAnalysis : public DenseDataFlowAnalysis { +public: + using DenseDataFlowAnalysis::DenseDataFlowAnalysis; + + /// Visit an operation. If the operation has no memory effects, then the state + /// is propagated with no change. If the operation allocates a resource, then + /// its reaching definitions is set to empty. If the operation writes to a + /// resource, then its reaching definition is set to the written value. + void visitOperation(Operation *op, const LastModification &before, + LastModification *after) override; +}; + +/// Define the lattice class explicitly to provide a type ID. +struct UnderlyingValueLattice : public Lattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice) + using Lattice::Lattice; +}; + +/// An analysis that uses forwarding of values along control-flow and callgraph +/// edges to determine single underlying values for block arguments. This +/// analysis exists so that the test analysis and pass can test the behaviour of +/// the dense data-flow analysis on the callgraph. +class UnderlyingValueAnalysis + : public SparseDataFlowAnalysis { +public: + using SparseDataFlowAnalysis::SparseDataFlowAnalysis; + + /// The underlying value of the results of an operation are not known. + void visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + markAllPessimisticFixpoint(results); + } +}; +} // end anonymous namespace + +/// Look for the most underlying value of a value. +static Value getMostUnderlyingValue( + Value value, + function_ref getUnderlyingValueFn) { + const UnderlyingValueLattice *underlying; + do { + underlying = getUnderlyingValueFn(value); + if (!underlying || underlying->isUninitialized()) + return {}; + Value underlyingValue = underlying->getValue().getUnderlyingValue(); + if (underlyingValue == value) + break; + value = underlyingValue; + } while (true); + return value; +} + +void LastModifiedAnalysis::visitOperation(Operation *op, + const LastModification &before, + LastModification *after) { + auto memory = dyn_cast(op); + // If we can't reason about the memory effects, then conservatively assume we + // can't deduce anything about the last modifications. + if (!memory) + return reset(after); + + SmallVector effects; + memory.getEffects(effects); + + ChangeResult result = after->join(before); + for (const auto &effect : effects) { + Value value = effect.getValue(); + + // If we see an effect on anything other than a value, assume we can't + // deduce anything about the last modifications. + if (!value) + return reset(after); + + value = getMostUnderlyingValue(value, [&](Value value) { + return getOrCreateFor(op, value); + }); + if (!value) + return; + + // Nothing to do for reads. + if (isa(effect.getEffect())) + continue; + + result |= after->set(value, op); + } + propagateIfChanged(after, result); +} + +namespace { +struct TestLastModifiedPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLastModifiedPass) + + StringRef getArgument() const override { return "test-last-modified"; } + + void runOnOperation() override { + Operation *op = getOperation(); + + DataFlowSolver solver; + solver.load(); + solver.load(); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) + return signalPassFailure(); + + raw_ostream &os = llvm::errs(); + + op->walk([&](Operation *op) { + auto tag = op->getAttrOfType("tag"); + if (!tag) + return; + os << "test_tag: " << tag.getValue() << ":\n"; + const LastModification *lastMods = + solver.lookupState(op); + assert(lastMods && "expected a dense lattice"); + for (auto &it : llvm::enumerate(op->getOperands())) { + os << " operand #" << it.index() << "\n"; + Value value = getMostUnderlyingValue(it.value(), [&](Value value) { + return solver.lookupState(value); + }); + assert(value && "expected an underlying value"); + if (Optional> lastMod = + lastMods->getLastModifiers(value)) { + for (Operation *lastModifier : *lastMod) { + if (auto tagName = + lastModifier->getAttrOfType("tag_name")) { + os << " - " << tagName.getValue() << "\n"; + } else { + os << " - " << lastModifier->getName() << "\n"; + } + } + } else { + os << " - \n"; + } + } + }); + } +}; +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestLastModifiedPass() { + PassRegistration(); +} +} // end namespace test +} // end namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -85,6 +85,7 @@ void registerTestGenericIRVisitorsPass(); void registerTestGenericIRVisitorsInterruptPass(); void registerTestInterfaces(); +void registerTestLastModifiedPass(); void registerTestLinalgCodegenStrategy(); void registerTestLinalgElementwiseFusion(); void registerTestLinalgFusionTransforms(); @@ -182,6 +183,7 @@ mlir::test::registerTestIRVisitorsPass(); mlir::test::registerTestGenericIRVisitorsPass(); mlir::test::registerTestInterfaces(); + mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgCodegenStrategy(); mlir::test::registerTestLinalgElementwiseFusion(); mlir::test::registerTestLinalgFusionTransforms();