diff --git a/mlir/include/mlir/Analysis/DataFlowAnalysis.h b/mlir/include/mlir/Analysis/DataFlowAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlowAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlowAnalysis.h @@ -22,6 +22,7 @@ #ifndef MLIR_ANALYSIS_DATAFLOWANALYSIS_H #define MLIR_ANALYSIS_DATAFLOWANALYSIS_H +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/DenseMap.h" @@ -250,6 +251,18 @@ ArrayRef operands, SmallVectorImpl &successors) = 0; + /// Given a operation with successor regions, one of those regions, + /// and the lattice elements corresponding to the operation's + /// arguments, compute the latice values for block arguments + /// that are not accounted for by the branching control flow (ex. the + /// bounds of loops). By default, this method marks all such lattice elements + /// as having reached a pessimistic fixpoint. The region in the + /// RegionSuccessor and the operand latice elements are guaranteed to be + /// non-null. + virtual ChangeResult + visitNonControlFlowArguments(Operation *op, const RegionSuccessor ®ion, + ArrayRef operands) = 0; + /// Create a new uninitialized lattice element. An optional value is provided /// which, if valid, should be used to initialize the known conservative state /// of the lattice. @@ -347,6 +360,33 @@ branch.getSuccessorRegions(sourceIndex, constantOperands, successors); } + /// Given a operation with successor regions, one of those regions, + /// and the lattice elements corresponding to the operation's + /// arguments, compute the latice values for block arguments + /// that are not accounted for by the branching control flow (ex. the + /// bounds of loops). By default, this method marks all such lattice elements + /// as having reached a pessimistic fixpoint. The region in the + /// RegionSuccessor and the operand latice elements are guaranteed to be + /// non-null. + virtual ChangeResult + visitNonControlFlowArguments(Operation *op, const RegionSuccessor ®ion, + ArrayRef *> operands) { + ChangeResult ret = ChangeResult::NoChange; + Region *regionPtr = region.getSuccessor(); + ValueRange succArgs = region.getSuccessorInputs(); + Block *block = ®ionPtr->front(); + Block::BlockArgListType arguments = block->getArguments(); + if (arguments.size() != succArgs.size()) { + unsigned firstArgIdx = + succArgs.empty() ? succArgs.size() + : succArgs[0].cast().getArgNumber(); + ret |= markAllPessimisticFixpoint(arguments.take_front(firstArgIdx)); + ret |= markAllPessimisticFixpoint( + arguments.drop_front(firstArgIdx + succArgs.size())); + } + return ret; + } + private: /// Type-erased wrappers that convert the abstract lattice operands to derived /// lattices and invoke the virtual hooks operating on the derived lattices. @@ -379,6 +419,14 @@ branch, sourceIndex, llvm::makeArrayRef(derivedOperandBase, operands.size()), successors); } + ChangeResult visitNonControlFlowArguments( + Operation *op, const RegionSuccessor ®ion, + ArrayRef operands) final { + LatticeElement *const *derivedOperandBase = + reinterpret_cast *const *>(operands.data()); + return visitNonControlFlowArguments( + op, region, llvm::makeArrayRef(derivedOperandBase, operands.size())); + } /// Create a new uninitialized lattice element. An optional value is provided, /// which if valid, should be used to initialize the known conservative state diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Operation.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallPtrSet.h" #include @@ -113,6 +114,7 @@ /// the parent operation results. void visitRegionSuccessors( Operation *parentOp, ArrayRef regionSuccessors, + ArrayRef operandLattices, function_ref)> getInputsForRegion); /// Visit the given terminator operation and compute any necessary lattice @@ -460,7 +462,7 @@ if (successors.empty()) return markAllPessimisticFixpoint(branch, branch->getResults()); return visitRegionSuccessors( - branch, successors, [&](Optional index) { + branch, successors, operandLattices, [&](Optional index) { assert(index && "expected valid region index"); return branch.getSuccessorEntryOperands(*index); }); @@ -468,6 +470,7 @@ void ForwardDataFlowSolver::visitRegionSuccessors( Operation *parentOp, ArrayRef regionSuccessors, + ArrayRef operandLattices, function_ref)> getInputsForRegion) { for (const RegionSuccessor &it : regionSuccessors) { Region *region = it.getSuccessor(); @@ -514,22 +517,17 @@ if (llvm::all_of(arguments, [&](Value arg) { return isAtFixpoint(arg); })) continue; - // Mark any arguments that do not receive inputs as having reached a - // pessimistic fixpoint, we won't be able to discern if they are constant. - // TODO: This isn't exactly ideal. There may be situations in which a - // region operation can provide information for certain results that - // aren't part of the control flow. if (succArgs.size() != arguments.size()) { - if (succArgs.empty()) { - markAllPessimisticFixpoint(arguments); - continue; + if (analysis.visitNonControlFlowArguments( + parentOp, it, operandLattices) == ChangeResult::Change) { + unsigned firstArgIdx = + succArgs.empty() ? succArgs.size() + : succArgs[0].cast().getArgNumber(); + for (Value v : arguments.take_front(firstArgIdx)) + visitUsers(v); + for (Value v : arguments.drop_front(firstArgIdx + succArgs.size())) + visitUsers(v); } - - unsigned firstArgIdx = succArgs[0].cast().getArgNumber(); - markAllPessimisticFixpointAndVisitUsers( - arguments.take_front(firstArgIdx)); - markAllPessimisticFixpointAndVisitUsers( - arguments.drop_front(firstArgIdx + succArgs.size())); } // Update the lattice of arguments that have inputs from the predecessor. @@ -573,12 +571,14 @@ // Try to get "region-like" successor operands if possible in order to // propagate the operand states to the successors. if (isRegionReturnLike(op)) { - return visitRegionSuccessors( - parentOp, regionSuccessors, [&](Optional regionIndex) { - // Determine the individual region successor operands for the given - // region index (if any). - return *getRegionBranchSuccessorOperands(op, regionIndex); - }); + return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices, + [&](Optional regionIndex) { + // Determine the individual region + // successor operands for the given region + // index (if any). + return *getRegionBranchSuccessorOperands( + op, regionIndex); + }); } // If this terminator is not "region-like", conservatively mark all of the