diff --git a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/DenseAnalysis.h @@ -17,12 +17,18 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" namespace mlir { +namespace dataflow { -class RegionBranchOpInterface; +//===----------------------------------------------------------------------===// +// CallControlFlowAction +//===----------------------------------------------------------------------===// -namespace dataflow { +/// Indicates whether the control enters or exits the callee. +enum class CallControlFlowAction { EnterCallee, ExitCallee }; //===----------------------------------------------------------------------===// // AbstractDenseLattice @@ -109,6 +115,32 @@ /// operation transfer function. virtual void processOperation(Operation *op); + /// Propagate the dense lattice forward along the control flow edge from + /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt` + /// values correspond to control flow branches originating at or targeting the + /// `branch` operation itself. Default implementation just joins the states, + /// meaning that operations implementing `RegionBranchOpInterface` don't have + /// any effect on the lattice that isn't already expressed by the interface + /// itself. + virtual void visitRegionBranchControlFlowTransfer( + RegionBranchOpInterface branch, std::optional regionFrom, + std::optional regionTo, const AbstractDenseLattice &before, + AbstractDenseLattice *after) { + join(after, before); + } + + /// Propagate the dense lattice forward along the call control flow edge, + /// which can be either entering or exiting the callee. Default implementation + /// just meets the states, meaning that operations implementing + /// `CallOpInterface` don't have any effect on the lattice that isn't already + /// expressed by the interface itself. + virtual void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const AbstractDenseLattice &before, + AbstractDenseLattice *after) { + join(after, before); + } + /// Visit a program point within a region branch operation with predecessors /// in it. This can either be an entry block of one of the regions of the /// parent operation itself. @@ -120,6 +152,10 @@ /// Visit a block. The state at the start of the block is propagated from /// control-flow predecessors or callsites. void visitBlock(Block *block); + + /// Visit an operation for which the data flow is described by the + /// `CallOpInterface`. + void visitCallOperation(CallOpInterface call, AbstractDenseLattice *after); }; //===----------------------------------------------------------------------===// @@ -146,6 +182,60 @@ virtual void visitOperation(Operation *op, const LatticeT &before, LatticeT *after) = 0; + /// Hook for customizing the behavior of lattice propagation along the call + /// control flow edges. Two types of (forward) propagation are possible here: + /// - `action == CallControlFlowAction::Enter` indicates that: + /// - `before` is the state before the call operation; + /// - `after` is the state at the beginning of the callee entry block; + /// - `action == CallControlFlowAction::Exit` indicates that: + /// - `before` is the state at the end of a callee exit block; + /// - `after` is the state after the call operation. + /// By default, the `after` state is simply joined with the `before` state. + /// Concrete analyses can override this behavior or delegate to the parent + /// call for the default behavior. Specifically, if the `call` op may affect + /// the lattice prior to entering the callee, the custom behavior can be added + /// for `action == CallControlFlowAction::Enter`. If the `call` op may affect + /// the lattice post exiting the callee, the custom behavior can be added for + /// `action == CallControlFlowAction::Exit`. + virtual void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const LatticeT &before, + LatticeT *after) { + AbstractDenseDataFlowAnalysis::visitCallControlFlowTransfer(call, action, + before, after); + } + + /// Hook for customizing the behavior of lattice propagation along the control + /// flow edges between regions and their parent op. The control flows from + /// `regionFrom` to `regionTo`, both of which may be `nullopt` to indicate the + /// parent op. The lattice is propagated forward along this edge. The lattices + /// are as follows: + /// - `before:` + /// - if `regionFrom` is a region, this is the lattice at the end of the + /// block that exits the region; note that for multi-exit regions, the + /// lattices are equal at the end of all exiting blocks, but they are + /// associated with different program points. + /// - otherwise, this is the lattice before the parent op. + /// - `after`: + /// - if `regionTo` is a region, this is the lattice at the beginning of + /// the entry block of that region; + /// - otherwise, this is the lattice after the parent op. + /// By default, the `after` state is simply joined with the `before` state. + /// Concrete analyses can override this behavior or delegate to the parent + /// call for the default behavior. Specifically, if the `branch` op may affect + /// the lattice before entering any region, the custom behavior can be added + /// for `regionFrom == nullopt`. If the `branch` op may affect the lattice + /// after all terminated, the custom behavior can be added for `regionTo == + /// nullptr`. The behavior can be further refined for specific pairs of "from" + /// and "to" regions. + virtual void visitRegionBranchControlFlowTransfer( + RegionBranchOpInterface branch, std::optional regionFrom, + std::optional regionTo, const LatticeT &before, + LatticeT *after) { + AbstractDenseDataFlowAnalysis::visitRegionBranchControlFlowTransfer( + branch, regionFrom, regionTo, before, after); + } + protected: /// Get the dense lattice after this program point. LatticeT *getLattice(ProgramPoint point) override { @@ -162,10 +252,27 @@ /// Type-erased wrappers that convert the abstract dense lattice to a derived /// lattice and invoke the virtual hooks operating on the derived lattice. void visitOperationImpl(Operation *op, const AbstractDenseLattice &before, - AbstractDenseLattice *after) override { + AbstractDenseLattice *after) final { visitOperation(op, static_cast(before), static_cast(after)); } + void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const AbstractDenseLattice &before, + AbstractDenseLattice *after) final { + visitCallControlFlowTransfer(call, action, + static_cast(before), + static_cast(after)); + } + void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, + std::optional regionFrom, + std::optional regionTo, + const AbstractDenseLattice &before, + AbstractDenseLattice *after) final { + visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo, + static_cast(before), + static_cast(after)); + } }; //===----------------------------------------------------------------------===// @@ -231,12 +338,42 @@ propagateIfChanged(lhs, lhs->meet(rhs)); } - /// Visit an operation. If this is a call operation or region control-flow - /// operation, then the state after the execution of the operation is set by - /// control-flow or the callgraph. Otherwise, this function invokes the - /// operation transfer function. + /// Visit an operation. Dispatches to specialized methods for call or region + /// control-flow operations. Otherwise, this function invokes the operation + /// transfer function. virtual void processOperation(Operation *op); + /// Propagate the dense lattice backwards along the control flow edge from + /// `regionFrom` to `regionTo` regions of the `branch` operation. `nullopt` + /// values correspond to control flow branches originating at or targeting the + /// `branch` operation itself. Default implementation just meets the states, + /// meaning that operations implementing `RegionBranchOpInterface` don't have + /// any effect on the lattice that isn't already expressed by the interface + /// itself. + virtual void visitRegionBranchControlFlowTransfer( + RegionBranchOpInterface branch, std::optional regionFrom, + std::optional regionTo, const AbstractDenseLattice &after, + AbstractDenseLattice *before) { + meet(before, after); + } + + /// Propagate the dense lattice backwards along the call control flow edge, + /// which can be either entering or exiting the callee. Default implementation + /// just meets the states, meaning that operations implementing + /// `CallOpInterface` don't have any effect on hte lattice that isn't already + /// expressed by the interface itself. + virtual void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const AbstractDenseLattice &after, + AbstractDenseLattice *before) { + meet(before, after); + } + +private: + /// Visit a block. The state and the end of the block is propagated from + /// control-flow successors of the block or callsites. + void visitBlock(Block *block); + /// Visit a program point within a region branch operation with successors /// (from which the state is propagated) in or after it. `regionNo` indicates /// the region that contains the successor, `nullopt` indicating the successor @@ -246,10 +383,16 @@ std::optional regionNo, AbstractDenseLattice *before); -private: - /// VIsit a block. The state and the end of the block is propagated from - /// control-flow successors of the block or callsites. - void visitBlock(Block *block); + /// Visit an operation for which the data flow is described by the + /// `CallOpInterface`. Performs inter-procedural data flow as follows: + /// + /// - find the callable (resolve via the symbol table), + /// - get the entry block of the callable region, + /// - take the state before the first operation if present or at block end + /// otherwise, + /// - meet that state with the state before the call-like op, or use the + /// custom logic if overridden by concrete analyses. + void visitCallOperation(CallOpInterface call, AbstractDenseLattice *before); /// Symbol table for call-level control flow. SymbolTableCollection &symbolTable; @@ -280,6 +423,60 @@ virtual void visitOperation(Operation *op, const LatticeT &after, LatticeT *before) = 0; + /// Hook for customizing the behavior of lattice propagation along the call + /// control flow edges. Two types of (back) propagation are possible here: + /// - `action == CallControlFlowAction::Enter` indicates that: + /// - `after` is the state at the top of the callee entry block; + /// - `before` is the state before the call operation; + /// - `action == CallControlFlowAction::Exit` indicates that: + /// - `after` is the state after the call operation; + /// - `before` is the state of exit blocks of the callee. + /// By default, the `before` state is simply met with the `after` state. + /// Concrete analyses can override this behavior or delegate to the parent + /// call for the default behavior. Specifically, if the `call` op may affect + /// the lattice prior to entering the callee, the custom behavior can be added + /// for `action == CallControlFlowAction::Enter`. If the `call` op may affect + /// the lattice post exiting the callee, the custom behavior can be added for + /// `action == CallControlFlowAction::Exit`. + virtual void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const LatticeT &after, + LatticeT *before) { + AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( + call, action, after, before); + } + + /// Hook for customizing the behavior of lattice propagation along the control + /// flow edges between regions and their parent op. The control flows from + /// `regionFrom` to `regionTo`, both of which may be `nullopt` to indicate the + /// parent op. The lattice is propagated back along this edge. The lattices + /// are as follows: + /// - `after`: + /// - if `regionTo` is a region, this is the lattice at the beginning of + /// the entry block of that region; + /// - otherwise, this is the lattice after the parent op. + /// - `before:` + /// - if `regionFrom` is a region, this is the lattice at the end of the + /// block that exits the region; note that for multi-exit regions, the + /// lattices are equal at the end of all exiting blocks, but they are + /// associated with different program points. + /// - otherwise, this is the lattice before the parent op. + /// By default, the `before` state is simply met with the `after` state. + /// Concrete analyses can override this behavior or delegate to the parent + /// call for the default behavior. Specifically, if the `branch` op may affect + /// the lattice before entering any region, the custom behavior can be added + /// for `regionFrom == nullopt`. If the `branch` op may affect the lattice + /// after all terminated, the custom behavior can be added for `regionTo == + /// nullptr`. The behavior can be further refined for specific pairs of "from" + /// and "to" regions. + virtual void visitRegionBranchControlFlowTransfer( + RegionBranchOpInterface branch, std::optional regionFrom, + std::optional regionTo, const LatticeT &after, + LatticeT *before) { + AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchControlFlowTransfer( + branch, regionFrom, regionTo, after, before); + } + protected: /// Get the dense lattice at the given program point. LatticeT *getLattice(ProgramPoint point) override { @@ -289,17 +486,33 @@ /// Set the dense lattice at control flow exit point (after the terminator) /// and propagate an update if it changed. virtual void setToExitState(LatticeT *lattice) = 0; - void setToExitState(AbstractDenseLattice *lattice) override { + void setToExitState(AbstractDenseLattice *lattice) final { setToExitState(static_cast(lattice)); } - /// Type-erased wrapper that convert the abstract dense lattice to a derived + /// Type-erased wrappers that convert the abstract dense lattice to a derived /// lattice and invoke the virtual hooks operating on the derived lattice. void visitOperationImpl(Operation *op, const AbstractDenseLattice &after, - AbstractDenseLattice *before) override { + AbstractDenseLattice *before) final { visitOperation(op, static_cast(after), static_cast(before)); } + void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const AbstractDenseLattice &after, + AbstractDenseLattice *before) final { + visitCallControlFlowTransfer(call, action, + static_cast(after), + static_cast(before)); + } + void visitRegionBranchControlFlowTransfer( + RegionBranchOpInterface branch, std::optional regionForm, + std::optional regionTo, const AbstractDenseLattice &after, + AbstractDenseLattice *before) final { + visitRegionBranchControlFlowTransfer(branch, regionForm, regionTo, + static_cast(after), + static_cast(before)); + } }; } // end namespace dataflow diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -42,6 +42,38 @@ return success(); } +void AbstractDenseDataFlowAnalysis::visitCallOperation( + CallOpInterface call, AbstractDenseLattice *after) { + + const auto *predecessors = + getOrCreateFor(call.getOperation(), call); + // If not all return sites are known, then conservatively assume we can't + // reason about the data-flow. + if (!predecessors->allPredecessorsKnown()) + return setToEntryState(after); + + for (Operation *predecessor : predecessors->getKnownPredecessors()) { + // Get the lattices at callee return: + // + // func.func @callee() { + // ... + // return // predecessor + // // latticeAtCalleeReturn + // } + // func.func @caller() { + // ... + // call @callee + // // latticeAfterCall + // ... + // } + AbstractDenseLattice *latticeAfterCall = after; + const AbstractDenseLattice *latticeAtCalleeReturn = + getLatticeFor(call.getOperation(), predecessor); + visitCallControlFlowTransfer(call, CallControlFlowAction::ExitCallee, + *latticeAtCalleeReturn, latticeAfterCall); + } +} + void AbstractDenseDataFlowAnalysis::processOperation(Operation *op) { // If the containing block is not executable, bail out. if (!getOrCreateFor(op, op->getBlock())->isLive()) @@ -50,6 +82,13 @@ // Get the dense lattice to update. AbstractDenseLattice *after = getLattice(op); + // Get the dense state before the execution of the op. + const AbstractDenseLattice *before; + if (Operation *prev = op->getPrevNode()) + before = getLatticeFor(op, prev); + else + before = getLatticeFor(op, op->getBlock()); + // If this op implements region control-flow, then control-flow dictates its // transfer function. if (auto branch = dyn_cast(op)) @@ -57,23 +96,8 @@ // If this is a call operation, then join its lattices across known return // sites. - if (auto call = dyn_cast(op)) { - const auto *predecessors = getOrCreateFor(op, call); - // If not all return sites are known, then conservatively assume we can't - // reason about the data-flow. - if (!predecessors->allPredecessorsKnown()) - return setToEntryState(after); - for (Operation *predecessor : predecessors->getKnownPredecessors()) - join(after, *getLatticeFor(op, predecessor)); - return; - } - - // Get the dense state before the execution of the op. - const AbstractDenseLattice *before; - if (Operation *prev = op->getPrevNode()) - before = getLatticeFor(op, prev); - else - before = getLatticeFor(op, op->getBlock()); + if (auto call = dyn_cast(op)) + return visitCallOperation(call, after); // Invoke the operation transfer function. visitOperationImpl(op, *before, after); @@ -100,10 +124,15 @@ return setToEntryState(after); for (Operation *callsite : callsites->getKnownPredecessors()) { // Get the dense lattice before the callsite. + const AbstractDenseLattice *before; if (Operation *prev = callsite->getPrevNode()) - join(after, *getLatticeFor(block, prev)); + before = getLatticeFor(block, prev); else - join(after, *getLatticeFor(block, callsite->getBlock())); + before = getLatticeFor(block, callsite->getBlock()); + + visitCallControlFlowTransfer(cast(callsite), + CallControlFlowAction::EnterCallee, + *before, after); } return; } @@ -152,7 +181,41 @@ } else { before = getLatticeFor(point, op); } - join(after, *before); + + // This function is called in two cases: + // 1. when visiting the block (point = block); + // 2. when visiting the parent operation (point = parent op). + // In both cases, we are looking for predecessor operations of the point, + // 1. predecessor may be the terminator of another block from another + // region (assuming that the block does belong to another region via an + // assertion) or the parent (when parent can transfer control to this + // region); + // 2. predecessor may be the terminator of a block that exits the + // region (when region transfers control to the parent) or the operation + // before the parent. + // In the latter case, just perform the join as it isn't the control flow + // affected by the region. + std::optional regionFrom = + op == branch ? std::optional() + : op->getBlock()->getParent()->getRegionNumber(); + if (auto *toBlock = point.dyn_cast()) { + assert(op == branch || + toBlock->getParent() != op->getBlock()->getParent()); + unsigned regionTo = toBlock->getParent()->getRegionNumber(); + visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo, + *before, after); + } else { + assert(point.get() == branch && + "expected to be visiting the branch itself"); + // Only need to call the arc transfer when the predecessor is the region + // or the op itself, not the previous op. + if (op->getParentOp() == branch || op == branch) { + visitRegionBranchControlFlowTransfer( + branch, regionFrom, /*regionTo=*/std::nullopt, *before, after); + } else { + join(after, *before); + } + } } } @@ -194,6 +257,44 @@ return success(); } +void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( + CallOpInterface call, AbstractDenseLattice *before) { + // Find the callee. + Operation *callee = call.resolveCallable(&symbolTable); + auto callable = dyn_cast_or_null(callee); + if (!callable) + return setToExitState(before); + + // No region means the callee is only declared in this module and we shouldn't + // assume anything about it. + Region *region = callable.getCallableRegion(); + if (!region || region->empty()) + return setToExitState(before); + + // Call-level control flow specifies the data flow here. + // + // func.func @callee() { + // ^calleeEntryBlock: + // // latticeAtCalleeEntry + // ... + // } + // func.func @caller() { + // ... + // // latticeBeforeCall + // call @callee + // ... + // } + Block *calleeEntryBlock = ®ion->front(); + ProgramPoint calleeEntry = calleeEntryBlock->empty() + ? ProgramPoint(calleeEntryBlock) + : &calleeEntryBlock->front(); + const AbstractDenseLattice &latticeAtCalleeEntry = + *getLatticeFor(call.getOperation(), calleeEntry); + AbstractDenseLattice *latticeBeforeCall = before; + visitCallControlFlowTransfer(call, CallControlFlowAction::EnterCallee, + latticeAtCalleeEntry, latticeBeforeCall); +} + void AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) { // If the containing block is not executable, bail out. if (!getOrCreateFor(op, op->getBlock())->isLive()) @@ -202,39 +303,6 @@ // Get the dense lattice to update. AbstractDenseLattice *before = getLattice(op); - // If the op implements region control flow, then the interface specifies the - // control function. - // TODO: this is not always true, e.g. linalg.generic, but is implement this - // way for consistency with the dense forward analysis. - if (auto branch = dyn_cast(op)) - return visitRegionBranchOperation(op, branch, std::nullopt, before); - - // If the op is a call-like, do inter-procedural data flow as follows: - // - // - find the callable (resolve via the symbol table), - // - get the entry block of the callable region, - // - take the state before the first operation if present or at block end - // otherwise, - // - meet that state with the state before the call-like op. - if (auto call = dyn_cast(op)) { - Operation *callee = call.resolveCallable(&symbolTable); - if (auto callable = dyn_cast(callee)) { - Region *region = callable.getCallableRegion(); - if (region && !region->empty()) { - Block *entryBlock = ®ion->front(); - if (entryBlock->empty()) - meet(before, *getLatticeFor(op, entryBlock)); - else - meet(before, *getLatticeFor(op, &entryBlock->front())); - } else { - setToExitState(before); - } - } else { - setToExitState(before); - } - return; - } - // Get the dense state after execution of this op. const AbstractDenseLattice *after; if (Operation *next = op->getNextNode()) @@ -242,6 +310,12 @@ else after = getLatticeFor(op, op->getBlock()); + // Special cases where control flow may dictate data flow. + if (auto branch = dyn_cast(op)) + return visitRegionBranchOperation(op, branch, std::nullopt, before); + if (auto call = dyn_cast(op)) + return visitCallOperation(call, before); + // Invoke the operation transfer function. visitOperationImpl(op, *after, before); } @@ -280,16 +354,20 @@ return setToExitState(before); for (Operation *callsite : callsites->getKnownPredecessors()) { + const AbstractDenseLattice *after; if (Operation *next = callsite->getNextNode()) - meet(before, *getLatticeFor(block, next)); + after = getLatticeFor(block, next); else - meet(before, *getLatticeFor(block, callsite->getBlock())); + after = getLatticeFor(block, callsite->getBlock()); + visitCallControlFlowTransfer(cast(callsite), + CallControlFlowAction::ExitCallee, *after, + before); } return; } // If this block is exiting from an operation with region-based control - // flow, follow that flow. + // flow, propagate the lattice back along the control flow edge. if (auto branch = dyn_cast(block->getParentOp())) { visitRegionBranchOperation(block, branch, block->getParent()->getRegionNumber(), before); @@ -346,7 +424,11 @@ else after = getLatticeFor(point, &successorBlock->front()); } - meet(before, *after); + std::optional successorNo = + successor.isParent() ? std::optional() + : successor.getSuccessor()->getRegionNumber(); + visitRegionBranchControlFlowTransfer(branch, regionNo, successorNo, *after, + before); } } diff --git a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir --- a/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir +++ b/mlir/test/Analysis/DataFlow/test-last-modified-callgraph.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-last-modified %s 2>&1 | FileCheck %s +// RUN: mlir-opt -test-last-modified --split-input-file %s 2>&1 | FileCheck %s // CHECK-LABEL: test_tag: test_callsite // CHECK: operand #0 @@ -64,4 +64,84 @@ func.func @test_multiple_return_sites(%cond: i1, %a: i32, %ptr: memref) -> memref { %0 = func.call @multiple_return_site_fn(%cond, %a, %ptr) : (i1, i32, memref) -> memref return {tag = "test_multiple_return_sites"} %0 : memref -} \ No newline at end of file +} + +// ----- + + +func.func private @callee(%arg0: memref) -> memref { + %2 = arith.constant 2.0 : f32 + memref.load %arg0[] {tag = "call_and_store_before::enter_callee"} : memref + memref.store %2, %arg0[] {tag_name = "callee"} : memref + memref.load %arg0[] {tag = "exit_callee"} : memref + return %arg0 : memref +} +// In this test, the "call" operation also stores to %arg0 itself before +// transferring control flow to the callee. Therefore, the order of accesses is +// "pre" -> "call" -> "callee" -> "post" + +// CHECK-LABEL: test_tag: call_and_store_before::enter_callee: +// CHECK: operand #0 +// CHECK: - call +// CHECK: test_tag: exit_callee: +// CHECK: operand #0 +// CHECK: - callee +// CHECK: test_tag: before_call: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: after_call: +// CHECK: operand #0 +// CHECK: - callee +// CHECK: test_tag: return: +// CHECK: operand #0 +// CHECK: - post +func.func @call_and_store_before(%arg0: memref) -> memref { + %0 = arith.constant 0.0 : f32 + %1 = arith.constant 1.0 : f32 + memref.store %0, %arg0[] {tag_name = "pre"} : memref + memref.load %arg0[] {tag = "before_call"} : memref + test.call_and_store @callee(%arg0), %arg0 {tag_name = "call", store_before_call = true} : (memref, memref) -> () + memref.load %arg0[] {tag = "after_call"} : memref + memref.store %1, %arg0[] {tag_name = "post"} : memref + return {tag = "return"} %arg0 : memref +} + +// ----- + +func.func private @callee(%arg0: memref) -> memref { + %2 = arith.constant 2.0 : f32 + memref.load %arg0[] {tag = "call_and_store_after::enter_callee"} : memref + memref.store %2, %arg0[] {tag_name = "callee"} : memref + memref.load %arg0[] {tag = "exit_callee"} : memref + return %arg0 : memref +} + +// In this test, the "call" operation also stores to %arg0 itself after getting +// control flow back from the callee. Therefore, the order of accesses is +// "pre" -> "callee" -> "call" -> "post" + +// CHECK-LABEL: test_tag: call_and_store_after::enter_callee: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: exit_callee: +// CHECK: operand #0 +// CHECK: - callee +// CHECK: test_tag: before_call: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: after_call: +// CHECK: operand #0 +// CHECK: - call +// CHECK: test_tag: return: +// CHECK: operand #0 +// CHECK: - post +func.func @call_and_store_after(%arg0: memref) -> memref { + %0 = arith.constant 0.0 : f32 + %1 = arith.constant 1.0 : f32 + memref.store %0, %arg0[] {tag_name = "pre"} : memref + memref.load %arg0[] {tag = "before_call"} : memref + test.call_and_store @callee(%arg0), %arg0 {tag_name = "call", store_before_call = false} : (memref, memref) -> () + memref.load %arg0[] {tag = "after_call"} : memref + memref.store %1, %arg0[] {tag_name = "post"} : memref + return {tag = "return"} %arg0 : memref +} diff --git a/mlir/test/Analysis/DataFlow/test-last-modified.mlir b/mlir/test/Analysis/DataFlow/test-last-modified.mlir --- a/mlir/test/Analysis/DataFlow/test-last-modified.mlir +++ b/mlir/test/Analysis/DataFlow/test-last-modified.mlir @@ -113,3 +113,119 @@ "test.unknown_effects"() : () -> () return {tag = "unknown_memory_effects_b"} %ptr : memref } + +// CHECK-LABEL: test_tag: store_with_a_region_before::before: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: inside_region: +// CHECK: operand #0 +// CHECK: - region +// CHECK: test_tag: after: +// CHECK: operand #0 +// CHECK: - region +// CHECK: test_tag: return: +// CHECK: operand #0 +// CHECK: - post +func.func @store_with_a_region_before(%arg0: memref) -> memref { + %0 = arith.constant 0.0 : f32 + %1 = arith.constant 1.0 : f32 + memref.store %0, %arg0[] {tag_name = "pre"} : memref + memref.load %arg0[] {tag = "store_with_a_region_before::before"} : memref + test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = true } { + memref.load %arg0[] {tag = "inside_region"} : memref + test.store_with_a_region_terminator + } : memref + memref.load %arg0[] {tag = "after"} : memref + memref.store %1, %arg0[] {tag_name = "post"} : memref + return {tag = "return"} %arg0 : memref +} + +// CHECK-LABEL: test_tag: store_with_a_region_after::before: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: inside_region: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: after: +// CHECK: operand #0 +// CHECK: - region +// CHECK: test_tag: return: +// CHECK: operand #0 +// CHECK: - post +func.func @store_with_a_region_after(%arg0: memref) -> memref { + %0 = arith.constant 0.0 : f32 + %1 = arith.constant 1.0 : f32 + memref.store %0, %arg0[] {tag_name = "pre"} : memref + memref.load %arg0[] {tag = "store_with_a_region_after::before"} : memref + test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = false } { + memref.load %arg0[] {tag = "inside_region"} : memref + test.store_with_a_region_terminator + } : memref + memref.load %arg0[] {tag = "after"} : memref + memref.store %1, %arg0[] {tag_name = "post"} : memref + return {tag = "return"} %arg0 : memref +} + +// CHECK-LABEL: test_tag: store_with_a_region_before_containing_a_store::before: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: enter_region: +// CHECK: operand #0 +// CHECK: - region +// CHECK: test_tag: exit_region: +// CHECK: operand #0 +// CHECK: - inner +// CHECK: test_tag: after: +// CHECK: operand #0 +// CHECK: - inner +// CHECK: test_tag: return: +// CHECK: operand #0 +// CHECK: - post +func.func @store_with_a_region_before_containing_a_store(%arg0: memref) -> memref { + %0 = arith.constant 0.0 : f32 + %1 = arith.constant 1.0 : f32 + memref.store %0, %arg0[] {tag_name = "pre"} : memref + memref.load %arg0[] {tag = "store_with_a_region_before_containing_a_store::before"} : memref + test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = true } { + memref.load %arg0[] {tag = "enter_region"} : memref + %2 = arith.constant 2.0 : f32 + memref.store %2, %arg0[] {tag_name = "inner"} : memref + memref.load %arg0[] {tag = "exit_region"} : memref + test.store_with_a_region_terminator + } : memref + memref.load %arg0[] {tag = "after"} : memref + memref.store %1, %arg0[] {tag_name = "post"} : memref + return {tag = "return"} %arg0 : memref +} + +// CHECK-LABEL: test_tag: store_with_a_region_after_containing_a_store::before: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: enter_region: +// CHECK: operand #0 +// CHECK: - pre +// CHECK: test_tag: exit_region: +// CHECK: operand #0 +// CHECK: - inner +// CHECK: test_tag: after: +// CHECK: operand #0 +// CHECK: - region +// CHECK: test_tag: return: +// CHECK: operand #0 +// CHECK: - post +func.func @store_with_a_region_after_containing_a_store(%arg0: memref) -> memref { + %0 = arith.constant 0.0 : f32 + %1 = arith.constant 1.0 : f32 + memref.store %0, %arg0[] {tag_name = "pre"} : memref + memref.load %arg0[] {tag = "store_with_a_region_after_containing_a_store::before"} : memref + test.store_with_a_region %arg0 attributes { tag_name = "region", store_before_region = false } { + memref.load %arg0[] {tag = "enter_region"} : memref + %2 = arith.constant 2.0 : f32 + memref.store %2, %arg0[] {tag_name = "inner"} : memref + memref.load %arg0[] {tag = "exit_region"} : memref + test.store_with_a_region_terminator + } : memref + memref.load %arg0[] {tag = "after"} : memref + memref.store %1, %arg0[] {tag_name = "post"} : memref + return {tag = "return"} %arg0 : memref +} diff --git a/mlir/test/Analysis/DataFlow/test-next-access.mlir b/mlir/test/Analysis/DataFlow/test-next-access.mlir --- a/mlir/test/Analysis/DataFlow/test-next-access.mlir +++ b/mlir/test/Analysis/DataFlow/test-next-access.mlir @@ -357,3 +357,157 @@ memref.load %arg0[] {name = "post"} : memref return } + +// ----- + + +// In this test, the "call" operation also accesses %arg0 itself before +// transferring control flow to the callee. Therefore, the order of accesses is +// "caller" -> "call" -> "callee" -> "post" + +func.func private @callee(%arg0: memref) { + // CHECK: name = "callee" + // CHECK-SAME-LITERAL: next_access = [["post"]] + memref.load %arg0[] {name = "callee"} : memref + return +} + +// CHECK-LABEL: @call_and_store_before +func.func @call_and_store_before(%arg0: memref) { + // CHECK: name = "caller" + // CHECK-SAME-LITERAL: next_access = [["call"]] + memref.load %arg0[] {name = "caller"} : memref + // Note that the access after the entire call is "post". + // CHECK: name = "call" + // CHECK-SAME-LITERAL: next_access = [["post"], ["post"]] + test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = true} : (memref, memref) -> () + // CHECK: name = "post" + // CHECK-SAME-LITERAL: next_access = ["unknown"] + memref.load %arg0[] {name = "post"} : memref + return +} + +// ----- + +// In this test, the "call" operation also accesses %arg0 itself after getting +// control flow back from the callee. Therefore, the order of accesses is +// "caller" -> "callee" -> "call" -> "post" + +func.func private @callee(%arg0: memref) { + // CHECK: name = "callee" + // CHECK-SAME-LITERAL: next_access = [["call"]] + memref.load %arg0[] {name = "callee"} : memref + return +} + +// CHECK-LABEL: @call_and_store_after +func.func @call_and_store_after(%arg0: memref) { + // CHECK: name = "caller" + // CHECK-SAME-LITERAL: next_access = [["callee"]] + memref.load %arg0[] {name = "caller"} : memref + // CHECK: name = "call" + // CHECK-SAME-LITERAL: next_access = [["post"], ["post"]] + test.call_and_store @callee(%arg0), %arg0 {name = "call", store_before_call = true} : (memref, memref) -> () + // CHECK: name = "post" + // CHECK-SAME-LITERAL: next_access = ["unknown"] + memref.load %arg0[] {name = "post"} : memref + return +} + +// ----- + +// In this test, the "region" operation also accesses %arg0 itself before +// entering the region. Therefore: +// - the next access of "pre" is the "region" operation itself; +// - at the entry of the block, the next access is "post". +// CHECK-LABEL: @store_with_a_region +func.func @store_with_a_region_before(%arg0: memref) { + // CHECK: name = "pre" + // CHECK-SAME-LITERAL: next_access = [["region"]] + memref.load %arg0[] {name = "pre"} : memref + // CHECK: name = "region" + // CHECK-SAME-LITERAL: next_access = [["post"]] + // CHECK-SAME-LITERAL: next_at_entry_point = [[["post"]]] + test.store_with_a_region %arg0 attributes { name = "region", store_before_region = true } { + test.store_with_a_region_terminator + } : memref + memref.load %arg0[] {name = "post"} : memref + return +} + +// In this test, the "region" operation also accesses %arg0 itself after +// exiting from the region. Therefore: +// - the next access of "pre" is the "region" operation itself; +// - at the entry of the block, the next access is "region". +// CHECK-LABEL: @store_with_a_region +func.func @store_with_a_region_after(%arg0: memref) { + // CHECK: name = "pre" + // CHECK-SAME-LITERAL: next_access = [["region"]] + memref.load %arg0[] {name = "pre"} : memref + // CHECK: name = "region" + // CHECK-SAME-LITERAL: next_access = [["post"]] + // CHECK-SAME-LITERAL: next_at_entry_point = [[["region"]]] + test.store_with_a_region %arg0 attributes { name = "region", store_before_region = false } { + test.store_with_a_region_terminator + } : memref + memref.load %arg0[] {name = "post"} : memref + return +} + +// In this test, the operation with a region stores to %arg0 before going to the +// region. Therefore: +// - the next access of "pre" is the "region" operation itself; +// - the next access of the "region" operation (computed as the next access +// *after* said operation) is the "post" operation; +// - the next access of the "inner" operation is also "post"; +// - the next access at the entry point of the region of the "region" operation +// is the "inner" operation. +// That is, the order of access is: "pre" -> "region" -> "inner" -> "post". +// CHECK-LABEL: @store_with_a_region_before_containing_a_load +func.func @store_with_a_region_before_containing_a_load(%arg0: memref) { + // CHECK: name = "pre" + // CHECK-SAME-LITERAL: next_access = [["region"]] + memref.load %arg0[] {name = "pre"} : memref + // CHECK: name = "region" + // CHECK-SAME-LITERAL: next_access = [["post"]] + // CHECK-SAME-LITERAL: next_at_entry_point = [[["inner"]]] + test.store_with_a_region %arg0 attributes { name = "region", store_before_region = true } { + // CHECK: name = "inner" + // CHECK-SAME-LITERAL: next_access = [["post"]] + memref.load %arg0[] {name = "inner"} : memref + test.store_with_a_region_terminator + } : memref + // CHECK: name = "post" + // CHECK-SAME-LITERAL: next_access = ["unknown"] + memref.load %arg0[] {name = "post"} : memref + return +} + +// In this test, the operation with a region stores to %arg0 after exiting from +// the region. Therefore: +// - the next access of "pre" is "inner"; +// - the next access of the "region" operation (computed as the next access +// *after* said operation) is the "post" operation); +// - the next access at the entry point of the region of the "region" operation +// is the "inner" operation; +// - the next access of the "inner" operation is the "region" operation itself. +// That is, the order of access is "pre" -> "inner" -> "region" -> "post". +// CHECK-LABEL: @store_with_a_region_after_containing_a_load +func.func @store_with_a_region_after_containing_a_load(%arg0: memref) { + // CHECK: name = "pre" + // CHECK-SAME-LITERAL: next_access = [["inner"]] + memref.load %arg0[] {name = "pre"} : memref + // CHECK: name = "region" + // CHECK-SAME-LITERAL: next_access = [["post"]] + // CHECK-SAME-LITERAL: next_at_entry_point = [[["inner"]]] + test.store_with_a_region %arg0 attributes { name = "region", store_before_region = false } { + // CHECK: name = "inner" + // CHECK-SAME-LITERAL: next_access = [["region"]] + memref.load %arg0[] {name = "inner"} : memref + test.store_with_a_region_terminator + } : memref + // CHECK: name = "post" + // CHECK-SAME-LITERAL: next_access = ["unknown"] + memref.load %arg0[] {name = "post"} : memref + return +} diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseBackwardDataFlowAnalysis.cpp @@ -11,11 +11,15 @@ //===----------------------------------------------------------------------===// #include "TestDenseDataFlowAnalysis.h" +#include "TestDialect.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/TypeID.h" @@ -27,14 +31,14 @@ namespace { -class NextAccess : public AbstractDenseLattice, public test::AccessLatticeBase { +class NextAccess : public AbstractDenseLattice, public AccessLatticeBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NextAccess) using dataflow::AbstractDenseLattice::AbstractDenseLattice; ChangeResult meet(const AbstractDenseLattice &lattice) override { - return AccessLatticeBase::merge(static_cast( + return AccessLatticeBase::merge(static_cast( static_cast(lattice))); } @@ -50,6 +54,17 @@ void visitOperation(Operation *op, const NextAccess &after, NextAccess *before) override; + void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const NextAccess &after, + NextAccess *before) override; + + void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, + std::optional regionFrom, + std::optional regionTo, + const NextAccess &after, + NextAccess *before) override; + // TODO: this isn't ideal for the analysis. When there is no next access, it // means "we don't know what the next access is" rather than "there is no next // access". But it's unclear how to differentiate the two cases... @@ -78,18 +93,53 @@ if (!value) return setToExitState(before); + // If cannot find the most underlying value, we cannot assume anything about + // the next accesses. value = UnderlyingValueAnalysis::getMostUnderlyingValue( value, [&](Value value) { return getOrCreateFor(op, value); }); if (!value) - return; + return setToExitState(before); result |= before->set(value, op); } propagateIfChanged(before, result); } +void NextAccessAnalysis::visitCallControlFlowTransfer( + CallOpInterface call, CallControlFlowAction action, const NextAccess &after, + NextAccess *before) { + auto testCallAndStore = + dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); + if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && + testCallAndStore.getStoreBeforeCall()) || + (action == CallControlFlowAction::ExitCallee && + !testCallAndStore.getStoreBeforeCall()))) { + visitOperation(call, after, before); + } else { + AbstractDenseBackwardDataFlowAnalysis::visitCallControlFlowTransfer( + call, action, after, before); + } +} + +void NextAccessAnalysis::visitRegionBranchControlFlowTransfer( + RegionBranchOpInterface branch, std::optional regionFrom, + std::optional regionTo, const NextAccess &after, + NextAccess *before) { + auto testStoreWithARegion = + dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); + + if (testStoreWithARegion && + ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) || + (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) { + visitOperation(branch, static_cast(after), + static_cast(before)); + } else { + propagateIfChanged(before, before->meet(after)); + } +} + namespace { struct TestNextAccessPass : public PassWrapper> { @@ -99,6 +149,45 @@ static constexpr llvm::StringLiteral kTagAttrName = "name"; static constexpr llvm::StringLiteral kNextAccessAttrName = "next_access"; + static constexpr llvm::StringLiteral kAtEntryPointAttrName = + "next_at_entry_point"; + + static Attribute makeNextAccessAttribute(Operation *op, + const DataFlowSolver &solver, + const NextAccess *nextAccess) { + if (!nextAccess) + return StringAttr::get(op->getContext(), "not computed"); + + SmallVector attrs; + for (Value operand : op->getOperands()) { + Value value = UnderlyingValueAnalysis::getMostUnderlyingValue( + operand, [&](Value value) { + return solver.lookupState(value); + }); + std::optional> nextAcc = + nextAccess->getAdjacentAccess(value); + if (!nextAcc) { + attrs.push_back(StringAttr::get(op->getContext(), "unknown")); + continue; + } + + SmallVector innerAttrs; + innerAttrs.reserve(nextAcc->size()); + for (Operation *nextAccOp : *nextAcc) { + if (auto nextAccTag = + nextAccOp->getAttrOfType(kTagAttrName)) { + innerAttrs.push_back(nextAccTag); + continue; + } + std::string repr; + llvm::raw_string_ostream os(repr); + nextAccOp->print(os); + innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); + } + attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); + } + return ArrayAttr::get(op->getContext(), attrs); + } void runOnOperation() override { Operation *op = getOperation(); @@ -113,7 +202,6 @@ emitError(op->getLoc(), "dataflow solver failed"); return signalPassFailure(); } - op->walk([&](Operation *op) { auto tag = op->getAttrOfType(kTagAttrName); if (!tag) @@ -122,42 +210,28 @@ const NextAccess *nextAccess = solver.lookupState( op->getNextNode() == nullptr ? ProgramPoint(op->getBlock()) : op->getNextNode()); - if (!nextAccess) { - op->setAttr(kNextAccessAttrName, - StringAttr::get(op->getContext(), "not computed")); + op->setAttr(kNextAccessAttrName, + makeNextAccessAttribute(op, solver, nextAccess)); + + auto iface = dyn_cast(op); + if (!iface) return; - } - SmallVector attrs; - for (Value operand : op->getOperands()) { - Value value = UnderlyingValueAnalysis::getMostUnderlyingValue( - operand, [&](Value value) { - return solver.lookupState(value); - }); - std::optional> nextAcc = - nextAccess->getAdjacentAccess(value); - if (!nextAcc) { - attrs.push_back(StringAttr::get(op->getContext(), "unknown")); + SmallVector entryPointNextAccess; + SmallVector regionSuccessors; + iface.getSuccessorRegions(std::nullopt, regionSuccessors); + for (const RegionSuccessor &successor : regionSuccessors) { + if (!successor.getSuccessor() || successor.getSuccessor()->empty()) continue; - } - - SmallVector innerAttrs; - innerAttrs.reserve(nextAcc->size()); - for (Operation *nextAccOp : *nextAcc) { - if (auto nextAccTag = - nextAccOp->getAttrOfType(kTagAttrName)) { - innerAttrs.push_back(nextAccTag); - continue; - } - std::string repr; - llvm::raw_string_ostream os(repr); - nextAccOp->print(os); - innerAttrs.push_back(StringAttr::get(op->getContext(), os.str())); - } - attrs.push_back(ArrayAttr::get(op->getContext(), innerAttrs)); + Block &successorBlock = successor.getSuccessor()->front(); + ProgramPoint successorPoint = successorBlock.empty() + ? ProgramPoint(&successorBlock) + : &successorBlock.front(); + entryPointNextAccess.push_back(makeNextAccessAttribute( + op, solver, solver.lookupState(successorPoint))); } - - op->setAttr(kNextAccessAttrName, ArrayAttr::get(op->getContext(), attrs)); + op->setAttr(kAtEntryPointAttrName, + ArrayAttr::get(op->getContext(), entryPointNextAccess)); }); } }; diff --git a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp --- a/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp +++ b/mlir/test/lib/Analysis/DataFlow/TestDenseDataFlowAnalysis.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDenseDataFlowAnalysis.h" +#include "TestDialect.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" @@ -22,8 +23,7 @@ /// This lattice represents, for a given memory resource, the potential last /// operations that modified the resource. -class LastModification : public AbstractDenseLattice, - public test::AccessLatticeBase { +class LastModification : public AbstractDenseLattice, public AccessLatticeBase { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LastModification) @@ -31,7 +31,7 @@ /// Join the last modifications. ChangeResult join(const AbstractDenseLattice &lattice) override { - return AccessLatticeBase::merge(static_cast( + return AccessLatticeBase::merge(static_cast( static_cast(lattice))); } @@ -51,6 +51,17 @@ void visitOperation(Operation *op, const LastModification &before, LastModification *after) override; + void visitCallControlFlowTransfer(CallOpInterface call, + CallControlFlowAction action, + const LastModification &before, + LastModification *after) override; + + void visitRegionBranchControlFlowTransfer(RegionBranchOpInterface branch, + std::optional regionFrom, + std::optional regionTo, + const LastModification &before, + LastModification *after) override; + /// At an entry point, the last modifications of all memory resources are /// unknown. void setToEntryState(LastModification *lattice) override { @@ -80,12 +91,14 @@ if (!value) return setToEntryState(after); + // If we cannot find the underlying value, we shouldn't just propagate the + // effects through, return the pessimistic state. value = UnderlyingValueAnalysis::getMostUnderlyingValue( value, [&](Value value) { return getOrCreateFor(op, value); }); if (!value) - return; + return setToEntryState(after); // Nothing to do for reads. if (isa(effect.getEffect())) @@ -96,6 +109,36 @@ propagateIfChanged(after, result); } +void LastModifiedAnalysis::visitCallControlFlowTransfer( + CallOpInterface call, CallControlFlowAction action, + const LastModification &before, LastModification *after) { + auto testCallAndStore = + dyn_cast<::test::TestCallAndStoreOp>(call.getOperation()); + if (testCallAndStore && ((action == CallControlFlowAction::EnterCallee && + testCallAndStore.getStoreBeforeCall()) || + (action == CallControlFlowAction::ExitCallee && + !testCallAndStore.getStoreBeforeCall()))) { + return visitOperation(call, before, after); + } + AbstractDenseDataFlowAnalysis::visitCallControlFlowTransfer(call, action, + before, after); +} + +void LastModifiedAnalysis::visitRegionBranchControlFlowTransfer( + RegionBranchOpInterface branch, std::optional regionFrom, + std::optional regionTo, const LastModification &before, + LastModification *after) { + auto testStoreWithARegion = + dyn_cast<::test::TestStoreWithARegion>(branch.getOperation()); + if (testStoreWithARegion && + ((!regionTo && !testStoreWithARegion.getStoreBeforeRegion()) || + (!regionFrom && testStoreWithARegion.getStoreBeforeRegion()))) { + return visitOperation(branch, before, after); + } + AbstractDenseDataFlowAnalysis::visitRegionBranchControlFlowTransfer( + branch, regionFrom, regionTo, before, after); +} + namespace { struct TestLastModifiedPass : public PassWrapper> { diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -26,6 +26,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Support/LogicalResult.h" @@ -2014,6 +2015,37 @@ printer << second << " = " << (second + first); } +//===----------------------------------------------------------------------===// +// Test Dataflow +//===----------------------------------------------------------------------===// + +CallInterfaceCallable TestCallAndStoreOp::getCallableForCallee() { + return getCallee(); +} + +void TestCallAndStoreOp::setCalleeFromCallable(CallInterfaceCallable callee) { + setCalleeAttr(callee.get()); +} + +Operation::operand_range TestCallAndStoreOp::getArgOperands() { + return getCalleeOperands(); +} + +void TestStoreWithARegion::getSuccessorRegions( + std::optional index, ArrayRef operands, + SmallVectorImpl ®ions) { + if (!index) { + regions.emplace_back(&getBody(), getBody().front().getArguments()); + } else { + regions.emplace_back(); + } +} + +MutableOperandRange TestStoreWithARegionTerminator::getMutableSuccessorOperands( + std::optional index) { + return MutableOperandRange(getOperation()); +} + #include "TestOpEnums.cpp.inc" #include "TestOpInterfaces.cpp.inc" #include "TestTypeInterfaces.cpp.inc" diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -3443,6 +3443,41 @@ }]; } +//===----------------------------------------------------------------------===// +// Test Dataflow +//===----------------------------------------------------------------------===// + +def TestCallAndStoreOp : TEST_Op<"call_and_store", + [DeclareOpInterfaceMethods]> { + let arguments = (ins + SymbolRefAttr:$callee, + Arg:$address, + Variadic:$callee_operands, + BoolAttr:$store_before_call + ); + let results = (outs + Variadic:$results + ); + let assemblyFormat = + "$callee `(` $callee_operands `)` `,` $address attr-dict " + "`:` functional-type(operands, results)"; +} +def TestStoreWithARegion : TEST_Op<"store_with_a_region", + [DeclareOpInterfaceMethods, + SingleBlock]> { + let arguments = (ins + Arg:$address, + BoolAttr:$store_before_region + ); + let regions = (region AnyRegion:$body); + let assemblyFormat = + "$address attr-dict-with-keyword regions `:` type($address)"; +} + +def TestStoreWithARegionTerminator : TEST_Op<"store_with_a_region_terminator", + [DeclareOpInterfaceMethods, Terminator, NoMemoryEffect]> { + let assemblyFormat = "attr-dict"; +} #endif // TEST_OPS