diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -964,7 +964,7 @@ let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; - let hasRegionVerifier = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -207,10 +207,7 @@ let extraClassDeclaration = [{ /// Convenience helper in case none of the operands is known. void getSuccessorRegions(Optional index, - SmallVectorImpl ®ions) { - SmallVector nullAttrs(getOperation()->getNumOperands()); - getSuccessorRegions(index, nullAttrs, regions); - } + SmallVectorImpl ®ions); /// Return `true` if control flow originating from the given region may /// eventually branch back to the same region. (Maybe after passing through diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -576,45 +576,12 @@ if (!regionInterface || !isBlockExecutable(parentOp->getBlock())) return; - // If the branch is a RegionBranchTerminatorOpInterface, - // construct the set of operand lattices as the set of non control-flow - // arguments of the parent and the values this op returns. This allows - // for the correct lattices to be passed to getSuccessorsForOperands() - // in cases such as scf.while. - ArrayRef branchOpLattices = operandLattices; - SmallVector parentLattices; - if (auto regionTerminator = - dyn_cast(op)) { - parentLattices.reserve(regionInterface->getNumOperands()); - for (Value parentOperand : regionInterface->getOperands()) { - AbstractLatticeElement *operandLattice = - analysis.lookupLatticeElement(parentOperand); - if (!operandLattice || operandLattice->isUninitialized()) - return; - parentLattices.push_back(operandLattice); - } - unsigned regionNumber = parentRegion->getRegionNumber(); - OperandRange iterArgs = - regionInterface.getSuccessorEntryOperands(regionNumber); - OperandRange terminatorArgs = - regionTerminator.getSuccessorOperands(regionNumber); - assert(iterArgs.size() == terminatorArgs.size() && - "Number of iteration arguments for region should equal number of " - "those arguments defined by terminator"); - if (!iterArgs.empty()) { - unsigned iterStart = iterArgs.getBeginOperandIndex(); - unsigned terminatorStart = terminatorArgs.getBeginOperandIndex(); - for (unsigned i = 0, e = iterArgs.size(); i < e; ++i) - parentLattices[iterStart + i] = operandLattices[terminatorStart + i]; - } - branchOpLattices = parentLattices; - } // Query the set of successors of the current region using the current // optimistic lattice state. SmallVector regionSuccessors; analysis.getSuccessorsForOperands(regionInterface, parentRegion->getRegionNumber(), - branchOpLattices, regionSuccessors); + operandLattices, regionSuccessors); if (regionSuccessors.empty()) return; @@ -622,11 +589,11 @@ // propagate the operand states to the successors. if (isRegionReturnLike(op)) { auto getOperands = [&](Optional regionIndex) { - // Determine the individual region successor operands for the given + // Determine the individual region successor operands for the given // region index (if any). return *getRegionBranchSuccessorOperands(op, regionIndex); }; - return visitRegionSuccessors(parentOp, regionSuccessors, branchOpLattices, + return visitRegionSuccessors(parentOp, regionSuccessors, operandLattices, getOperands); } diff --git a/mlir/lib/Analysis/IntRangeAnalysis.cpp b/mlir/lib/Analysis/IntRangeAnalysis.cpp --- a/mlir/lib/Analysis/IntRangeAnalysis.cpp +++ b/mlir/lib/Analysis/IntRangeAnalysis.cpp @@ -214,12 +214,24 @@ RegionBranchOpInterface branch, Optional sourceIndex, ArrayRef *> operands, SmallVectorImpl &successors) { - auto toConstantAttr = [&branch](auto enumPair) -> Attribute { - Optional maybeConstValue = - enumPair.value()->getValue().value.getConstantValue(); + // Get a type with which to construct a constant. + auto getOperandType = [branch, sourceIndex](unsigned index) { + // The types of all return-like operations are the same. + if (!sourceIndex) + return branch->getOperand(index).getType(); + + for (Block &block : branch->getRegion(*sourceIndex)) { + Operation *terminator = block.getTerminator(); + if (getRegionBranchSuccessorOperands(terminator, *sourceIndex)) + return terminator->getOperand(index).getType(); + } + return Type(); + }; - if (maybeConstValue) { - return IntegerAttr::get(branch->getOperand(enumPair.index()).getType(), + auto toConstantAttr = [&getOperandType](auto enumPair) -> Attribute { + if (Optional maybeConstValue = + enumPair.value()->getValue().value.getConstantValue()) { + return IntegerAttr::get(getOperandType(enumPair.index()), *maybeConstValue); } return {}; diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -2631,21 +2631,26 @@ void WhileOp::getSuccessorRegions(Optional index, ArrayRef operands, SmallVectorImpl ®ions) { - (void)operands; - + // The parent op always branches to the condition region. if (!index.hasValue()) { regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } assert(*index < 2 && "there are only two regions in a WhileOp"); - if (*index == 0) { - regions.emplace_back(&getAfter(), getAfter().getArguments()); - regions.emplace_back(getResults()); + // The body region always branches back to the condition region. + if (*index == 1) { + regions.emplace_back(&getBefore(), getBefore().getArguments()); return; } - regions.emplace_back(&getBefore(), getBefore().getArguments()); + // Try to narrow the successor to the condition region. + assert(!operands.empty() && "expected at least one operand"); + auto cond = operands[0].dyn_cast_or_null(); + if (!cond || !cond.getValue()) + regions.emplace_back(getResults()); + if (!cond || cond.getValue()) + regions.emplace_back(&getAfter(), getAfter().getArguments()); } /// Parses a `while` op. @@ -2745,7 +2750,7 @@ return nullptr; } -LogicalResult scf::WhileOp::verifyRegions() { +LogicalResult scf::WhileOp::verify() { auto beforeTerminator = verifyAndGetTerminator( *this, getBefore(), "expects the 'before' region to terminate with 'scf.condition'"); diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -9,7 +9,6 @@ #include #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/SmallPtrSet.h" using namespace mlir; @@ -97,15 +96,7 @@ auto regionInterface = cast(op); SmallVector successors; - unsigned numInputs; - if (sourceNo) { - Region &srcRegion = op->getRegion(sourceNo.getValue()); - numInputs = srcRegion.getNumArguments(); - } else { - numInputs = op->getNumOperands(); - } - SmallVector operands(numInputs, nullptr); - regionInterface.getSuccessorRegions(sourceNo, operands, successors); + regionInterface.getSuccessorRegions(sourceNo, successors); for (RegionSuccessor &succ : successors) { Optional succRegionNo; @@ -327,6 +318,27 @@ return isRegionReachable(region, region); } +void RegionBranchOpInterface::getSuccessorRegions( + Optional index, SmallVectorImpl ®ions) { + unsigned numInputs = 0; + if (index) { + // If the predecessor is a region, get the number of operands from an + // exiting terminator in the region. + for (Block &block : getOperation()->getRegion(*index)) { + Operation *terminator = block.getTerminator(); + if (getRegionBranchSuccessorOperands(terminator, *index)) { + numInputs = terminator->getNumOperands(); + break; + } + } + } else { + // Otherwise, use the number of parent operation operands. + numInputs = getOperation()->getNumOperands(); + } + SmallVector operands(numInputs, nullptr); + getSuccessorRegions(index, operands, regions); +} + Region *mlir::getEnclosingRepetitiveRegion(Operation *op) { while (Region *region = op->getParentRegion()) { op = region->getParentOp(); diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -476,7 +476,7 @@ func.func @while_cross_region_type_mismatch() { %true = arith.constant true // expected-error@+1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}} - scf.while : () -> () { + %0 = scf.while : () -> (i1) { scf.condition(%true) %true : i1 } do { ^bb0(%arg0: i32): diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -100,3 +100,22 @@ %0 = test.reflect_bounds %arg0 func.return %0 : index } + +// CHECK-LABEL: func @propagate_across_while_loop() +func.func @propagate_across_while_loop() -> index { + // CHECK-DAG: %[[C0:.*]] = "test.constant"() {value = 0 + // CHECK-DAG: %[[C1:.*]] = "test.constant"() {value = 1 + %0 = test.with_bounds { umin = 0 : index, umax = 0 : index, + smin = 0 : index, smax = 0 : index } + %1 = scf.while : () -> index { + %true = arith.constant true + // CHECK: scf.condition(%{{.*}}) %[[C0]] + scf.condition(%true) %0 : index + } do { + ^bb0(%i1: index): + scf.yield + } + // CHECK: return %[[C1]] + %2 = test.increment %1 + return %2 : index +} diff --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir --- a/mlir/test/Transforms/sccp-structured.mlir +++ b/mlir/test/Transforms/sccp-structured.mlir @@ -179,3 +179,43 @@ // CHECK: return %[[C0]] : i32 return %s0 : i32 } + +// CHECK-LABEL: func @while_loop_different_arg_count +func.func @while_loop_different_arg_count() -> index { + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 + // CHECK-DAG: %[[C1:.*]] = arith.constant 1 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: %[[WHILE:.*]] = scf.while + %0 = scf.while (%arg3 = %c0, %arg4 = %c1) : (index, index) -> index { + %1 = arith.cmpi slt, %arg3, %c1 : index + // CHECK: scf.condition(%[[TRUE]]) %[[C1]] + scf.condition(%1) %arg4 : index + } do { + ^bb0(%arg3: index): + %1 = arith.muli %arg3, %c1 : index + // CHECK: scf.yield %[[C0]], %[[C1]] + scf.yield %c0, %1 : index, index + } + // CHECK: return %[[WHILE]] + return %0 : index +} + +// CHECK-LABEL: func @while_loop_false_condition +func.func @while_loop_false_condition(%arg0 : index) -> index { + // CHECK: %[[C0:.*]] = arith.constant 0 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = arith.muli %arg0, %c0 : index + %1 = scf.while (%arg1 = %0) : (index) -> index { + %2 = arith.cmpi slt, %arg1, %c0 : index + scf.condition(%2) %arg1 : index + } do { + ^bb0(%arg2 : index): + %3 = arith.addi %arg2, %c1 : index + scf.yield %3 : index + } + // CHECK: return %[[C0]] + func.return %1 : index +}