diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -525,6 +525,10 @@ /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); + + /// Returns the successor that would be chosen with the given constant + /// operands. Returns nullptr if a single successor could not be chosen. + Block *getSuccessorForOperands(ArrayRef); }]; let hasCanonicalizer = 1; @@ -1021,6 +1025,10 @@ eraseSuccessorOperand(falseIndex, index); } + /// Returns the successor that would be chosen with the given constant + /// operands. Returns nullptr if a single successor could not be chosen. + Block *getSuccessorForOperands(ArrayRef operands); + private: /// Get the index of the first true destination operand. unsigned getTrueDestOperandIndex() { return 1; } diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -68,6 +68,14 @@ } return llvm::None; }] + >, + InterfaceMethod<[{ + Returns the successor that would be chosen with the given constant + operands. Returns nullptr if a single successor could not be chosen. + }], + "Block *", "getSuccessorForOperands", + (ins "ArrayRef":$operands), [{}], + /*defaultImplementation=*/[{ return nullptr; }] > ]; diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -119,6 +119,11 @@ /// Clear out any constants cached inside of the folder. void clear(); + /// Get or create a constant using the given builder. On success this returns + /// the constant operation, nullptr otherwise. + Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, + Attribute value, Type type, Location loc); + private: /// This map keeps track of uniqued constants by dialect, attribute, and type. /// A constant operation materializes an attribute with a type. Dialects may diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -76,6 +76,10 @@ /// the CallGraph. std::unique_ptr createInlinerPass(); +/// Creates a pass which performs sparse conditional constant propagation over +/// nested operations. +std::unique_ptr createSCCPPass(); + /// Creates a pass which delete symbol operations that are unreachable. This /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -273,6 +273,20 @@ let constructor = "mlir::createPrintOpGraphPass()"; } +def SCCP : Pass<"sccp"> { + let summary = "Sparse Conditional Constant Propagation"; + let description = [{ + This pass implements a general algorithm for sparse conditional constant + propagation. This algorithm detects values that are known to be constant and + optimistically propagates this throughout the IR. Any values proven to be + constant are replaced, and removed if possible. + + This implementation is based on the algorithm described by Wegman and Zadeck + in [“Constant Propagation with Conditional Branches”](https://dl.acm.org/doi/10.1145/103135.103136) (1991). + }]; + let constructor = "mlir::createSCCPPass()"; +} + def StripDebugInfo : Pass<"strip-debuginfo"> { let summary = "Strip debug info from all operations"; let description = [{ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -529,6 +529,8 @@ bool BranchOp::canEraseSuccessorOperand() { return true; } +Block *BranchOp::getSuccessorForOperands(ArrayRef) { return dest(); } + //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// @@ -795,6 +797,14 @@ bool CondBranchOp::canEraseSuccessorOperand() { return true; } +Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { + if (BoolAttr condAttr = operands.front().dyn_cast_or_null()) + return condAttr.getValue() ? trueDest() : falseDest(); + if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) + return condAttr.getValue().isOneValue() ? trueDest() : falseDest(); + return nullptr; +} + //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ OpStats.cpp ParallelLoopCollapsing.cpp PipelineDataTransfer.cpp + SCCP.cpp StripDebugInfo.cpp SymbolDCE.cpp ViewOpGraph.cpp diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/SCCP.cpp @@ -0,0 +1,531 @@ +//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass performs a sparse conditional constant propagation +// in MLIR. It identifies values known to be constant, propagates that +// information throughout the IR, and replaces them. This is done with an +// optimisitic dataflow analysis that assumes that all values are constant until +// proven otherwise. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffects.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +/// This class represents a single lattice value. A lattive value corresponds to +/// the various different states that a value in the SCCP dataflow anaylsis can +/// take. See 'Kind' below for more details on the different states a value can +/// take. +class LatticeValue { + enum Kind { + /// A value with a yet to be determined value. This state may be changed to + /// anything. + Unknown, + + /// A value that is known to be a constant. This state may be changed to + /// overdefined. + Constant, + + /// A value that cannot statically be determined to be a constant. This + /// state cannot be changed. + Overdefined + }; + +public: + /// Initialize a lattice value with "Unknown". + LatticeValue() + : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {} + /// Initialize a lattice value with a constant. + LatticeValue(Attribute attr, Dialect *dialect) + : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {} + + /// Returns true if this lattice value is unknown. + bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; } + + /// Mark the lattice value as overdefined. + void markOverdefined() { + constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined); + constantDialect = nullptr; + } + + /// Returns true if the lattice is overdefined. + bool isOverdefined() const { + return constantAndTag.getInt() == Kind::Overdefined; + } + + /// Mark the lattice value as constant. + void markConstant(Attribute value, Dialect *dialect) { + constantAndTag.setPointerAndInt(value, Kind::Constant); + constantDialect = dialect; + } + + /// If this lattice is constant, return the constant. Returns nullptr + /// otherwise. + Attribute getConstant() const { return constantAndTag.getPointer(); } + + /// If this lattice is constant, return the dialect to use when materializing + /// the constant. + Dialect *getConstantDialect() const { + assert(getConstant() && "expected valid constant"); + return constantDialect; + } + + /// Merge in the value of the 'rhs' lattice into this one. Returns true if the + /// lattice value changed. + bool mergeIn(const LatticeValue &rhs) { + // If we are already overdefined, or rhs is unknown, there is nothing to do. + if (isOverdefined() || rhs.isUnknown()) + return false; + // If we are unknown, just take the value of rhs. + if (isUnknown()) { + constantAndTag = rhs.constantAndTag; + constantDialect = rhs.constantDialect; + return true; + } + + // Otherwise, if this value doesn't match rhs go straight to overdefined. + if (constantAndTag != rhs.constantAndTag) { + markOverdefined(); + return true; + } + return false; + } + +private: + /// The attribute value if this is a constant and the tag for the element + /// kind. + llvm::PointerIntPair constantAndTag; + + /// The dialect the constant originated from. This is only valid if the + /// lattice is a constant. This is not used as part of the key, and is only + /// needed to materialize the held constant if necessary. + Dialect *constantDialect; +}; + +/// This class represents the solver for the SCCP analysis. This class acts as +/// the propagation engine for computing which values form constants. +class SCCPSolver { +public: + /// Initialize the solver with a given set of regions. + SCCPSolver(MutableArrayRef regions); + + /// Run the solver until it converges. + void solve(); + + /// Rewrite the given regions using the computing analysis. This replaces the + /// uses of all values that have been computed to be constant, and erases as + /// any newly dead operations. + void rewrite(MLIRContext *context, MutableArrayRef regions); + +private: + /// Replace the given value with a constant if the corresponding lattice + /// represents a constant. Returns success if the value was replaced, failure + /// otherwise. + LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder, + Value value); + + /// Visit the given operation and compute any necessary lattice state. + void visitOperation(Operation *op); + + /// Visit the given block and compute any necessary lattice state. + void visitBlock(Block *block); + + /// Visit argument #'i' of the given block and compute any necessary lattice + /// state. + void visitBlockArgument(Block *block, int i); + + /// Mark the given block as executable. Returns false if the block was already + /// marked executable. + bool markBlockExecutable(Block *block); + + /// Returns true if the given block is executable. + bool isBlockExecutable(Block *block) const; + + /// Mark the edge between 'from' and 'to' as executable. + void markEdgeExecutable(Block *from, Block *to); + + /// Return true if the edge between 'from' and 'to' is executable. + bool isEdgeExecutable(Block *from, Block *to) const; + + /// Mark the given value as overdefined. This means that we cannot refine a + /// specific constant for this value. + void markOverdefined(Value value); + + /// Mark all of the given values as overdefined. + template + void markAllOverdefined(ValuesT values) { + for (auto value : values) + markOverdefined(value); + } + + /// Returns true if the given value was marked as overdefined. + bool isOverdefined(Value value) const; + + /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' + /// corresponds to the parent operation of 'to'. + void mergeIn(Operation *owner, LatticeValue &to, const LatticeValue &from); + + /// Merge in the lattice of 'from' into the lattice for 'to'. 'owner' + /// corresponds to the parent operation of 'value'. + void mergeIn(Operation *owner, Value to, Value from); + + /// The lattice for each SSA value. + DenseMap latticeValues; + + /// The set of blocks that are known to execute, or are intrinsically live. + SmallPtrSet executableBlocks; + + /// The set of control flow edges that are known to execute. + DenseSet> executableEdges; + + /// A worklist containing blocks that need to be processed. + SmallVector blockWorklist; + + /// A worklist of operations that need to be processed. + SmallVector opWorklist; +}; +} // end anonymous namespace + +SCCPSolver::SCCPSolver(MutableArrayRef regions) { + for (Region ®ion : regions) { + if (region.empty()) + continue; + Block *entryBlock = ®ion.front(); + + // Mark the entry block as executable. + markBlockExecutable(entryBlock); + + // The values passed to these regions are invisible, so mark any arguments + // as overdefined. + markAllOverdefined(entryBlock->getArguments()); + } +} + +void SCCPSolver::solve() { + while (!blockWorklist.empty() || !opWorklist.empty()) { + // Process any operations in the op worklist. + while (!opWorklist.empty()) { + Operation *op = opWorklist.pop_back_val(); + + // Visit all of the live users to propagate changes to this operation. + for (Operation *user : op->getUsers()) { + if (isBlockExecutable(user->getBlock())) + visitOperation(user); + } + } + + // Process any blocks in the block worklist. + while (!blockWorklist.empty()) + visitBlock(blockWorklist.pop_back_val()); + } +} + +void SCCPSolver::rewrite(MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : region) + if (isBlockExecutable(&block)) + worklist.push_back(&block); + }; + + // An operation folder used to create and unique constants. + OperationFolder folder(context); + OpBuilder builder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + replaceWithConstant(builder, folder, arg); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= succeeded(replaceWithConstant(builder, folder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + } +} + +LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder, + OperationFolder &folder, + Value value) { + auto &latticeValue = latticeValues[value]; + Attribute attrResult = latticeValue.getConstant(); + if (!attrResult) + return failure(); + + // Attempt to materialize a constant for the given value. + Dialect *dialect = latticeValue.getConstantDialect(); + Value constant = folder.getOrCreateConstant(builder, dialect, attrResult, + value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + return success(); +} + +void SCCPSolver::visitOperation(Operation *op) { + // Process region holding operations. + if (op->getNumRegions()) { + for (Region ®ion : op->getRegions()) { + if (region.empty()) + continue; + Block *entryBlock = ®ion.front(); + markBlockExecutable(entryBlock); + + // TODO: Add an interface to map operands to region arguments. After that, + // we can properly map the values for region arguments. + markAllOverdefined(entryBlock->getArguments()); + } + + // Don't try to fold the results as we can't guarantee folds won't be + // in-place. + return markAllOverdefined(op->getResults()); + } + + // If this op produces no results, it can't produce any constants. + if (op->getNumResults() == 0 || op->getNumRegions() != 0) + return markAllOverdefined(op->getResults()); + + // If all of the results of this operation are already overdefined, bail out + // early. + auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); }; + if (llvm::all_of(op->getResults(), isOverdefinedFn)) + return; + + // Collect the lattice values for the operands. + SmallVector operandLattices; + for (Value operand : op->getOperands()) + operandLattices.push_back(latticeValues[operand]); + + // If any of the operands are still unknown, wait for them them to resolve. + auto isUnknownFn = [&](LatticeValue value) { return value.isUnknown(); }; + if (llvm::any_of(operandLattices, isUnknownFn)) + return; + + // Collect all of the constant operands feeding into this operation. + SmallVector operandConstants; + operandConstants.reserve(operandLattices.size()); + for (auto &operandLattice : operandLattices) + operandConstants.push_back(operandLattice.getConstant()); + + // Save the original operands and attributes just in case the operation folds + // in-place. The constant passed in may not correspond to the real runtime + // value, so in-place updates are not allowed. + SmallVector originalOperands(op->getOperands()); + NamedAttributeList originalAttrs = op->getAttrList(); + + // Try to fold the result of this operation to a constant. If folding fails or + // was an in-place fold, mark the results as overdefined. + SmallVector foldResults; + foldResults.reserve(op->getNumResults()); + if (failed(op->fold(operandConstants, foldResults))) + return markAllOverdefined(op->getResults()); + + // If the folding was in-place, mark the results as overdefined and reset the + // operation. + if (foldResults.empty()) { + op->setOperands(originalOperands); + op->setAttrs(originalAttrs); + return markAllOverdefined(op->getResults()); + } + + // Merge the fold results into the lattice for this operation. + assert(foldResults.size() == op->getNumResults() && "invalid result size"); + Dialect *opDialect = op->getDialect(); + for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { + LatticeValue &resultLattice = latticeValues[op->getResult(i)]; + + // Merge in the result of the fold, either a constant or a value. + OpFoldResult foldResult = foldResults[i]; + if (Attribute foldAttr = foldResult.dyn_cast()) + mergeIn(op, resultLattice, LatticeValue(foldAttr, opDialect)); + else + mergeIn(op, resultLattice, latticeValues[foldResult.get()]); + } +} + +void SCCPSolver::visitBlock(Block *block) { + // If the block is not the entry block we need to compute the lattice state + // for the block arguments. Entry block argument lattices are computed + // elsewhere, such as when visiting the parent operation. + if (!block->isEntryBlock()) { + for (int i : llvm::seq(0, block->getNumArguments())) + visitBlockArgument(block, i); + } + + // Visit all of the operations within the block. + for (Operation &op : *block) + visitOperation(&op); + + // Mark the successor edges of this block as executable as necessary. + Operation *terminator = block->getTerminator(); + + // If we can't handle this terminator, assume all edges are executable. + auto branch = dyn_cast(terminator); + if (!branch) { + for (Block *succ : block->getSuccessors()) + markEdgeExecutable(block, succ); + return; + } + + // Resolve any constant operands to the terminator. + SmallVector termOperands; + termOperands.reserve(terminator->getNumOperands()); + for (Value operand : terminator->getOperands()) { + // Make sure all of the operands are resolved first. + auto &operandLattice = latticeValues[operand]; + if (operandLattice.isUnknown()) + return; + termOperands.push_back(operandLattice.getConstant()); + } + + // Try to resolve to a specific successor with the constant operands. + if (Block *singleSucc = branch.getSuccessorForOperands(termOperands)) { + markEdgeExecutable(block, singleSucc); + return; + } + + // Otherwise, conservatively treat all edges as executable. + for (Block *succ : block->getSuccessors()) + markEdgeExecutable(block, succ); +} + +void SCCPSolver::visitBlockArgument(Block *block, int i) { + BlockArgument arg = block->getArgument(i); + if (isOverdefined(arg)) + return; + + bool updatedLattice = false; + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + Block *pred = *it; + + // We only care about this predecessor if it is going to execute. + if (!isEdgeExecutable(pred, block)) + continue; + + // Try to get the operand forwarded by the predecessor. If we can't reason + // about the terminator of the predecessor, mark overdefined. + Optional branchOperands; + if (auto branch = dyn_cast(pred->getTerminator())) + branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); + if (!branchOperands) { + updatedLattice = true; + latticeValues[arg].markOverdefined(); + break; + } + + const LatticeValue &operandLattice = latticeValues[(*branchOperands)[i]]; + LatticeValue &argLattice = latticeValues[arg]; + updatedLattice |= argLattice.mergeIn(operandLattice); + if (argLattice.isOverdefined()) + break; + } + + // If the lattice was updated, visit any executable users of the argument. + if (updatedLattice) { + for (Operation *user : arg.getUsers()) + if (isBlockExecutable(user->getBlock())) + visitOperation(user); + } +} + +bool SCCPSolver::markBlockExecutable(Block *block) { + bool marked = executableBlocks.insert(block).second; + if (marked) + blockWorklist.push_back(block); + return marked; +} + +bool SCCPSolver::isBlockExecutable(Block *block) const { + return executableBlocks.count(block); +} + +void SCCPSolver::markEdgeExecutable(Block *from, Block *to) { + if (!executableEdges.insert(std::make_pair(from, to)).second) + return; + // Mark the destination as executable, and reprocess its arguments if it was + // already executable. + if (!markBlockExecutable(to)) { + for (int i : llvm::seq(0, to->getNumArguments())) + visitBlockArgument(to, i); + } +} + +bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const { + return executableEdges.count(std::make_pair(from, to)); +} + +void SCCPSolver::markOverdefined(Value value) { + latticeValues[value].markOverdefined(); +} + +bool SCCPSolver::isOverdefined(Value value) const { + auto it = latticeValues.find(value); + return it != latticeValues.end() && it->second.isOverdefined(); +} + +void SCCPSolver::mergeIn(Operation *owner, LatticeValue &to, + const LatticeValue &from) { + if (to.mergeIn(from)) + opWorklist.push_back(owner); +} + +//===----------------------------------------------------------------------===// +// SCCP Pass +//===----------------------------------------------------------------------===// + +namespace { +struct SCCP : public SCCPBase { + void runOnOperation() override; +}; +} // end anonymous namespace + +void SCCP::runOnOperation() { + Operation *op = getOperation(); + + // Solve for SCCP constraints within nested regions. + SCCPSolver solver(op->getRegions()); + solver.solve(); + + // Cleanup any operations using the solver analysis. + solver.rewrite(&getContext(), op->getRegions()); +} + +std::unique_ptr mlir::createSCCPPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -140,6 +140,27 @@ referencedDialects.clear(); } +/// Get or create a constant using the given builder. On success this returns +/// the constant operation, nullptr otherwise. +Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect, + Attribute value, Type type, + Location loc) { + OpBuilder::InsertionGuard foldGuard(builder); + + // Use the builder insertion block to find an insertion point for the + // constant. + auto *insertRegion = + getInsertionRegion(interfaces, builder.getInsertionBlock()); + auto &entry = insertRegion->front(); + builder.setInsertionPoint(&entry, entry.begin()); + + // Get the constant map for the insertion region of this operation. + auto &uniquedConstants = foldScopes[insertRegion]; + Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, + builder, value, type, loc); + return constOp ? constOp->getResult(0) : Value(); +} + /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult OperationFolder::tryToFold( diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/sccp.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func(sccp)" -split-input-file | FileCheck %s + +/// Check simple forward constant propagation without any control flow. + +// CHECK-LABEL: func @no_control_flow +func @no_control_flow(%arg0: i32) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: return %[[CST]] : i32 + + %cond = constant 1 : i1 + %cst_1 = constant 1 : i32 + %select = select %cond, %cst_1, %arg0 : i32 + return %select : i32 +} + +/// Check that a constant is properly propagated when only one edge of a branch +/// is taken. + +// CHECK-LABEL: func @simple_control_flow +func @simple_control_flow(%arg0 : i32) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + + %cond = constant true + %1 = constant 1 : i32 + cond_br %cond, ^bb1, ^bb2(%arg0 : i32) + +^bb1: + br ^bb2(%1 : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%{{.*}}: i32): + // CHECK: return %[[CST]] : i32 + + return %arg : i32 +} + +/// Check that the arguments go to overdefined if the branch cannot detect when +/// a specific successor is taken. + +// CHECK-LABEL: func @simple_control_flow_overdefined +func @simple_control_flow_overdefined(%arg0 : i32, %arg1 : i1) -> i32 { + %1 = constant 1 : i32 + cond_br %arg1, ^bb1, ^bb2(%arg0 : i32) + +^bb1: + br ^bb2(%1 : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32): + // CHECK: return %[[ARG]] : i32 + + return %arg : i32 +} + +/// Check that the arguments go to overdefined if the branch is unknown. + +// CHECK-LABEL: func @unknown_terminator +func @unknown_terminator(%arg0 : i32, %arg1 : i1) -> i32 { + %1 = constant 1 : i32 + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1: + br ^bb2(%1 : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32): + // CHECK: return %[[ARG]] : i32 + + return %arg : i32 +} + +/// Check that arguments are properly merged across loop-like control flow. + +func @ext_cond_fn() -> i1 + +// CHECK-LABEL: func @simple_loop +func @simple_loop(%arg0 : i32, %cond1 : i1) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + + %cst_1 = constant 1 : i32 + cond_br %cond1, ^bb1(%cst_1 : i32), ^bb2(%cst_1 : i32) + +^bb1(%iv: i32): + // CHECK: ^bb1(%{{.*}}: i32): + // CHECK-NEXT: %[[COND:.*]] = call @ext_cond_fn() + // CHECK-NEXT: cond_br %[[COND]], ^bb1(%[[CST]] : i32), ^bb2(%[[CST]] : i32) + + %cst_0 = constant 0 : i32 + %res = addi %iv, %cst_0 : i32 + %cond2 = call @ext_cond_fn() : () -> i1 + cond_br %cond2, ^bb1(%res : i32), ^bb2(%res : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%{{.*}}: i32): + // CHECK: return %[[CST]] : i32 + + return %arg : i32 +} + +/// Test that we can properly propagate within inner control, and in situations +/// where the executable edges within the CFG are sensitive to the current state +/// of the analysis. + +// CHECK-LABEL: func @simple_loop_inner_control_flow +func @simple_loop_inner_control_flow(%arg0 : i32) -> i32 { + // CHECK-DAG: %[[CST:.*]] = constant 1 : i32 + // CHECK-DAG: %[[TRUE:.*]] = constant 1 : i1 + + %cst_1 = constant 1 : i32 + br ^bb1(%cst_1 : i32) + +^bb1(%iv: i32): + %cond2 = call @ext_cond_fn() : () -> i1 + cond_br %cond2, ^bb5(%iv : i32), ^bb2 + +^bb2: + // CHECK: ^bb2: + // CHECK: cond_br %[[TRUE]], ^bb3, ^bb4 + + %cst_20 = constant 20 : i32 + %cond = cmpi "ult", %iv, %cst_20 : i32 + cond_br %cond, ^bb3, ^bb4 + +^bb3: + // CHECK: ^bb3: + // CHECK: br ^bb1(%[[CST]] : i32) + + %cst_1_2 = constant 1 : i32 + br ^bb1(%cst_1_2 : i32) + +^bb4: + %iv_inc = addi %iv, %cst_1 : i32 + br ^bb1(%iv_inc : i32) + +^bb5(%result: i32): + // CHECK: ^bb5(%{{.*}}: i32): + // CHECK: return %[[CST]] : i32 + + return %result : i32 +} + +/// Check that arguments go to overdefined when loop backedges produce a +/// conflicting value. + +func @ext_cond_and_value_fn() -> (i1, i32) + +// CHECK-LABEL: func @simple_loop_overdefined +func @simple_loop_overdefined(%arg0 : i32, %cond1 : i1) -> i32 { + %cst_1 = constant 1 : i32 + cond_br %cond1, ^bb1(%cst_1 : i32), ^bb2(%cst_1 : i32) + +^bb1(%iv: i32): + %cond2, %res = call @ext_cond_and_value_fn() : () -> (i1, i32) + cond_br %cond2, ^bb1(%res : i32), ^bb2(%res : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32): + // CHECK: return %[[ARG]] : i32 + + return %arg : i32 +}