diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -596,6 +596,10 @@ /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); + + /// Returns the successor that would be chosen with the given constant + /// operands. Returns nullptr if a single successor could not be chosen. + Block *getSuccessorForOperands(ArrayRef); }]; let hasCanonicalizer = 1; @@ -1092,6 +1096,10 @@ eraseSuccessorOperand(falseIndex, index); } + /// Returns the successor that would be chosen with the given constant + /// operands. Returns nullptr if a single successor could not be chosen. + Block *getSuccessorForOperands(ArrayRef operands); + private: /// Get the index of the first true destination operand. unsigned getTrueDestOperandIndex() { return 1; } diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -68,6 +68,14 @@ } return llvm::None; }] + >, + InterfaceMethod<[{ + Returns the successor that would be chosen with the given constant + operands. Returns nullptr if a single successor could not be chosen. + }], + "Block *", "getSuccessorForOperands", + (ins "ArrayRef":$operands), [{}], + /*defaultImplementation=*/[{ return nullptr; }] > ]; diff --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h --- a/mlir/include/mlir/Transforms/FoldUtils.h +++ b/mlir/include/mlir/Transforms/FoldUtils.h @@ -119,6 +119,11 @@ /// Clear out any constants cached inside of the folder. void clear(); + /// Get or create a constant using the given builder. On success this returns + /// the constant operation, nullptr otherwise. + Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect, + Attribute value, Type type, Location loc); + private: /// This map keeps track of uniqued constants by dialect, attribute, and type. /// A constant operation materializes an attribute with a type. Dialects may diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -76,6 +76,10 @@ /// the CallGraph. std::unique_ptr createInlinerPass(); +/// Creates a pass which performs sparse conditional constant propagation over +/// nested operations. +std::unique_ptr createSCCPPass(); + /// Creates a pass which delete symbol operations that are unreachable. This /// pass may *only* be scheduled on an operation that defines a SymbolTable. std::unique_ptr createSymbolDCEPass(); diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -273,6 +273,20 @@ let constructor = "mlir::createPrintOpGraphPass()"; } +def SCCP : Pass<"sccp"> { + let summary = "Sparse Conditional Constant Propagation"; + let description = [{ + This pass implements a general algorithm for sparse conditional constant + propagation. This algorithm detects values that are known to be constant and + optimistically propagates this throughout the IR. Any values proven to be + constant are replaced, and removed if possible. + + This implementation is based on the algorithm described by Wegman and Zadeck + in [“Constant Propagation with Conditional Branches”](https://dl.acm.org/doi/10.1145/103135.103136) (1991). + }]; + let constructor = "mlir::createSCCPPass()"; +} + def StripDebugInfo : Pass<"strip-debuginfo"> { let summary = "Strip debug info from all operations"; let description = [{ diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -597,6 +597,8 @@ bool BranchOp::canEraseSuccessorOperand() { return true; } +Block *BranchOp::getSuccessorForOperands(ArrayRef) { return dest(); } + //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// @@ -863,6 +865,14 @@ bool CondBranchOp::canEraseSuccessorOperand() { return true; } +Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { + if (BoolAttr condAttr = operands.front().dyn_cast_or_null()) + return condAttr.getValue() ? trueDest() : falseDest(); + if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) + return condAttr.getValue().isOneValue() ? trueDest() : falseDest(); + return nullptr; +} + //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ OpStats.cpp ParallelLoopCollapsing.cpp PipelineDataTransfer.cpp + SCCP.cpp StripDebugInfo.cpp SymbolDCE.cpp ViewOpGraph.cpp diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/SCCP.cpp @@ -0,0 +1,539 @@ +//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transformation pass performs a sparse conditional constant propagation +// in MLIR. It identifies values known to be constant, propagates that +// information throughout the IR, and replaces them. This is done with an +// optimisitic dataflow analysis that assumes that all values are constant until +// proven otherwise. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/SideEffects.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/FoldUtils.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; + +namespace { +/// This class represents a single lattice value. A lattive value corresponds to +/// the various different states that a value in the SCCP dataflow anaylsis can +/// take. See 'Kind' below for more details on the different states a value can +/// take. +class LatticeValue { + enum Kind { + /// A value with a yet to be determined value. This state may be changed to + /// anything. + Unknown, + + /// A value that is known to be a constant. This state may be changed to + /// overdefined. + Constant, + + /// A value that cannot statically be determined to be a constant. This + /// state cannot be changed. + Overdefined + }; + +public: + /// Initialize a lattice value with "Unknown". + LatticeValue() + : constantAndTag(nullptr, Kind::Unknown), constantDialect(nullptr) {} + /// Initialize a lattice value with a constant. + LatticeValue(Attribute attr, Dialect *dialect) + : constantAndTag(attr, Kind::Constant), constantDialect(dialect) {} + + /// Returns true if this lattice value is unknown. + bool isUnknown() const { return constantAndTag.getInt() == Kind::Unknown; } + + /// Mark the lattice value as overdefined. + void markOverdefined() { + constantAndTag.setPointerAndInt(nullptr, Kind::Overdefined); + constantDialect = nullptr; + } + + /// Returns true if the lattice is overdefined. + bool isOverdefined() const { + return constantAndTag.getInt() == Kind::Overdefined; + } + + /// Mark the lattice value as constant. + void markConstant(Attribute value, Dialect *dialect) { + constantAndTag.setPointerAndInt(value, Kind::Constant); + constantDialect = dialect; + } + + /// If this lattice is constant, return the constant. Returns nullptr + /// otherwise. + Attribute getConstant() const { return constantAndTag.getPointer(); } + + /// If this lattice is constant, return the dialect to use when materializing + /// the constant. + Dialect *getConstantDialect() const { + assert(getConstant() && "expected valid constant"); + return constantDialect; + } + + /// Merge in the value of the 'rhs' lattice into this one. Returns true if the + /// lattice value changed. + bool meet(const LatticeValue &rhs) { + // If we are already overdefined, or rhs is unknown, there is nothing to do. + if (isOverdefined() || rhs.isUnknown()) + return false; + // If we are unknown, just take the value of rhs. + if (isUnknown()) { + constantAndTag = rhs.constantAndTag; + constantDialect = rhs.constantDialect; + return true; + } + + // Otherwise, if this value doesn't match rhs go straight to overdefined. + if (constantAndTag != rhs.constantAndTag) { + markOverdefined(); + return true; + } + return false; + } + +private: + /// The attribute value if this is a constant and the tag for the element + /// kind. + llvm::PointerIntPair constantAndTag; + + /// The dialect the constant originated from. This is only valid if the + /// lattice is a constant. This is not used as part of the key, and is only + /// needed to materialize the held constant if necessary. + Dialect *constantDialect; +}; + +/// This class represents the solver for the SCCP analysis. This class acts as +/// the propagation engine for computing which values form constants. +class SCCPSolver { +public: + /// Initialize the solver with a given set of regions. + SCCPSolver(MutableArrayRef regions); + + /// Run the solver until it converges. + void solve(); + + /// Rewrite the given regions using the computing analysis. This replaces the + /// uses of all values that have been computed to be constant, and erases as + /// many newly dead operations. + void rewrite(MLIRContext *context, MutableArrayRef regions); + +private: + /// Replace the given value with a constant if the corresponding lattice + /// represents a constant. Returns success if the value was replaced, failure + /// otherwise. + LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder, + Value value); + + /// Visit the given operation and compute any necessary lattice state. + void visitOperation(Operation *op); + + /// Visit the given operation, which defines regions, and compute any + /// necessary lattice state. This also resolves the lattice state of both the + /// operation results and any nested regions. + void visitRegionOperation(Operation *op); + + /// Visit the given terminator operation and compute any necessary lattice + /// state. + void visitTerminatorOperation(Operation *op, + ArrayRef constantOperands); + + /// Visit the given block and compute any necessary lattice state. + void visitBlock(Block *block); + + /// Visit argument #'i' of the given block and compute any necessary lattice + /// state. + void visitBlockArgument(Block *block, int i); + + /// Mark the given block as executable. Returns false if the block was already + /// marked executable. + bool markBlockExecutable(Block *block); + + /// Returns true if the given block is executable. + bool isBlockExecutable(Block *block) const; + + /// Mark the edge between 'from' and 'to' as executable. + void markEdgeExecutable(Block *from, Block *to); + + /// Return true if the edge between 'from' and 'to' is executable. + bool isEdgeExecutable(Block *from, Block *to) const; + + /// Mark the given value as overdefined. This means that we cannot refine a + /// specific constant for this value. + void markOverdefined(Value value); + + /// Mark all of the given values as overdefined. + template + void markAllOverdefined(ValuesT values) { + for (auto value : values) + markOverdefined(value); + } + template + void markAllOverdefined(Operation *op, ValuesT values) { + markAllOverdefined(values); + opWorklist.push_back(op); + } + + /// Returns true if the given value was marked as overdefined. + bool isOverdefined(Value value) const; + + /// Merge in the given lattice 'from' into the lattice 'to'. 'owner' + /// corresponds to the parent operation of 'to'. + void meet(Operation *owner, LatticeValue &to, const LatticeValue &from); + + /// The lattice for each SSA value. + DenseMap latticeValues; + + /// The set of blocks that are known to execute, or are intrinsically live. + SmallPtrSet executableBlocks; + + /// The set of control flow edges that are known to execute. + DenseSet> executableEdges; + + /// A worklist containing blocks that need to be processed. + SmallVector blockWorklist; + + /// A worklist of operations that need to be processed. + SmallVector opWorklist; +}; +} // end anonymous namespace + +SCCPSolver::SCCPSolver(MutableArrayRef regions) { + for (Region ®ion : regions) { + if (region.empty()) + continue; + Block *entryBlock = ®ion.front(); + + // Mark the entry block as executable. + markBlockExecutable(entryBlock); + + // The values passed to these regions are invisible, so mark any arguments + // as overdefined. + markAllOverdefined(entryBlock->getArguments()); + } +} + +void SCCPSolver::solve() { + while (!blockWorklist.empty() || !opWorklist.empty()) { + // Process any operations in the op worklist. + while (!opWorklist.empty()) { + Operation *op = opWorklist.pop_back_val(); + + // Visit all of the live users to propagate changes to this operation. + for (Operation *user : op->getUsers()) { + if (isBlockExecutable(user->getBlock())) + visitOperation(user); + } + } + + // Process any blocks in the block worklist. + while (!blockWorklist.empty()) + visitBlock(blockWorklist.pop_back_val()); + } +} + +void SCCPSolver::rewrite(MLIRContext *context, + MutableArrayRef initialRegions) { + SmallVector worklist; + auto addToWorklist = [&](MutableArrayRef regions) { + for (Region ®ion : regions) + for (Block &block : region) + if (isBlockExecutable(&block)) + worklist.push_back(&block); + }; + + // An operation folder used to create and unique constants. + OperationFolder folder(context); + OpBuilder builder(context); + + addToWorklist(initialRegions); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + + // Replace any block arguments with constants. + builder.setInsertionPointToStart(block); + for (BlockArgument arg : block->getArguments()) + replaceWithConstant(builder, folder, arg); + + for (Operation &op : llvm::make_early_inc_range(*block)) { + builder.setInsertionPoint(&op); + + // Replace any result with constants. + bool replacedAll = op.getNumResults() != 0; + for (Value res : op.getResults()) + replacedAll &= succeeded(replaceWithConstant(builder, folder, res)); + + // If all of the results of the operation were replaced, try to erase + // the operation completely. + if (replacedAll && wouldOpBeTriviallyDead(&op)) { + assert(op.use_empty() && "expected all uses to be replaced"); + op.erase(); + continue; + } + + // Add any the regions of this operation to the worklist. + addToWorklist(op.getRegions()); + } + } +} + +LogicalResult SCCPSolver::replaceWithConstant(OpBuilder &builder, + OperationFolder &folder, + Value value) { + auto it = latticeValues.find(value); + auto attr = it == latticeValues.end() ? nullptr : it->second.getConstant(); + if (!attr) + return failure(); + + // Attempt to materialize a constant for the given value. + Dialect *dialect = it->second.getConstantDialect(); + Value constant = folder.getOrCreateConstant(builder, dialect, attr, + value.getType(), value.getLoc()); + if (!constant) + return failure(); + + value.replaceAllUsesWith(constant); + latticeValues.erase(it); + return success(); +} + +void SCCPSolver::visitOperation(Operation *op) { + // Collect all of the constant operands feeding into this operation. If any + // are not ready to be resolved, bail out and wait for them to resolve. + SmallVector operandConstants; + operandConstants.reserve(op->getNumOperands()); + for (Value operand : op->getOperands()) { + // Make sure all of the operands are resolved first. + auto &operandLattice = latticeValues[operand]; + if (operandLattice.isUnknown()) + return; + operandConstants.push_back(operandLattice.getConstant()); + } + + // If this is a terminator operation, process any control flow lattice state. + if (op->isKnownTerminator()) + visitTerminatorOperation(op, operandConstants); + + // Process region holding operations. The region visitor processes result + // values, so we can exit afterwards. + if (op->getNumRegions()) + return visitRegionOperation(op); + + // If this op produces no results, it can't produce any constants. + if (op->getNumResults() == 0) + return; + + // If all of the results of this operation are already overdefined, bail out + // early. + auto isOverdefinedFn = [&](Value value) { return isOverdefined(value); }; + if (llvm::all_of(op->getResults(), isOverdefinedFn)) + return; + + // Save the original operands and attributes just in case the operation folds + // in-place. The constant passed in may not correspond to the real runtime + // value, so in-place updates are not allowed. + SmallVector originalOperands(op->getOperands()); + NamedAttributeList originalAttrs = op->getAttrList(); + + // Simulate the result of folding this operation to a constant. If folding + // fails or was an in-place fold, mark the results as overdefined. + SmallVector foldResults; + foldResults.reserve(op->getNumResults()); + if (failed(op->fold(operandConstants, foldResults))) + return markAllOverdefined(op, op->getResults()); + + // If the folding was in-place, mark the results as overdefined and reset the + // operation. We don't allow in-place folds as the desire here is for + // simulated execution, and not general folding. + if (foldResults.empty()) { + op->setOperands(originalOperands); + op->setAttrs(originalAttrs); + return markAllOverdefined(op, op->getResults()); + } + + // Merge the fold results into the lattice for this operation. + assert(foldResults.size() == op->getNumResults() && "invalid result size"); + Dialect *opDialect = op->getDialect(); + for (unsigned i = 0, e = foldResults.size(); i != e; ++i) { + LatticeValue &resultLattice = latticeValues[op->getResult(i)]; + + // Merge in the result of the fold, either a constant or a value. + OpFoldResult foldResult = foldResults[i]; + if (Attribute foldAttr = foldResult.dyn_cast()) + meet(op, resultLattice, LatticeValue(foldAttr, opDialect)); + else + meet(op, resultLattice, latticeValues[foldResult.get()]); + } +} + +void SCCPSolver::visitRegionOperation(Operation *op) { + for (Region ®ion : op->getRegions()) { + if (region.empty()) + continue; + Block *entryBlock = ®ion.front(); + markBlockExecutable(entryBlock); + markAllOverdefined(entryBlock->getArguments()); + } + + // Don't try to simulate the results of a region operation as we can't + // guarantee that folding will be out-of-place. We don't allow in-place folds + // as the desire here is for simulated execution, and not general folding. + return markAllOverdefined(op, op->getResults()); +} + +void SCCPSolver::visitTerminatorOperation( + Operation *op, ArrayRef constantOperands) { + if (op->getNumSuccessors() == 0) + return; + + // Try to resolve to a specific successor with the constant operands. + if (auto branch = dyn_cast(op)) { + if (Block *singleSucc = branch.getSuccessorForOperands(constantOperands)) { + markEdgeExecutable(op->getBlock(), singleSucc); + return; + } + } + + // Otherwise, conservatively treat all edges as executable. + Block *block = op->getBlock(); + for (Block *succ : op->getSuccessors()) + markEdgeExecutable(block, succ); +} + +void SCCPSolver::visitBlock(Block *block) { + // If the block is not the entry block we need to compute the lattice state + // for the block arguments. Entry block argument lattices are computed + // elsewhere, such as when visiting the parent operation. + if (!block->isEntryBlock()) { + for (int i : llvm::seq(0, block->getNumArguments())) + visitBlockArgument(block, i); + } + + // Visit all of the operations within the block. + for (Operation &op : *block) + visitOperation(&op); +} + +void SCCPSolver::visitBlockArgument(Block *block, int i) { + BlockArgument arg = block->getArgument(i); + LatticeValue &argLattice = latticeValues[arg]; + if (argLattice.isOverdefined()) + return; + + bool updatedLattice = false; + for (auto it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { + Block *pred = *it; + + // We only care about this predecessor if it is going to execute. + if (!isEdgeExecutable(pred, block)) + continue; + + // Try to get the operand forwarded by the predecessor. If we can't reason + // about the terminator of the predecessor, mark overdefined. + Optional branchOperands; + if (auto branch = dyn_cast(pred->getTerminator())) + branchOperands = branch.getSuccessorOperands(it.getSuccessorIndex()); + if (!branchOperands) { + updatedLattice = true; + argLattice.markOverdefined(); + break; + } + + // If the operand hasn't been resolved, it is unknown which can merge with + // anything. + auto operandLattice = latticeValues.find((*branchOperands)[i]); + if (operandLattice == latticeValues.end()) + continue; + + // Otherwise, meet the two lattice values. + updatedLattice |= argLattice.meet(operandLattice->second); + if (argLattice.isOverdefined()) + break; + } + + // If the lattice was updated, visit any executable users of the argument. + if (updatedLattice) { + for (Operation *user : arg.getUsers()) + if (isBlockExecutable(user->getBlock())) + visitOperation(user); + } +} + +bool SCCPSolver::markBlockExecutable(Block *block) { + bool marked = executableBlocks.insert(block).second; + if (marked) + blockWorklist.push_back(block); + return marked; +} + +bool SCCPSolver::isBlockExecutable(Block *block) const { + return executableBlocks.count(block); +} + +void SCCPSolver::markEdgeExecutable(Block *from, Block *to) { + if (!executableEdges.insert(std::make_pair(from, to)).second) + return; + // Mark the destination as executable, and reprocess its arguments if it was + // already executable. + if (!markBlockExecutable(to)) { + for (int i : llvm::seq(0, to->getNumArguments())) + visitBlockArgument(to, i); + } +} + +bool SCCPSolver::isEdgeExecutable(Block *from, Block *to) const { + return executableEdges.count(std::make_pair(from, to)); +} + +void SCCPSolver::markOverdefined(Value value) { + latticeValues[value].markOverdefined(); +} + +bool SCCPSolver::isOverdefined(Value value) const { + auto it = latticeValues.find(value); + return it != latticeValues.end() && it->second.isOverdefined(); +} + +void SCCPSolver::meet(Operation *owner, LatticeValue &to, + const LatticeValue &from) { + if (to.meet(from)) + opWorklist.push_back(owner); +} + +//===----------------------------------------------------------------------===// +// SCCP Pass +//===----------------------------------------------------------------------===// + +namespace { +struct SCCP : public SCCPBase { + void runOnOperation() override; +}; +} // end anonymous namespace + +void SCCP::runOnOperation() { + Operation *op = getOperation(); + + // Solve for SCCP constraints within nested regions. + SCCPSolver solver(op->getRegions()); + solver.solve(); + + // Cleanup any operations using the solver analysis. + solver.rewrite(&getContext(), op->getRegions()); +} + +std::unique_ptr mlir::createSCCPPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp --- a/mlir/lib/Transforms/Utils/FoldUtils.cpp +++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp @@ -140,6 +140,27 @@ referencedDialects.clear(); } +/// Get or create a constant using the given builder. On success this returns +/// the constant operation, nullptr otherwise. +Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect, + Attribute value, Type type, + Location loc) { + OpBuilder::InsertionGuard foldGuard(builder); + + // Use the builder insertion block to find an insertion point for the + // constant. + auto *insertRegion = + getInsertionRegion(interfaces, builder.getInsertionBlock()); + auto &entry = insertRegion->front(); + builder.setInsertionPoint(&entry, entry.begin()); + + // Get the constant map for the insertion region of this operation. + auto &uniquedConstants = foldScopes[insertRegion]; + Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, + builder, value, type, loc); + return constOp ? constOp->getResult(0) : Value(); +} + /// Tries to perform folding on the given `op`. If successful, populates /// `results` with the results of the folding. LogicalResult OperationFolder::tryToFold( diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/sccp.mlir @@ -0,0 +1,180 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func(sccp)" -split-input-file | FileCheck %s + +/// Check simple forward constant propagation without any control flow. + +// CHECK-LABEL: func @no_control_flow +func @no_control_flow(%arg0: i32) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK: return %[[CST]] : i32 + + %cond = constant 1 : i1 + %cst_1 = constant 1 : i32 + %select = select %cond, %cst_1, %arg0 : i32 + return %select : i32 +} + +/// Check that a constant is properly propagated when only one edge of a branch +/// is taken. + +// CHECK-LABEL: func @simple_control_flow +func @simple_control_flow(%arg0 : i32) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + + %cond = constant true + %1 = constant 1 : i32 + cond_br %cond, ^bb1, ^bb2(%arg0 : i32) + +^bb1: + br ^bb2(%1 : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%{{.*}}: i32): + // CHECK: return %[[CST]] : i32 + + return %arg : i32 +} + +/// Check that the arguments go to overdefined if the branch cannot detect when +/// a specific successor is taken. + +// CHECK-LABEL: func @simple_control_flow_overdefined +func @simple_control_flow_overdefined(%arg0 : i32, %arg1 : i1) -> i32 { + %1 = constant 1 : i32 + cond_br %arg1, ^bb1, ^bb2(%arg0 : i32) + +^bb1: + br ^bb2(%1 : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32): + // CHECK: return %[[ARG]] : i32 + + return %arg : i32 +} + +/// Check that the arguments go to overdefined if there are conflicting +/// constants. + +// CHECK-LABEL: func @simple_control_flow_constant_overdefined +func @simple_control_flow_constant_overdefined(%arg0 : i32, %arg1 : i1) -> i32 { + %1 = constant 1 : i32 + %2 = constant 2 : i32 + cond_br %arg1, ^bb1, ^bb2(%arg0 : i32) + +^bb1: + br ^bb2(%2 : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32): + // CHECK: return %[[ARG]] : i32 + + return %arg : i32 +} + +/// Check that the arguments go to overdefined if the branch is unknown. + +// CHECK-LABEL: func @unknown_terminator +func @unknown_terminator(%arg0 : i32, %arg1 : i1) -> i32 { + %1 = constant 1 : i32 + "foo.cond_br"() [^bb1, ^bb2] : () -> () + +^bb1: + br ^bb2(%1 : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32): + // CHECK: return %[[ARG]] : i32 + + return %arg : i32 +} + +/// Check that arguments are properly merged across loop-like control flow. + +func @ext_cond_fn() -> i1 + +// CHECK-LABEL: func @simple_loop +func @simple_loop(%arg0 : i32, %cond1 : i1) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + + %cst_1 = constant 1 : i32 + cond_br %cond1, ^bb1(%cst_1 : i32), ^bb2(%cst_1 : i32) + +^bb1(%iv: i32): + // CHECK: ^bb1(%{{.*}}: i32): + // CHECK-NEXT: %[[COND:.*]] = call @ext_cond_fn() + // CHECK-NEXT: cond_br %[[COND]], ^bb1(%[[CST]] : i32), ^bb2(%[[CST]] : i32) + + %cst_0 = constant 0 : i32 + %res = addi %iv, %cst_0 : i32 + %cond2 = call @ext_cond_fn() : () -> i1 + cond_br %cond2, ^bb1(%res : i32), ^bb2(%res : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%{{.*}}: i32): + // CHECK: return %[[CST]] : i32 + + return %arg : i32 +} + +/// Test that we can properly propagate within inner control, and in situations +/// where the executable edges within the CFG are sensitive to the current state +/// of the analysis. + +// CHECK-LABEL: func @simple_loop_inner_control_flow +func @simple_loop_inner_control_flow(%arg0 : i32) -> i32 { + // CHECK-DAG: %[[CST:.*]] = constant 1 : i32 + // CHECK-DAG: %[[TRUE:.*]] = constant 1 : i1 + + %cst_1 = constant 1 : i32 + br ^bb1(%cst_1 : i32) + +^bb1(%iv: i32): + %cond2 = call @ext_cond_fn() : () -> i1 + cond_br %cond2, ^bb5(%iv : i32), ^bb2 + +^bb2: + // CHECK: ^bb2: + // CHECK: cond_br %[[TRUE]], ^bb3, ^bb4 + + %cst_20 = constant 20 : i32 + %cond = cmpi "ult", %iv, %cst_20 : i32 + cond_br %cond, ^bb3, ^bb4 + +^bb3: + // CHECK: ^bb3: + // CHECK: br ^bb1(%[[CST]] : i32) + + %cst_1_2 = constant 1 : i32 + br ^bb1(%cst_1_2 : i32) + +^bb4: + %iv_inc = addi %iv, %cst_1 : i32 + br ^bb1(%iv_inc : i32) + +^bb5(%result: i32): + // CHECK: ^bb5(%{{.*}}: i32): + // CHECK: return %[[CST]] : i32 + + return %result : i32 +} + +/// Check that arguments go to overdefined when loop backedges produce a +/// conflicting value. + +func @ext_cond_and_value_fn() -> (i1, i32) + +// CHECK-LABEL: func @simple_loop_overdefined +func @simple_loop_overdefined(%arg0 : i32, %cond1 : i1) -> i32 { + %cst_1 = constant 1 : i32 + cond_br %cond1, ^bb1(%cst_1 : i32), ^bb2(%cst_1 : i32) + +^bb1(%iv: i32): + %cond2, %res = call @ext_cond_and_value_fn() : () -> (i1, i32) + cond_br %cond2, ^bb1(%res : i32), ^bb2(%res : i32) + +^bb2(%arg : i32): + // CHECK: ^bb2(%[[ARG:.*]]: i32): + // CHECK: return %[[ARG]] : i32 + + return %arg : i32 +}