diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -16,6 +16,7 @@ #define MLIR_ANALYSIS_DENSEDATAFLOWANALYSIS_H #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/SymbolTable.h" namespace mlir { @@ -37,17 +38,24 @@ using AnalysisState::AnalysisState; /// Join the lattice across control-flow or callgraph edges. - virtual ChangeResult join(const AbstractDenseLattice &rhs) = 0; + virtual ChangeResult join(const AbstractDenseLattice &rhs) { + return ChangeResult::NoChange; + } + + virtual ChangeResult meet(const AbstractDenseLattice &rhs) { + return ChangeResult::NoChange; + } }; //===----------------------------------------------------------------------===// // AbstractDenseDataFlowAnalysis //===----------------------------------------------------------------------===// -/// Base class for dense data-flow analyses. Dense data-flow analysis attaches a -/// lattice between the execution of operations and implements a transfer -/// function from the lattice before each operation to the lattice after. The -/// lattice contains information about the state of the program at that point. +/// Base class for dense (forward) data-flow analyses. Dense data-flow analysis +/// attaches a lattice between the execution of operations and implements a +/// transfer function from the lattice before each operation to the lattice +/// after. The lattice contains information about the state of the program at +/// that point. /// /// In this implementation, a lattice attached to an operation represents the /// state of the program after its execution, and a lattice attached to block @@ -80,7 +88,8 @@ virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0; /// Get the dense lattice after the execution of the given program point and - /// add it as a dependency to a program point. + /// indicate that the `dependent` program point must be updated every time + /// `point` is. const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, ProgramPoint point); @@ -108,7 +117,7 @@ private: /// Visit a block. The state at the start of the block is propagated from - /// control-flow predecessors or callsites + /// control-flow predecessors or callsites. void visitBlock(Block *block); }; @@ -120,7 +129,7 @@ /// after the execution of every operation across the IR by implementing /// transfer functions for operations. /// -/// `StateT` is expected to be a subclass of `AbstractDenseLattice`. +/// `LatticeT` is expected to be a subclass of `AbstractDenseLattice`. template class DenseDataFlowAnalysis : public AbstractDenseDataFlowAnalysis { static_assert( @@ -131,7 +140,8 @@ using AbstractDenseDataFlowAnalysis::AbstractDenseDataFlowAnalysis; /// Visit an operation with the dense lattice before its execution. This - /// function is expected to set the dense lattice after its execution. + /// function is expected to set the dense lattice after its execution and + /// trigger change propagation in case of change. virtual void visitOperation(Operation *op, const LatticeT &before, LatticeT *after) = 0; @@ -157,6 +167,140 @@ } }; +//===----------------------------------------------------------------------===// +// AbstractDenseBackwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// Base class for dense backward dataflow analyses. Such analyses attach a +/// lattice between the execution of operations and implement a transfer +/// function from the lattice after the operation ot the lattice before it, thus +/// propagating backward. +/// +/// In this implementation, a lattice attached to an operation represents the +/// state of the program before its execution, and a lattice attached to a block +/// represents the state of the program before the end of the block, i.e., after +/// its terminator. +class AbstractDenseBackwardDataFlowAnalysis : public DataFlowAnalysis { +public: + /// Construct the analysis in the given solver. Takes a symbol table + /// collection that is used to cache symbol resolution in interprocedural part + /// of the analysis. The symbol table need not be prefilled. + AbstractDenseBackwardDataFlowAnalysis(DataFlowSolver &solver, + SymbolTableCollection &symbolTable) + : DataFlowAnalysis(solver), symbolTable(symbolTable) {} + + /// Initialize the analysis by visiting every program point whose execution + /// may modify the program state; that is, every operation and block. + LogicalResult initialize(Operation *top) override; + + /// Visit a program point that modifies the state of the program. The state is + /// propagated along control flow directions for branch-, region- and + /// call-based control flow using the respective interfaces. For other + /// operations, the state is propagated using the transfer function + /// (visitOperation). + /// + /// Note: the transfer function is currently *not* invoked for operations with + /// region or call interface, but *is* invoked for block terminators. + LogicalResult visit(ProgramPoint point) override; + +protected: + /// Propagate the dense lattice after the execution of an operation to the + /// lattice before its execution. + virtual void visitOperationImpl(Operation *op, + const AbstractDenseLattice &after, + AbstractDenseLattice *before) = 0; + + /// Get the dense lattice before the execution of the program point. That is, + /// before the execution of the given operation or after the execution of the + /// block. + virtual AbstractDenseLattice *getLattice(ProgramPoint point) = 0; + + /// Get the dense lattice before the execution of the program point `point` + /// and declare that the `dependent` program point must be updated every time + /// `point` is. + const AbstractDenseLattice *getLatticeFor(ProgramPoint dependent, + ProgramPoint point); + + /// Set the dense lattice before at the control flow exit point and propagate + /// the update if it changed. + virtual void setToExitState(AbstractDenseLattice *lattice) = 0; + + /// Meet a lattice with another lattice and propagate an update if ti changed + void meet(AbstractDenseLattice *lhs, const AbstractDenseLattice &rhs) { + propagateIfChanged(lhs, lhs->meet(rhs)); + } + + /// Visit an operation. If this is a call operation or region control-flow + /// operation, then the state after the execution of the operation is set by + /// control-flow or the callgraph. Otherwise, this function invokes the + /// operation transfer function. + virtual void processOperation(Operation *op); + + /// Visit a program point within a region branch operation with successors + /// (from which the state is propagated) in or after it. `regionNo` indicates + /// the region that contains the successor, `nullopt` indicating the successor + /// of the branch operation itself. + void visitRegionBranchOperation(ProgramPoint point, + RegionBranchOpInterface branch, + std::optional regionNo, + AbstractDenseLattice *before); + +private: + /// VIsit a block. The state and the end of the block is propagated from + /// control-flow successors of the block or callsites. + void visitBlock(Block *block); + + /// Symbol table for call-level control flow. + SymbolTableCollection &symbolTable; +}; + +//===----------------------------------------------------------------------===// +// DenseBackwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +/// A dense backward dataflow analysis propagating lattices after and before the +/// execution of every operation across the IR by implementing transfer +/// functions for opreations. +/// +/// `LatticeT` is expected to be a subclass of `AbstractDenseLattice`. +template +class DenseBackwardDataFlowAnalysis + : public AbstractDenseBackwardDataFlowAnalysis { + static_assert(std::is_base_of_v, + "analysis state expected to subclass AbstractDenseLattice"); + +public: + using AbstractDenseBackwardDataFlowAnalysis:: + AbstractDenseBackwardDataFlowAnalysis; + + /// Transfer function. Visits an operation with the dense lattice after its + /// execution. This function is expected to set the dense lattice before its + /// execution and trigger propagation in case of change. + virtual void visitOperation(Operation *op, const LatticeT &after, + LatticeT *before) = 0; + +protected: + /// Get the dense lattice at the given program point. + LatticeT *getLattice(ProgramPoint point) override { + return getOrCreate(point); + } + + /// Set the dense lattice at control flow exit point (after the terminator) + /// and propagate an update if it changed. + virtual void setToExitState(LatticeT *lattice) = 0; + void setToExitState(AbstractDenseLattice *lattice) override { + setToExitState(static_cast(lattice)); + } + + /// Type-erased wrapper that convert the abstract dense lattice to a derived + /// lattice and invoke the virtual hooks operating on the derived lattice. + void visitOperationImpl(Operation *op, const AbstractDenseLattice &after, + AbstractDenseLattice *before) override { + visitOperation(op, static_cast(after), + static_cast(before)); + } +}; + } // end namespace dataflow } // end namespace mlir diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -19,6 +19,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Support/StorageUniquer.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/TypeName.h" #include @@ -181,7 +182,7 @@ /// The general data-flow analysis solver. This class is responsible for /// orchestrating child data-flow analyses, running the fixed-point iteration /// algorithm, managing analysis state and program point memory, and tracking -/// dependencies beteen analyses, program points, and analysis states. +/// dependencies between analyses, program points, and analysis states. /// /// Steps to run a data-flow analysis: /// @@ -282,11 +283,12 @@ /// Create the analysis state at the given program point. AnalysisState(ProgramPoint point) : point(point) {} - /// Returns the program point this static is located at. + /// Returns the program point this state is located at. ProgramPoint getPoint() const { return point; } /// Print the contents of the analysis state. virtual void print(raw_ostream &os) const = 0; + LLVM_DUMP_METHOD void dump() const; /// Add a dependency to this analysis state on a program point and an /// analysis. If this state is updated, the analysis will be invoked on the @@ -378,7 +380,7 @@ /// dependents are placed on the worklist. /// /// The dependency graph does not need to be static. Each invocation of - /// `visit` can add new dependencies, but these dependecies will not be + /// `visit` can add new dependencies, but these dependencies will not be /// dynamically added to the worklist because the solver doesn't know what /// will provide a value for then. virtual LogicalResult visit(ProgramPoint point) = 0; @@ -403,7 +405,7 @@ return solver.getProgramPoint(std::forward(args)...); } - /// Get the analysis state assiocated with the program point. The returned + /// Get the analysis state associated with the program point. The returned /// state is expected to be "write-only", and any updates need to be /// propagated by `propagateIfChanged`. template diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -163,3 +163,197 @@ addDependency(state, dependent); return state; } + +//===----------------------------------------------------------------------===// +// AbstractDenseBackwardDataFlowAnalysis +//===----------------------------------------------------------------------===// + +LogicalResult +AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) { + // Visit every operation and block. + processOperation(top); + for (Region ®ion : top->getRegions()) { + for (Block &block : region) { + visitBlock(&block); + for (Operation &op : llvm::reverse(block)) { + if (failed(initialize(&op))) + return failure(); + } + } + } + return success(); +} + +LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint point) { + if (auto *op = llvm::dyn_cast_if_present(point)) + processOperation(op); + else if (auto *block = llvm::dyn_cast_if_present(point)) + visitBlock(block); + else + return failure(); + return success(); +} + +void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) { + // If the containing block is not executable, bail out. + if (!getOrCreateFor(op, op->getBlock())->isLive()) + return; + + // Get the dense lattice to update. + AbstractDenseLattice *before = getLattice(op); + + // If the op implements region control flow, then the interface specifies the + // control function. + // TODO: this is not always true, e.g. linalg.generic, but is implement this + // way for consistency with the dense forward analysis. + if (auto branch = dyn_cast(op)) + return visitRegionBranchOperation(op, branch, std::nullopt, before); + + // If the op is a call-like, do inter-procedural data flow as follows: + // + // - find the callable (resolve via the symbol table), + // - get the entry block of the callable region, + // - take the state before the first operation if present or at block end + // otherwise, + // - meet that state with the state before the call-like op. + if (auto call = dyn_cast(op)) { + Operation *callee = call.resolveCallable(&symbolTable); + if (auto callable = dyn_cast(callee)) { + Region *region = callable.getCallableRegion(); + if (region && !region->empty()) { + Block *entryBlock = ®ion->front(); + if (entryBlock->empty()) + meet(before, *getLatticeFor(op, entryBlock)); + else + meet(before, *getLatticeFor(op, &entryBlock->front())); + } else { + setToExitState(before); + } + } else { + setToExitState(before); + } + return; + } + + // Get the dense state after execution of this op. + const AbstractDenseLattice *after; + if (Operation *next = op->getNextNode()) + after = getLatticeFor(op, next); + else + after = getLatticeFor(op, op->getBlock()); + + // Invoke the operation transfer function. + visitOperationImpl(op, *after, before); +} + +void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { + // If the block is not executable, bail out. + if (!getOrCreateFor(block, block)->isLive()) + return; + + AbstractDenseLattice *before = getLattice(block); + + // We need "exit" blocks, i.e. the blocks that may return control to the + // parent operation. + auto isExitBlock = [](Block *b) { + // Treat empty and terminator-less blocks as exit blocks. + if (b->empty() || !b->back().mightHaveTrait()) + return true; + + // There may be a weird case where a terminator may be transferring control + // either to the parent or to another block, so exit blocks and successors + // are not mutually exclusive. + Operation *terminator = b->getTerminator(); + return terminator && (terminator->hasTrait() || + isa(terminator)); + }; + if (isExitBlock(block)) { + // If this block is exiting from a callable, the successors of exiting from + // a callable are the successors of all call sites. And the call sites + // themselves are predecessors of the callable. + auto callable = dyn_cast(block->getParentOp()); + if (callable && callable.getCallableRegion() == block->getParent()) { + const auto *callsites = getOrCreateFor(block, callable); + // If not all call sites are known, conservative mark all lattices as + // having reached their pessimistic fix points. + if (!callsites->allPredecessorsKnown()) + return setToExitState(before); + + for (Operation *callsite : callsites->getKnownPredecessors()) { + if (Operation *next = callsite->getNextNode()) + meet(before, *getLatticeFor(block, next)); + else + meet(before, *getLatticeFor(block, callsite->getBlock())); + } + return; + } + + // If this block is exiting from an operation with region-based control + // flow, follow that flow. + if (auto branch = dyn_cast(block->getParentOp())) { + visitRegionBranchOperation(block, branch, + block->getParent()->getRegionNumber(), before); + return; + } + + // Cannot reason about successors of an exit block, set the pessimistic + // fixpoint. + return setToExitState(before); + } + + // Meet the state with the state before block's successors. + for (Block *successor : block->getSuccessors()) { + if (!getOrCreateFor(block, + getProgramPoint(block, successor)) + ->isLive()) + continue; + + // Merge in the state from the successor: either the first operation, or the + // block itself when empty. + if (successor->empty()) + meet(before, *getLatticeFor(block, successor)); + else + meet(before, *getLatticeFor(block, &successor->front())); + } +} + +void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation( + ProgramPoint point, RegionBranchOpInterface branch, + std::optional regionNo, AbstractDenseLattice *before) { + + // The successors of the operation may be either the first operation of the + // entry block of each possible successor region, or the next operation when + // the branch is a successor of itself. + SmallVector successors; + branch.getSuccessorRegions(regionNo, successors); + for (const RegionSuccessor &successor : successors) { + const AbstractDenseLattice *after; + if (successor.isParent() || successor.getSuccessor()->empty()) { + if (Operation *next = branch->getNextNode()) + after = getLatticeFor(point, next); + else + after = getLatticeFor(point, branch->getBlock()); + } else { + Region *successorRegion = successor.getSuccessor(); + assert(!successorRegion->empty() && "unexpected empty successor region"); + Block *successorBlock = &successorRegion->front(); + + if (!getOrCreateFor(point, successorBlock)->isLive()) + continue; + + if (successorBlock->empty()) + after = getLatticeFor(point, successorBlock); + else + after = getLatticeFor(point, &successorBlock->front()); + } + meet(before, *after); + } +} + +const AbstractDenseLattice * +AbstractDenseBackwardDataFlowAnalysis::getLatticeFor(ProgramPoint dependent, + ProgramPoint point) { + AbstractDenseLattice *state = getLattice(point); + addDependency(state, dependent); + return state; +} diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -43,6 +43,8 @@ }); } +void AnalysisState::dump() const { print(llvm::errs()); } + //===----------------------------------------------------------------------===// // ProgramPoint //===----------------------------------------------------------------------===// @@ -55,9 +57,9 @@ if (auto *programPoint = llvm::dyn_cast(*this)) return programPoint->print(os); if (auto *op = llvm::dyn_cast(*this)) - return op->print(os); + return op->print(os, OpPrintingFlags().skipRegions()); if (auto value = llvm::dyn_cast(*this)) - return value.print(os); + return value.print(os, OpPrintingFlags().skipRegions()); return get()->print(os); } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -537,6 +537,8 @@ /// induction variable. LoopOp only has one region, so 0 is the only valid value /// for `index`. OperandRange ForOp::getSuccessorEntryOperands(std::optional index) { + // if (!index) + // return getInitArgs().take_front(0); assert(index && *index == 0 && "invalid region index"); // The initial operands map to the loop arguments after the induction @@ -555,6 +557,7 @@ // If the predecessor is the ForOp, branch into the body using the iterator // arguments. if (!index) { + // regions.push_back(RegionSuccessor()); regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs())); return; } diff --git a/mlir/test/Analysis/DataFlow/test-next-access.mlir b/mlir/test/Analysis/DataFlow/test-next-access.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Analysis/DataFlow/test-next-access.mlir @@ -0,0 +1,359 @@ +// RUN: mlir-opt %s --test-next-access --split-input-file | FileCheck %s + +// CHECK-LABEL: @trivial +func.func @trivial(%arg0: memref, %arg1: f32) -> f32 { + // CHECK: name = "store" + // CHECK-SAME: next_access = ["unknown", ["load"]] + memref.store %arg1, %arg0[] {name = "store"} : memref + // CHECK: name = "load" + // CHECK-SAME: next_access = ["unknown"] + %0 = memref.load %arg0[] {name = "load"} : memref + return %0 : f32 +} + +// CHECK-LABEL: @chain +func.func @chain(%arg0: memref, %arg1: f32) -> f32 { + // CHECK: name = "store" + // CHECK-SAME: next_access = ["unknown", ["load 1"]] + memref.store %arg1, %arg0[] {name = "store"} : memref + // CHECK: name = "load 1" + // CHECK-SAME: next_access = {{\[}}["load 2"]] + %0 = memref.load %arg0[] {name = "load 1"} : memref + // CHECK: name = "load 2" + // CHECK-SAME: next_access = ["unknown"] + %1 = memref.load %arg0[] {name = "load 2"} : memref + %2 = arith.addf %0, %1 : f32 + return %2 : f32 +} + +// CHECK-LABEL: @branch +func.func @branch(%arg0: memref, %arg1: f32, %arg2: i1) -> f32 { + // CHECK: name = "store" + // CHECK-SAME: next_access = ["unknown", ["load 1", "load 2"]] + memref.store %arg1, %arg0[] {name = "store"} : memref + cf.cond_br %arg2, ^bb0, ^bb1 + +^bb0: + %0 = memref.load %arg0[] {name = "load 1"} : memref + cf.br ^bb2(%0 : f32) + +^bb1: + %1 = memref.load %arg0[] {name = "load 2"} : memref + cf.br ^bb2(%1 : f32) + +^bb2(%phi: f32): + return %phi : f32 +} + +// CHECK-LABEL @dead_branch +func.func @dead_branch(%arg0: memref, %arg1: f32) -> f32 { + // CHECK: name = "store" + // CHECK-SAME: next_access = ["unknown", ["load 2"]] + memref.store %arg1, %arg0[] {name = "store"} : memref + cf.br ^bb1 + +^bb0: + // CHECK: name = "load 1" + // CHECK-SAME: next_access = "not computed" + %0 = memref.load %arg0[] {name = "load 1"} : memref + cf.br ^bb2(%0 : f32) + +^bb1: + %1 = memref.load %arg0[] {name = "load 2"} : memref + cf.br ^bb2(%1 : f32) + +^bb2(%phi: f32): + return %phi : f32 +} + +// CHECK-LABEL: @loop +func.func @loop(%arg0: memref, %arg1: f32, %arg2: index, %arg3: index, %arg4: index) -> f32 { + %c0 = arith.constant 0.0 : f32 + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["loop"], "unknown"] + // The above is not entirely correct when the loop has 0 iterations, but + // the region control flow specificaiton is currently incapable of + // specifying that. + memref.load %arg0[%arg4] {name = "pre"} : memref + %l = scf.for %i = %arg2 to %arg3 step %arg4 iter_args(%ia = %c0) -> (f32) { + // CHECK: name = "loop" + // CHECK-SAME: next_access = {{\[}}["outside", "loop"], "unknown"] + %0 = memref.load %arg0[%i] {name = "loop"} : memref + %1 = arith.addf %ia, %0 : f32 + scf.yield %1 : f32 + } + %v = memref.load %arg0[%arg3] {name = "outside"} : memref + %2 = arith.addf %v, %l : f32 + return %2 : f32 +} + +// CHECK-LABEL: @conditional +func.func @conditional(%cond: i1, %arg0: memref) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["then"]] + // The above is not entirely correct when the condition is false, but + // the region control flow specificaiton is currently incapable of + // specifying that. + memref.load %arg0[] {name = "pre"}: memref + scf.if %cond { + // CHECK: name = "then" + // CHECK-SAME: next_access = {{\[}}["post"]] + memref.load %arg0[] {name = "then"} : memref + } + memref.load %arg0[] {name = "post"} : memref + return +} + +// CHECK-LABEL: @two_sided_conditional +func.func @two_sided_conditional(%cond: i1, %arg0: memref) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["then", "else"]] + memref.load %arg0[] {name = "pre"}: memref + scf.if %cond { + // CHECK: name = "then" + // CHECK-SAME: next_access = {{\[}}["post"]] + memref.load %arg0[] {name = "then"} : memref + } else { + // CHECK: name = "else" + // CHECK-SAME: next_access = {{\[}}["post"]] + memref.load %arg0[] {name = "else"} : memref + } + memref.load %arg0[] {name = "post"} : memref + return +} + +// CHECK-LABEL: @dead_conditional +func.func @dead_conditional(%arg0: memref) { + %false = arith.constant 0 : i1 + // CHECK: name = "pre" + // CHECK-SAME: next_access = ["unknown"] + // The above is not entirely correct when the condition is false, but + // the region control flow specificaiton is currently incapable of + // specifying that. + memref.load %arg0[] {name = "pre"}: memref + scf.if %false { + // CHECK: name = "then" + // CHECK-SAME: next_access = "not computed" + memref.load %arg0[] {name = "then"} : memref + } + memref.load %arg0[] {name = "post"} : memref + return +} + +// CHECK-LABEL: @known_conditional +func.func @known_conditional(%arg0: memref) { + %false = arith.constant 0 : i1 + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["else"]] + memref.load %arg0[] {name = "pre"}: memref + scf.if %false { + // CHECK: name = "then" + // CHECK-SAME: next_access = "not computed" + memref.load %arg0[] {name = "then"} : memref + } else { + // CHECK: name = "else" + // CHECK-SAME: next_access = {{\[}}["post"]] + memref.load %arg0[] {name = "else"} : memref + } + memref.load %arg0[] {name = "post"} : memref + return +} + +// CHECK-LABEL: @loop_cf +func.func @loop_cf(%arg0: memref, %arg1: f32, %arg2: index, %arg3: index, %arg4: index) -> f32 { + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["loop", "outside"], "unknown"] + %0 = memref.load %arg0[%arg4] {name = "pre"} : memref + cf.br ^bb1(%arg2, %cst : index, f32) +^bb1(%1: index, %2: f32): + %3 = arith.cmpi slt, %1, %arg3 : index + cf.cond_br %3, ^bb2, ^bb3 +^bb2: + // CHECK: name = "loop" + // CHECK-SAME: next_access = {{\[}}["loop", "outside"], "unknown"] + %4 = memref.load %arg0[%1] {name = "loop"} : memref + %5 = arith.addf %2, %4 : f32 + %6 = arith.addi %1, %arg4 : index + cf.br ^bb1(%6, %5 : index, f32) +^bb3: + %7 = memref.load %arg0[%arg3] {name = "outside"} : memref + %8 = arith.addf %7, %2 : f32 + return %8 : f32 +} + +// CHECK-LABEL @conditional_cf +func.func @conditional_cf(%arg0: i1, %arg1: memref) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["then", "post"]] + %0 = memref.load %arg1[] {name = "pre"} : memref + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + // CHECK: name = "then" + // CHECK-SAME: next_access = {{\[}}["post"]] + %1 = memref.load %arg1[] {name = "then"} : memref + cf.br ^bb2 +^bb2: + %2 = memref.load %arg1[] {name = "post"} : memref + return +} + +// CHECK-LABEL: @two_sided_conditional_cf +func.func @two_sided_conditional_cf(%arg0: i1, %arg1: memref) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["then", "else"]] + %0 = memref.load %arg1[] {name = "pre"} : memref + cf.cond_br %arg0, ^bb1, ^bb2 +^bb1: + // CHECK: name = "then" + // CHECK-SAME: next_access = {{\[}}["post"]] + %1 = memref.load %arg1[] {name = "then"} : memref + cf.br ^bb3 +^bb2: + // CHECK: name = "else" + // CHECK-SAME: next_access = {{\[}}["post"]] + %2 = memref.load %arg1[] {name = "else"} : memref + cf.br ^bb3 +^bb3: + %3 = memref.load %arg1[] {name = "post"} : memref + return +} + +// CHECK-LABEL: @dead_conditional_cf +func.func @dead_conditional_cf(%arg0: memref) { + %false = arith.constant false + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["post"]] + %0 = memref.load %arg0[] {name = "pre"} : memref + cf.cond_br %false, ^bb1, ^bb2 +^bb1: + // CHECK: name = "then" + // CHECK-SAME: next_access = "not computed" + %1 = memref.load %arg0[] {name = "then"} : memref + cf.br ^bb2 +^bb2: + %2 = memref.load %arg0[] {name = "post"} : memref + return +} + +// CHECK-LABEL: @known_conditional_cf +func.func @known_conditional_cf(%arg0: memref) { + %false = arith.constant false + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["else"]] + %0 = memref.load %arg0[] {name = "pre"} : memref + cf.cond_br %false, ^bb1, ^bb2 +^bb1: + // CHECK: name = "then" + // CHECK-SAME: next_access = "not computed" + %1 = memref.load %arg0[] {name = "then"} : memref + cf.br ^bb3 +^bb2: + // CHECK: name = "else" + // CHECK-SAME: next_access = {{\[}}["post"]] + %2 = memref.load %arg0[] {name = "else"} : memref + cf.br ^bb3 +^bb3: + %3 = memref.load %arg0[] {name = "post"} : memref + return +} + +// ----- + +func.func private @callee1(%arg0: memref) { + // CHECK: name = "callee1" + // CHECK-SAME: next_access = {{\[}}["post"]] + memref.load %arg0[] {name = "callee1"} : memref + return +} + +func.func private @callee2(%arg0: memref) { + // CHECK: name = "callee2" + // CHECK-SAME: next_access = "not computed" + memref.load %arg0[] {name = "callee2"} : memref + return +} + +// CHECK-LABEL: @simple_call +func.func @simple_call(%arg0: memref) { + // CHECK: name = "caller" + // CHECK-SAME: next_access = {{\[}}["callee1"]] + memref.load %arg0[] {name = "caller"} : memref + func.call @callee1(%arg0) : (memref) -> () + memref.load %arg0[] {name = "post"} : memref + return +} + +// ----- + +// CHECK-LABEL: @infinite_recursive_call +func.func @infinite_recursive_call(%arg0: memref) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["pre"]] + memref.load %arg0[] {name = "pre"} : memref + func.call @infinite_recursive_call(%arg0) : (memref) -> () + memref.load %arg0[] {name = "post"} : memref + return +} + +// ----- + +// CHECK-LABEL: @recursive_call +func.func @recursive_call(%arg0: memref, %cond: i1) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["pre"]] + // The above is not entirely correct when the condition is false, but + // the region control flow specificaiton is currently incapable of + // specifying that. + memref.load %arg0[] {name = "pre"} : memref + scf.if %cond { + func.call @recursive_call(%arg0, %cond) : (memref, i1) -> () + } + memref.load %arg0[] {name = "post"} : memref + return +} + +// ----- + +// CHECK-LABEL: @recursive_call_cf +func.func @recursive_call_cf(%arg0: memref, %cond: i1) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["pre", "post"]] + %0 = memref.load %arg0[] {name = "pre"} : memref + cf.cond_br %cond, ^bb1, ^bb2 +^bb1: + call @recursive_call_cf(%arg0, %cond) : (memref, i1) -> () + cf.br ^bb2 +^bb2: + %2 = memref.load %arg0[] {name = "post"} : memref + return +} + +// ----- + +func.func private @callee1(%arg0: memref) { + // CHECK: name = "callee1" + // CHECK-SAME: next_access = {{\[}}["post"]] + memref.load %arg0[] {name = "callee1"} : memref + return +} + +func.func private @callee2(%arg0: memref) { + // CHECK: name = "callee2" + // CHECK-SAME: next_access = {{\[}}["post"]] + memref.load %arg0[] {name = "callee2"} : memref + return +} + +func.func @conditonal_call(%arg0: memref, %cond: i1) { + // CHECK: name = "pre" + // CHECK-SAME: next_access = {{\[}}["callee1", "callee2"]] + memref.load %arg0[] {name = "pre"} : memref + scf.if %cond { + func.call @callee1(%arg0) : (memref) -> () + } else { + func.call @callee2(%arg0) : (memref) -> () + } + memref.load %arg0[] {name = "post"} : memref + return +} diff --git a/mlir/test/lib/Analysis/CMakeLists.txt b/mlir/test/lib/Analysis/CMakeLists.txt --- a/mlir/test/lib/Analysis/CMakeLists.txt +++ b/mlir/test/lib/Analysis/CMakeLists.txt @@ -14,6 +14,7 @@ DataFlow/TestDeadCodeAnalysis.cpp DataFlow/TestDenseDataFlowAnalysis.cpp DataFlow/TestBackwardDataFlowAnalysis.cpp + DataFlow/TestDenseBackwardDataFlowAnalysis.cpp EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -0,0 +1,168 @@ +//===- TestDenseBackwardDataFlowAnalysis.cpp - Test pass ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Test pass for backward dense dataflow analysis. +// +//===----------------------------------------------------------------------===// + +#include "TestDenseDataFlowAnalysis.h" +#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlow/DenseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/TypeID.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::dataflow; +using namespace mlir::dataflow::test; + +namespace { + +class NextAccess : public AbstractDenseLattice, public test::AccessLatticeBase { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) + + using dataflow::AbstractDenseLattice::AbstractDenseLattice; + + ChangeResult meet(const AbstractDenseLattice &lattice) override { + return AccessLatticeBase::merge(static_cast( + static_cast(lattice))); + } + + void print(raw_ostream &os) const override { + return AccessLatticeBase::print(os); + } +}; + +class NextAccessAnalysis : public DenseBackwardDataFlowAnalysis { +public: + using DenseBackwardDataFlowAnalysis::DenseBackwardDataFlowAnalysis; + + void visitOperation(Operation *op, const NextAccess &after, + NextAccess *before) override; + + // TODO: this isn't ideal for the analysis. When there is no next access, it + // means "we don't know what the next access is" rather than "there is no next + // access". But it's unclear how to differentiate the two cases... + void setToExitState(NextAccess *lattice) override { + propagateIfChanged(lattice, lattice->reset()); + } +}; +} // namespace + +void NextAccessAnalysis::visitOperation(Operation *op, const NextAccess &after, + NextAccess *before) { + auto memory = dyn_cast(op); + // If we can't reason about the memory effects, conservatively assume we can't + // say anything about the next access. + if (!memory) + return setToExitState(before); + + SmallVector effects; + memory.getEffects(effects); + ChangeResult result = before->meet(after); + for (const MemoryEffects::EffectInstance &effect : effects) { + Value value = effect.getValue(); + + // Effects with unspecified value are treated conservatively and we cannot + // assume anything about the next access. + if (!value) + return setToExitState(before); + + value = UnderlyingValueAnalysis::getMostUnderlyingValue( + value, [&](Value value) { + return getOrCreateFor(op, value); + }); + if (!value) + return; + + result |= before->set(value, op); + } + propagateIfChanged(before, result); +} + +namespace { +struct TestNextAccessPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestNextAccessPass) + + StringRef getArgument() const override { return "test-next-access"; } + + static constexpr llvm::StringLiteral kTagAttrName = "name"; + static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access"; + + void runOnOperation() override { + Operation *op = getOperation(); + SymbolTableCollection symbolTable; + + DataFlowSolver solver; + solver.load(); + solver.load(symbolTable); + solver.load(); + solver.load(); + if (failed(solver.initializeAndRun(op))) { + emitError(op->getLoc(), "dataflow solver failed"); + return signalPassFailure(); + } + + op->walk([&](Operation *op) { + auto tag = op->getAttrOfType(kTagAttrName); + if (!tag) + return; + + const NextAccess *nextAccess = solver.lookupState( + op->getNextNode() == nullptr ? ProgramPoint(op->getBlock()) + : op->getNextNode()); + if (!nextAccess) { + op->setAttr(kNextAccessAttrName, + StringAttr::get(op->getContext(), "not computed")); + return; + } + + SmallVector attrs; + for (Value operand : op->getOperands()) { + Value value = UnderlyingValueAnalysis::getMostUnderlyingValue( + operand, [&](Value value) { + return solver.lookupState(value); + }); + std::optional> nextAcc = + nextAccess->getAdjacentAccess(value); + if (!nextAcc) { + attrs.push_back(StringAttr::get(op->getContext(), "unknown")); + continue; + } + + SmallVector innerAttrs; + innerAttrs.reserve(nextAcc->size()); + for (Operation *nextAccOp : *nextAcc) { + if (auto nextAccTag = + nextAccOp->getAttrOfType(kTagAttrName)) { + innerAttrs.push_back(nextAccTag); + continue; + } + std::string repr; + llvm::raw_string_ostream os(repr); + nextAccOp->print(os); + innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); + } + attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); + } + + op->setAttr(kNextAccessAttrName, ArrayAttr::get(op->getContext(), attrs)); + }); + } +}; +} // namespace + +namespace mlir::test { +void registerTestNextAccessPass() { PassRegistration(); } +} // namespace mlir::test diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.h @@ -0,0 +1,171 @@ +//===- TestDenseDataFlowAnalysis.h - Dataflow test utilities ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/raw_ostream.h" +#include + +namespace mlir { +namespace dataflow { +namespace test { + +/// This lattice represents a single underlying value for an SSA value. +class UnderlyingValue { +public: + /// Create an underlying value state with a known underlying value. + explicit UnderlyingValue(std::optional underlyingValue = std::nullopt) + : underlyingValue(underlyingValue) {} + + /// Whether the state is uninitialized. + bool isUninitialized() const { return !underlyingValue.has_value(); } + + /// Returns the underlying value. + Value getUnderlyingValue() const { + assert(!isUninitialized()); + return *underlyingValue; + } + + /// Join two underlying values. If there are conflicting underlying values, + /// go to the pessimistic value. + static UnderlyingValue join(const UnderlyingValue &lhs, + const UnderlyingValue &rhs) { + if (lhs.isUninitialized()) + return rhs; + if (rhs.isUninitialized()) + return lhs; + return lhs.underlyingValue == rhs.underlyingValue + ? lhs + : UnderlyingValue(Value{}); + } + + /// Compare underlying values. + bool operator==(const UnderlyingValue &rhs) const { + return underlyingValue == rhs.underlyingValue; + } + + void print(raw_ostream &os) const { os << underlyingValue; } + +private: + std::optional underlyingValue; +}; + +/// This lattice represents, for a given memory resource, the potential last +/// operations that modified the resource. +class AccessLatticeBase { +public: + /// Clear all modifications. + ChangeResult reset() { + if (adjAccesses.empty()) + return ChangeResult::NoChange; + adjAccesses.clear(); + return ChangeResult::Change; + } + + /// Join the last modifications. + ChangeResult merge(const AccessLatticeBase &rhs) { + ChangeResult result = ChangeResult::NoChange; + for (const auto &mod : rhs.adjAccesses) { + auto &lhsMod = adjAccesses[mod.first]; + if (lhsMod != mod.second) { + lhsMod.insert(mod.second.begin(), mod.second.end()); + result |= ChangeResult::Change; + } + } + return result; + } + + /// Set the last modification of a value. + ChangeResult set(Value value, Operation *op) { + auto &lastMod = adjAccesses[value]; + ChangeResult result = ChangeResult::NoChange; + if (lastMod.size() != 1 || *lastMod.begin() != op) { + result = ChangeResult::Change; + lastMod.clear(); + lastMod.insert(op); + } + return result; + } + + /// Get the adjacent accesses to a value. Returns std::nullopt if they + /// are not known. + std::optional> getAdjacentAccess(Value value) const { + auto it = adjAccesses.find(value); + if (it == adjAccesses.end()) + return {}; + return it->second.getArrayRef(); + } + + void print(raw_ostream &os) const { + for (const auto &lastMod : adjAccesses) { + os << lastMod.first << ":\n"; + for (Operation *op : lastMod.second) + os << " " << *op << "\n"; + } + } + +private: + /// The potential adjacent accesses to a memory resource. Use a set vector to + /// keep the results deterministic. + DenseMap, + SmallPtrSet>> + adjAccesses; +}; + +/// Define the lattice class explicitly to provide a type ID. +struct UnderlyingValueLattice : public Lattice { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice) + using Lattice::Lattice; +}; + +/// An analysis that uses forwarding of values along control-flow and callgraph +/// edges to determine single underlying values for block arguments. This +/// analysis exists so that the test analysis and pass can test the behaviour of +/// the dense data-flow analysis on the callgraph. +class UnderlyingValueAnalysis + : public SparseDataFlowAnalysis { +public: + using SparseDataFlowAnalysis::SparseDataFlowAnalysis; + + /// The underlying value of the results of an operation are not known. + void visitOperation(Operation *op, + ArrayRef operands, + ArrayRef results) override { + setAllToEntryStates(results); + } + + /// At an entry point, the underlying value of a value is itself. + void setToEntryState(UnderlyingValueLattice *lattice) override { + propagateIfChanged(lattice, + lattice->join(UnderlyingValue{lattice->getPoint()})); + } + + /// Look for the most underlying value of a value. + static Value + getMostUnderlyingValue(Value value, + function_ref + getUnderlyingValueFn) { + const UnderlyingValueLattice *underlying; + do { + underlying = getUnderlyingValueFn(value); + if (!underlying || underlying->getValue().isUninitialized()) + return {}; + Value underlyingValue = underlying->getValue().getUnderlyingValue(); + if (underlyingValue == value) + break; + value = underlyingValue; + } while (true); + return value; + } +}; + +} // namespace test +} // namespace dataflow +} // namespace mlir diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp @@ -1,4 +1,4 @@ -//===- TestDeadCodeAnalysis.cpp - Test dead code analysis -----------------===// +//===- TestDenseDataFlowAnalysis.cpp - Test dense data flow analysis ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,123 +6,38 @@ // //===----------------------------------------------------------------------===// +#include "TestDenseDataFlowAnalysis.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" -#include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include using namespace mlir; using namespace mlir::dataflow; +using namespace mlir::dataflow::test; namespace { -/// This lattice represents a single underlying value for an SSA value. -class UnderlyingValue { -public: - /// Create an underlying value state with a known underlying value. - explicit UnderlyingValue(std::optional underlyingValue = std::nullopt) - : underlyingValue(underlyingValue) {} - - /// Whether the state is uninitialized. - bool isUninitialized() const { return !underlyingValue.has_value(); } - - /// Returns the underlying value. - Value getUnderlyingValue() const { - assert(!isUninitialized()); - return *underlyingValue; - } - - /// Join two underlying values. If there are conflicting underlying values, - /// go to the pessimistic value. - static UnderlyingValue join(const UnderlyingValue &lhs, - const UnderlyingValue &rhs) { - if (lhs.isUninitialized()) - return rhs; - if (rhs.isUninitialized()) - return lhs; - return lhs.underlyingValue == rhs.underlyingValue - ? lhs - : UnderlyingValue(Value{}); - } - - /// Compare underlying values. - bool operator==(const UnderlyingValue &rhs) const { - return underlyingValue == rhs.underlyingValue; - } - - void print(raw_ostream &os) const { os << underlyingValue; } - -private: - std::optional underlyingValue; -}; /// This lattice represents, for a given memory resource, the potential last /// operations that modified the resource. -class LastModification : public AbstractDenseLattice { +class LastModification : public AbstractDenseLattice, + public test::AccessLatticeBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) using AbstractDenseLattice::AbstractDenseLattice; - /// Clear all modifications. - ChangeResult reset() { - if (lastMods.empty()) - return ChangeResult::NoChange; - lastMods.clear(); - return ChangeResult::Change; - } - /// Join the last modifications. ChangeResult join(const AbstractDenseLattice &lattice) override { - const auto &rhs = static_cast(lattice); - ChangeResult result = ChangeResult::NoChange; - for (const auto &mod : rhs.lastMods) { - auto &lhsMod = lastMods[mod.first]; - if (lhsMod != mod.second) { - lhsMod.insert(mod.second.begin(), mod.second.end()); - result |= ChangeResult::Change; - } - } - return result; - } - - /// Set the last modification of a value. - ChangeResult set(Value value, Operation *op) { - auto &lastMod = lastMods[value]; - ChangeResult result = ChangeResult::NoChange; - if (lastMod.size() != 1 || *lastMod.begin() != op) { - result = ChangeResult::Change; - lastMod.clear(); - lastMod.insert(op); - } - return result; - } - - /// Get the last modifications of a value. Returns std::nullopt if the last - /// modifications are not known. - std::optional> getLastModifiers(Value value) const { - auto it = lastMods.find(value); - if (it == lastMods.end()) - return {}; - return it->second.getArrayRef(); + return AccessLatticeBase::merge(static_cast( + static_cast(lattice))); } void print(raw_ostream &os) const override { - for (const auto &lastMod : lastMods) { - os << lastMod.first << ":\n"; - for (Operation *op : lastMod.second) - os << " " << *op << "\n"; - } + return AccessLatticeBase::print(os); } - -private: - /// The potential last modifications of a memory resource. Use a set vector to - /// keep the results deterministic. - DenseMap, - SmallPtrSet>> - lastMods; }; class LastModifiedAnalysis : public DenseDataFlowAnalysis { @@ -142,54 +57,8 @@ propagateIfChanged(lattice, lattice->reset()); } }; - -/// Define the lattice class explicitly to provide a type ID. -struct UnderlyingValueLattice : public Lattice { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnderlyingValueLattice) - using Lattice::Lattice; -}; - -/// An analysis that uses forwarding of values along control-flow and callgraph -/// edges to determine single underlying values for block arguments. This -/// analysis exists so that the test analysis and pass can test the behaviour of -/// the dense data-flow analysis on the callgraph. -class UnderlyingValueAnalysis - : public SparseDataFlowAnalysis { -public: - using SparseDataFlowAnalysis::SparseDataFlowAnalysis; - - /// The underlying value of the results of an operation are not known. - void visitOperation(Operation *op, - ArrayRef operands, - ArrayRef results) override { - setAllToEntryStates(results); - } - - /// At an entry point, the underlying value of a value is itself. - void setToEntryState(UnderlyingValueLattice *lattice) override { - propagateIfChanged(lattice, - lattice->join(UnderlyingValue{lattice->getPoint()})); - } -}; } // end anonymous namespace -/// Look for the most underlying value of a value. -static Value getMostUnderlyingValue( - Value value, - function_ref getUnderlyingValueFn) { - const UnderlyingValueLattice *underlying; - do { - underlying = getUnderlyingValueFn(value); - if (!underlying || underlying->getValue().isUninitialized()) - return {}; - Value underlyingValue = underlying->getValue().getUnderlyingValue(); - if (underlyingValue == value) - break; - value = underlyingValue; - } while (true); - return value; -} - void LastModifiedAnalysis::visitOperation(Operation *op, const LastModification &before, LastModification *after) { @@ -211,9 +80,10 @@ if (!value) return setToEntryState(after); - value = getMostUnderlyingValue(value, [&](Value value) { - return getOrCreateFor(op, value); - }); + value = UnderlyingValueAnalysis::getMostUnderlyingValue( + value, [&](Value value) { + return getOrCreateFor(op, value); + }); if (!value) return; @@ -256,12 +126,13 @@ assert(lastMods && "expected a dense lattice"); for (auto [index, operand] : llvm::enumerate(op->getOperands())) { os << " operand #" << index << "\n"; - Value value = getMostUnderlyingValue(operand, [&](Value value) { - return solver.lookupState(value); - }); + Value value = UnderlyingValueAnalysis::getMostUnderlyingValue( + operand, [&](Value value) { + return solver.lookupState(value); + }); assert(value && "expected an underlying value"); if (std::optional> lastMod = - lastMods->getLastModifiers(value)) { + lastMods->getAdjacentAccess(value)) { for (Operation *lastModifier : *lastMod) { if (auto tagName = lastModifier->getAttrOfType("tag_name")) { diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -114,6 +114,7 @@ void registerTestMathPolynomialApproximationPass(); void registerTestMemRefDependenceCheck(); void registerTestMemRefStrideCalculation(); +void registerTestNextAccessPass(); void registerTestOneToNTypeConversionPass(); void registerTestOpaqueLoc(); void registerTestPadFusion(); @@ -232,6 +233,7 @@ mlir::test::registerTestMathPolynomialApproximationPass(); mlir::test::registerTestMemRefDependenceCheck(); mlir::test::registerTestMemRefStrideCalculation(); + mlir::test::registerTestNextAccessPass(); mlir::test::registerTestOneToNTypeConversionPass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestPadFusion();