diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.h +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.h @@ -17,6 +17,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffects.h" diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -13,6 +13,7 @@ #ifndef LOOP_OPS #define LOOP_OPS +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffects.td" @@ -37,6 +38,7 @@ def ForOp : Loop_Op<"for", [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { let summary = "for operation"; @@ -169,11 +171,18 @@ unsigned getNumIterOperands() { return getOperation()->getNumOperands() - getNumControlOperands(); } + + /// Return the operands used for the region at `index`, which was specified + /// as a successor by `getSuccessorRegions`. when entering. These operands + /// should correspond 1-1 with the successor inputs specifed in + /// `getSuccessorRegions`. + OperandRange getRegionEntryOperands(unsigned index); }]; } def IfOp : Loop_Op<"if", - [SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { + [DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">, RecursiveSideEffects]> { let summary = "if-then-else operation"; let description = [{ The `loop.if` operation represents an if-then-else construct for @@ -385,7 +394,7 @@ let assemblyFormat = "$result attr-dict `:` type($result)"; } -def YieldOp : Loop_Op<"yield", [NoSideEffect, Terminator]> { +def YieldOp : Loop_Op<"yield", [NoSideEffect, ReturnLike, Terminator]> { let summary = "loop yield and termination operation"; let description = [{ "loop.yield" yields an SSA value from a loop dialect op region and diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -19,6 +19,10 @@ namespace mlir { class BranchOpInterface; +//===----------------------------------------------------------------------===// +// BranchOpInterface +//===----------------------------------------------------------------------===// + namespace detail { /// Erase an operand from a branch operation that is used as a successor /// operand. `operandIndex` is the operand within `operands` to be erased. @@ -37,7 +41,67 @@ Optional operands); } // namespace detail +//===----------------------------------------------------------------------===// +// RegionBranchOpInterface +//===----------------------------------------------------------------------===// + +/// This class represents a successor of a region. A region successor can either +/// be another region, or the parent operation. If the successor is a region, +/// this class accepts the destination region, as well as a set of arguments +/// from that region that will be populated by values from the current region. +/// If the successor is the parent operation, this class accepts an optional set +/// of results that will be populated by values from the current region. +class RegionSuccessor { +public: + /// Initialize a successor that branches to another region of the parent + /// operation. + RegionSuccessor(Region *region, Block::BlockArgListType regionInputs = {}) + : successor(region), successorInputs(regionInputs) {} + /// Initialize a successor that branches back to/out of the parent operation. + RegionSuccessor(Optional results = {}) + : successor(nullptr), + successorInputs(results ? ValueRange(*results) : ValueRange()) {} + + /// Return the given region successor. Returns nullptr if the successor is the + /// parent operation. + Region *getSuccessor() const { return successor; } + + /// Return the inputs to the successor that are remapped by the exit values of + /// the current region. + ValueRange getSuccessorInputs() const { return successorInputs; } + +private: + Region *successor; + ValueRange successorInputs; +}; + +//===----------------------------------------------------------------------===// +// ControlFlow Interfaces +//===----------------------------------------------------------------------===// + #include "mlir/Interfaces/ControlFlowInterfaces.h.inc" + +//===----------------------------------------------------------------------===// +// ControlFlow Traits +//===----------------------------------------------------------------------===// + +namespace OpTrait { +/// This trait indicates that a terminator operation is "return-like". This +/// means that it exits its current region and forwards its operands as "exit" +/// values to the parent region. Operations with this trait are not permitted to +/// contain successors. +template +struct ReturnLike : public TraitBase { + static LogicalResult verifyTrait(Operation *op) { + static_assert(ConcreteType::template hasTrait(), + "expected operation to be a terminator"); + static_assert(ConcreteType::template hasTrait(), + "expected operation to have zero successors"); + return success(); + } +}; +} // namespace OpTrait + } // end namespace mlir #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES_H diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -90,4 +90,51 @@ }]; } +//===----------------------------------------------------------------------===// +// RegionBranchOpInterface +//===----------------------------------------------------------------------===// + +def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> { + let description = [{ + This interface provides information for region operations that contain + branching behavior between held regions, i.e. this interface allows for + expressing control flow information for region holding operations. + }]; + let methods = [ + InterfaceMethod<[{ + Return the operands used for the region at `index`, which was specified + as a successor by `getSuccessorRegions`. when entering. These operands + should correspond 1-1 with the successor inputs specifed in + `getSuccessorRegions`. + }], + "OperandRange", "getRegionEntryOperands", (ins "unsigned":$index), + [{}], /*defaultImplementation=*/[{ + auto operandEnd = this->getOperation()->operand_end(); + return OperandRange({operandEnd, operandEnd}); + }] + >, + InterfaceMethod<[{ + Given the region at `index`, or the parent operation if `index` is None, + return the successor regions. These are the regions that may be selected + during the flow of control. If `index` is None, `operands` is a set of + optional attributes that correspond to a constant value for each + operand of this operation, or null if that operand is not a constant. If + `index` is valid, `operands` corresponds to the exit values of the + region at `index`. Only a region, i.e. a valid `index`, may use the + parent operation as a successor. + }], + "void", "getSuccessorRegions", + (ins "Optional":$index, "ArrayRef":$operands, + "SmallVectorImpl &":$regions) + > + ]; +} + +//===----------------------------------------------------------------------===// +// ControlFlow Traits +//===----------------------------------------------------------------------===// + +// Op is "return-like". +def ReturnLike : NativeOpTrait<"ReturnLike">; + #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -196,6 +196,39 @@ return dyn_cast_or_null(containingOp); } +/// Return the operands used for the region at `index`, which was specified +/// as a successor by `getSuccessorRegions`. when entering. These operands +/// should correspond 1-1 with the successor inputs specifed in +/// `getSuccessorRegions`. +OperandRange ForOp::getRegionEntryOperands(unsigned index) { + assert(index == 0 && "invalid region index"); + + // The initial operands map to the loop arguments after the induction + // variable. + return initArgs(); +} + +/// Given the region at `index`, or the parent operation if `index` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void ForOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // If the predecessor is the ForOp, branch into the body using the iterator + // arguments. + if (!index) { + regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs())); + return; + } + + // Otherwise, the loop may branch back to itself or the parent operation. + assert(index.getValue() == 0 && "expected loop region"); + regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs())); + regions.push_back(RegionSuccessor(getResults())); +} + //===----------------------------------------------------------------------===// // IfOp //===----------------------------------------------------------------------===// @@ -298,6 +331,37 @@ p.printOptionalAttrDict(op.getAttrs()); } +/// Given the region at `index`, or the parent operation if `index` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void IfOp::getSuccessorRegions(Optional index, + ArrayRef operands, + SmallVectorImpl ®ions) { + // The `then` and the `else` region branch back to the parent operation, + if (index) { + regions.push_back(RegionSuccessor(getResults())); + return; + } + + // Otherwise, the successor is dependent on the condition. + bool condition; + if (auto condAttr = operands.front().dyn_cast_or_null()) { + condition = condAttr.getValue().isOneValue(); + } else if (auto condAttr = operands.front().dyn_cast_or_null()) { + condition = condAttr.getValue(); + } else { + // If the condition isn't constant, both regions may be executed. + regions.push_back(RegionSuccessor(&thenRegion())); + regions.push_back(RegionSuccessor(&elseRegion())); + return; + } + + // Add the successor regions using the condition. + regions.push_back(RegionSuccessor(condition ? &thenRegion() : &elseRegion())); +} + //===----------------------------------------------------------------------===// // ParallelOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/SCCP.cpp b/mlir/lib/Transforms/SCCP.cpp --- a/mlir/lib/Transforms/SCCP.cpp +++ b/mlir/lib/Transforms/SCCP.cpp @@ -138,13 +138,30 @@ LogicalResult replaceWithConstant(OpBuilder &builder, OperationFolder &folder, Value value); + /// Visit the users of the given IR. + template + void visitUsers(T &value) { + for (Operation *user : value.getUsers()) + if (isBlockExecutable(user->getBlock())) + visitOperation(user); + } + /// Visit the given operation and compute any necessary lattice state. void visitOperation(Operation *op); /// Visit the given operation, which defines regions, and compute any /// necessary lattice state. This also resolves the lattice state of both the /// operation results and any nested regions. - void visitRegionOperation(Operation *op); + void visitRegionOperation(Operation *op, + ArrayRef constantOperands); + + /// Visit the given set of region successors, computing any necessary lattice + /// state. The provided function returns the input operands to the region at + /// the given index. If the index is 'None', the input operands correspond to + /// the parent operation results. + void visitRegionSuccessors( + Operation *parentOp, ArrayRef regionSuccessors, + function_ref)> getInputsForRegion); /// Visit the given terminator operation and compute any necessary lattice /// state. @@ -186,6 +203,16 @@ markAllOverdefined(values); opWorklist.push_back(op); } + template + void markAllOverdefinedAndVisitUsers(ValuesT values) { + for (auto value : values) { + auto &lattice = latticeValues[value]; + if (!lattice.isOverdefined()) { + lattice.markOverdefined(); + visitUsers(value); + } + } + } /// Returns true if the given value was marked as overdefined. bool isOverdefined(Value value) const; @@ -229,15 +256,8 @@ void SCCPSolver::solve() { while (!blockWorklist.empty() || !opWorklist.empty()) { // Process any operations in the op worklist. - while (!opWorklist.empty()) { - Operation *op = opWorklist.pop_back_val(); - - // Visit all of the live users to propagate changes to this operation. - for (Operation *user : op->getUsers()) { - if (isBlockExecutable(user->getBlock())) - visitOperation(user); - } - } + while (!opWorklist.empty()) + visitUsers(*opWorklist.pop_back_val()); // Process any blocks in the block worklist. while (!blockWorklist.empty()) @@ -329,7 +349,7 @@ // Process region holding operations. The region visitor processes result // values, so we can exit afterwards. if (op->getNumRegions()) - return visitRegionOperation(op); + return visitRegionOperation(op, operandConstants); // If this op produces no results, it can't produce any constants. if (op->getNumResults() == 0) @@ -378,25 +398,145 @@ } } -void SCCPSolver::visitRegionOperation(Operation *op) { - for (Region ®ion : op->getRegions()) { - if (region.empty()) +void SCCPSolver::visitRegionOperation(Operation *op, + ArrayRef constantOperands) { + // Check to see if we can reason about the internal control flow of this + // region operation. + auto regionInterface = dyn_cast(op); + if (!regionInterface) { + // If we can't, conservatively mark all regions as executable. + for (Region ®ion : op->getRegions()) { + if (region.empty()) + continue; + Block *entryBlock = ®ion.front(); + markBlockExecutable(entryBlock); + markAllOverdefined(entryBlock->getArguments()); + } + + // Don't try to simulate the results of a region operation as we can't + // guarantee that folding will be out-of-place. We don't allow in-place + // folds as the desire here is for simulated execution, and not general + // folding. + return markAllOverdefined(op, op->getResults()); + } + + // Check to see which regions are executable. + SmallVector successors; + regionInterface.getSuccessorRegions(/*index=*/llvm::None, constantOperands, + successors); + + // If the interface identified that no region will be executed. Mark + // any results of this operation as overdefined, as we can't reason about + // them. + // TODO: If we had an interface to detect pass through operands, we could + // resolve some results based on the lattice state of the operands. We could + // also allow for the parent operation to have itself as a region successor. + if (successors.empty()) + return markAllOverdefined(op, op->getResults()); + return visitRegionSuccessors(op, successors, [&](Optional index) { + assert(index && "expected valid region index"); + return regionInterface.getRegionEntryOperands(*index); + }); +} + +void SCCPSolver::visitRegionSuccessors( + Operation *parentOp, ArrayRef regionSuccessors, + function_ref)> getInputsForRegion) { + for (const RegionSuccessor &it : regionSuccessors) { + Region *region = it.getSuccessor(); + ValueRange succArgs = it.getSuccessorInputs(); + + // Check to see if this is the parent operation. + if (!region) { + ResultRange results = parentOp->getResults(); + if (llvm::all_of(results, [&](Value res) { return isOverdefined(res); })) + continue; + + // Mark the results outside of the input range as overdefined. + if (succArgs.size() != results.size()) { + opWorklist.push_back(parentOp); + if (succArgs.empty()) + return markAllOverdefined(results); + + unsigned firstResIdx = succArgs[0].cast().getResultNumber(); + markAllOverdefined(results.take_front(firstResIdx)); + markAllOverdefined(results.drop_front(firstResIdx + succArgs.size())); + } + + // Update the lattice for any operation results. + OperandRange operands = getInputsForRegion(/*index=*/llvm::None); + for (auto it : llvm::zip(succArgs, operands)) + meet(parentOp, latticeValues[std::get<0>(it)], + latticeValues[std::get<1>(it)]); + return; + } + if (region->empty()) continue; - Block *entryBlock = ®ion.front(); + Block *entryBlock = ®ion->front(); markBlockExecutable(entryBlock); - markAllOverdefined(entryBlock->getArguments()); - } - // Don't try to simulate the results of a region operation as we can't - // guarantee that folding will be out-of-place. We don't allow in-place folds - // as the desire here is for simulated execution, and not general folding. - return markAllOverdefined(op, op->getResults()); + // If all of the arguments are already overdefined, the arguments have + // already been fully resolved. + auto arguments = entryBlock->getArguments(); + if (llvm::all_of(arguments, [&](Value arg) { return isOverdefined(arg); })) + continue; + + // Mark any arguments that do not receive inputs as overdefined, we won't be + // able to discern if they are constant. + if (succArgs.size() != arguments.size()) { + if (succArgs.empty()) { + markAllOverdefined(arguments); + continue; + } + + unsigned firstArgIdx = succArgs[0].cast().getArgNumber(); + markAllOverdefinedAndVisitUsers(arguments.take_front(firstArgIdx)); + markAllOverdefinedAndVisitUsers( + arguments.drop_front(firstArgIdx + succArgs.size())); + } + + // Update the lattice for arguments that have inputs from the predecessor. + OperandRange succOperands = getInputsForRegion(region->getRegionNumber()); + for (auto it : llvm::zip(succArgs, succOperands)) { + LatticeValue &argLattice = latticeValues[std::get<0>(it)]; + if (argLattice.meet(latticeValues[std::get<1>(it)])) + visitUsers(std::get<0>(it)); + } + } } void SCCPSolver::visitTerminatorOperation( Operation *op, ArrayRef constantOperands) { - if (op->getNumSuccessors() == 0) - return; + // If this operation has no successors, we treat it as an exiting terminator. + if (op->getNumSuccessors() == 0) { + // Check to see if the parent tracks region control flow. + Region *parentRegion = op->getParentRegion(); + Operation *parentOp = parentRegion->getParentOp(); + auto regionInterface = dyn_cast(parentOp); + if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) + return; + + // Query the set of successors from the current region. + SmallVector regionSuccessors; + regionInterface.getSuccessorRegions(parentRegion->getRegionNumber(), + constantOperands, regionSuccessors); + if (regionSuccessors.empty()) + return; + + // If this terminator is not "region-like", conservatively mark all of the + // successor values as overdefined. + if (!op->hasTrait()) { + for (auto &it : regionSuccessors) + markAllOverdefinedAndVisitUsers(it.getSuccessorInputs()); + return; + } + + // Otherwise, propagate the operand lattice states to each of the + // successors. + OperandRange operands = op->getOperands(); + return visitRegionSuccessors(parentOp, regionSuccessors, + [&](Optional) { return operands; }); + } // Try to resolve to a specific successor with the constant operands. if (auto branch = dyn_cast(op)) { @@ -464,11 +604,8 @@ } // If the lattice was updated, visit any executable users of the argument. - if (updatedLattice) { - for (Operation *user : arg.getUsers()) - if (isBlockExecutable(user->getBlock())) - visitOperation(user); - } + if (updatedLattice) + visitUsers(arg); } bool SCCPSolver::markBlockExecutable(Block *block) { diff --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/sccp-structured.mlir @@ -0,0 +1,132 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline="func(sccp)" -split-input-file | FileCheck %s + +/// Check that a constant is properly propagated when only one edge is taken. + +// CHECK-LABEL: func @simple( +func @simple(%arg0 : i32) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK-NOT: loop.if + // CHECK: return %[[CST]] : i32 + + %cond = constant true + %res = loop.if %cond -> (i32) { + %1 = constant 1 : i32 + loop.yield %1 : i32 + } else { + loop.yield %arg0 : i32 + } + return %res : i32 +} + +/// Check that a constant is properly propagated when both edges produce the +/// same value. + +// CHECK-LABEL: func @simple_both_same( +func @simple_both_same(%cond : i1) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK-NOT: loop.if + // CHECK: return %[[CST]] : i32 + + %res = loop.if %cond -> (i32) { + %1 = constant 1 : i32 + loop.yield %1 : i32 + } else { + %2 = constant 1 : i32 + loop.yield %2 : i32 + } + return %res : i32 +} + +/// Check that the arguments go to overdefined if the branch cannot detect when +/// a specific successor is taken. + +// CHECK-LABEL: func @overdefined_unknown_condition( +func @overdefined_unknown_condition(%cond : i1, %arg0 : i32) -> i32 { + // CHECK: %[[RES:.*]] = loop.if + // CHECK: return %[[RES]] : i32 + + %res = loop.if %cond -> (i32) { + %1 = constant 1 : i32 + loop.yield %1 : i32 + } else { + loop.yield %arg0 : i32 + } + return %res : i32 +} + +/// Check that the arguments go to overdefined if there are conflicting +/// constants. + +// CHECK-LABEL: func @overdefined_different_constants( +func @overdefined_different_constants(%cond : i1) -> i32 { + // CHECK: %[[RES:.*]] = loop.if + // CHECK: return %[[RES]] : i32 + + %res = loop.if %cond -> (i32) { + %1 = constant 1 : i32 + loop.yield %1 : i32 + } else { + %2 = constant 2 : i32 + loop.yield %2 : i32 + } + return %res : i32 +} + +/// Check that arguments are properly merged across loop-like control flow. + +// CHECK-LABEL: func @simple_loop( +func @simple_loop(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 { + // CHECK: %[[CST:.*]] = constant 0 : i32 + // CHECK-NOT: loop.for + // CHECK: return %[[CST]] : i32 + + %s0 = constant 0 : i32 + %result = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0) -> (i32) { + %sn = addi %si, %si : i32 + loop.yield %sn : i32 + } + return %result : i32 +} + +/// Check that arguments go to overdefined when loop backedges produce a +/// conflicting value. + +// CHECK-LABEL: func @loop_overdefined( +func @loop_overdefined(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 { + // CHECK: %[[RES:.*]] = loop.for + // CHECK: return %[[RES]] : i32 + + %s0 = constant 1 : i32 + %result = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %s0) -> (i32) { + %sn = addi %si, %si : i32 + loop.yield %sn : i32 + } + return %result : i32 +} + +/// Test that we can properly propagate within inner control, and in situations +/// where the executable edges within the CFG are sensitive to the current state +/// of the analysis. + +// CHECK-LABEL: func @loop_inner_control_flow( +func @loop_inner_control_flow(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 { + // CHECK: %[[CST:.*]] = constant 1 : i32 + // CHECK-NOT: loop.for + // CHECK-NOT: loop.if + // CHECK: return %[[CST]] : i32 + + %cst_1 = constant 1 : i32 + %result = loop.for %i0 = %arg0 to %arg1 step %arg2 iter_args(%si = %cst_1) -> (i32) { + %cst_20 = constant 20 : i32 + %cond = cmpi "ult", %si, %cst_20 : i32 + %inner_res = loop.if %cond -> (i32) { + %1 = constant 1 : i32 + loop.yield %1 : i32 + } else { + %si_inc = addi %si, %cst_1 : i32 + loop.yield %si_inc : i32 + } + loop.yield %inner_res : i32 + } + return %result : i32 +}