diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -309,7 +309,7 @@ /// correspond to the loop iterator operands, i.e., those exclusing the /// induction variable. LoopOp only has one region, so 0 is the only valid /// value for `index`. - OperandRange getSuccessorEntryOperands(unsigned index); + OperandRange getSuccessorEntryOperands(Optional index); }]; let hasCanonicalizer = 1; @@ -955,7 +955,7 @@ let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after); let extraClassDeclaration = [{ - OperandRange getSuccessorEntryOperands(unsigned index); + OperandRange getSuccessorEntryOperands(Optional index); ConditionOp getConditionOp(); YieldOp getYieldOp(); Block::BlockArgListType getBeforeArguments(); diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -134,12 +134,14 @@ InterfaceMethod<[{ Returns the operands of this operation used as the entry arguments when entering the region at `index`, which was specified as a successor of - this operation by `getSuccessorRegions`. These operands should - correspond 1-1 with the successor inputs specified in + this operation by `getSuccessorRegions`, or the operands forwarded to + the operation's results when it branches back to itself. These operands + should correspond 1-1 with the successor inputs specified in `getSuccessorRegions`. }], "::mlir::OperandRange", "getSuccessorEntryOperands", - (ins "unsigned":$index), [{}], /*defaultImplementation=*/[{ + (ins "::llvm::Optional":$index), [{}], + /*defaultImplementation=*/[{ auto operandEnd = this->getOperation()->operand_end(); return ::mlir::OperandRange(operandEnd, operandEnd); }] diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -78,12 +78,12 @@ if (region) { // Determine the actual region number from the passed region. regionIndex = region->getRegionNumber(); - if (Optional operandIndex = - getOperandIndexIfPred(/*predIndex=*/llvm::None)) { - collectUnderlyingAddressValues( - branch.getSuccessorEntryOperands(*regionIndex)[*operandIndex], - maxDepth, visited, output); - } + } + if (Optional operandIndex = + getOperandIndexIfPred(/*predIndex=*/llvm::None)) { + collectUnderlyingAddressValues( + branch.getSuccessorEntryOperands(regionIndex)[*operandIndex], maxDepth, + visited, output); } // Check branches from each child region. Operation *op = branch.getOperation(); diff --git a/mlir/lib/Analysis/DataFlowAnalysis.cpp b/mlir/lib/Analysis/DataFlowAnalysis.cpp --- a/mlir/lib/Analysis/DataFlowAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlowAnalysis.cpp @@ -470,11 +470,10 @@ // also allow for the parent operation to have itself as a region successor. if (successors.empty()) return markAllPessimisticFixpoint(branch, branch->getResults()); - return visitRegionSuccessors( - branch, successors, operandLattices, [&](Optional index) { - assert(index && "expected valid region index"); - return branch.getSuccessorEntryOperands(*index); - }); + return visitRegionSuccessors(branch, successors, operandLattices, + [&](Optional index) { + return branch.getSuccessorEntryOperands(index); + }); } void ForwardDataFlowSolver::visitRegionSuccessors( diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1731,11 +1731,11 @@ /// correspond to the loop iterator operands, i.e., those excluding the /// induction variable. AffineForOp only has one region, so zero is the only /// valid value for `index`. -OperandRange AffineForOp::getSuccessorEntryOperands(unsigned index) { - assert(index == 0 && "invalid region index"); +OperandRange AffineForOp::getSuccessorEntryOperands(Optional index) { + assert(!index || *index == 0 && "invalid region index"); // The initial operands map to the loop arguments after the induction - // variable. + // variable or are forwarded to the results when the trip count is zero. return getIterOperands(); } diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -59,8 +59,8 @@ constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; -OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) { - assert(index == 0 && "invalid region index"); +OperandRange ExecuteOp::getSuccessorEntryOperands(Optional index) { + assert(index && *index == 0 && "invalid region index"); return operands(); } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -473,8 +473,8 @@ /// correspond to the loop iterator operands, i.e., those excluding the /// induction variable. LoopOp only has one region, so 0 is the only valid value /// for `index`. -OperandRange ForOp::getSuccessorEntryOperands(unsigned index) { - assert(index == 0 && "invalid region index"); +OperandRange ForOp::getSuccessorEntryOperands(Optional index) { + assert(index && *index == 0 && "invalid region index"); // The initial operands map to the loop arguments after the induction // variable. @@ -2605,8 +2605,8 @@ // WhileOp //===----------------------------------------------------------------------===// -OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) { - assert(index == 0 && +OperandRange WhileOp::getSuccessorEntryOperands(Optional index) { + assert(index && *index == 0 && "WhileOp is expected to branch only to the first region"); return getInits(); diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -312,8 +312,9 @@ } } -OperandRange transform::SequenceOp::getSuccessorEntryOperands(unsigned index) { - assert(index == 0 && "unexpected region index"); +OperandRange +transform::SequenceOp::getSuccessorEntryOperands(Optional index) { + assert(index && *index == 0 && "unexpected region index"); if (getOperation()->getNumOperands() == 1) return getOperation()->getOperands(); return OperandRange(getOperation()->operand_end(), diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -8,6 +8,7 @@ #include +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/SmallPtrSet.h" @@ -151,16 +152,7 @@ auto regionInterface = cast(op); auto inputTypesFromParent = [&](Optional regionNo) -> TypeRange { - if (regionNo.hasValue()) { - return regionInterface.getSuccessorEntryOperands(regionNo.getValue()) - .getTypes(); - } - - // If the successor of a parent op is the parent itself - // RegionBranchOpInterface does not have an API to query what the entry - // operands will be in that case. Vend out the result types of the op in - // that case so that type checking succeeds for this case. - return op->getResultTypes(); + return regionInterface.getSuccessorEntryOperands(regionNo).getTypes(); }; // Verify types along control flow edges originating from the parent. diff --git a/mlir/test/Analysis/test-alias-analysis.mlir b/mlir/test/Analysis/test-alias-analysis.mlir --- a/mlir/test/Analysis/test-alias-analysis.mlir +++ b/mlir/test/Analysis/test-alias-analysis.mlir @@ -191,6 +191,31 @@ // ----- +// CHECK-LABEL: Testing : "region_loop_zero_trip_count" +// CHECK-DAG: alloca_1#0 <-> alloca_2#0: NoAlias +// CHECK-DAG: alloca_1#0 <-> for_alloca#0: MustAlias +// CHECK-DAG: alloca_1#0 <-> for_alloca.region0#0: MayAlias +// CHECK-DAG: alloca_1#0 <-> for_alloca.region0#1: MayAlias + +// CHECK-DAG: alloca_2#0 <-> for_alloca#0: NoAlias +// CHECK-DAG: alloca_2#0 <-> for_alloca.region0#0: MayAlias +// CHECK-DAG: alloca_2#0 <-> for_alloca.region0#1: MayAlias + +// CHECK-DAG: for_alloca#0 <-> for_alloca.region0#0: MayAlias +// CHECK-DAG: for_alloca#0 <-> for_alloca.region0#1: MayAlias + +// CHECK-DAG: for_alloca.region0#0 <-> for_alloca.region0#1: MayAlias +func.func @region_loop_zero_trip_count() attributes {test.ptr = "func"} { + %0 = memref.alloca() {test.ptr = "alloca_1"} : memref + %1 = memref.alloca() {test.ptr = "alloca_2"} : memref + %result = affine.for %i = 0 to 0 iter_args(%si = %0) -> (memref) { + affine.yield %si : memref + } {test.ptr = "for_alloca"} + return +} + +// ----- + // CHECK-LABEL: Testing : "view_like" // CHECK-DAG: alloc_1#0 <-> view#0: NoAlias diff --git a/mlir/test/Transforms/sccp-structured.mlir b/mlir/test/Transforms/sccp-structured.mlir --- a/mlir/test/Transforms/sccp-structured.mlir +++ b/mlir/test/Transforms/sccp-structured.mlir @@ -154,7 +154,7 @@ /// interface as well. // CHECK-LABEL: func @affine_loop_one_iter( -func.func @affine_loop_one_iter(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 { +func.func @affine_loop_one_iter() -> i32 { // CHECK: %[[C1:.*]] = arith.constant 1 : i32 %s0 = arith.constant 0 : i32 %s1 = arith.constant 1 : i32 @@ -167,17 +167,27 @@ } // CHECK-LABEL: func @affine_loop_zero_iter( -func.func @affine_loop_zero_iter(%arg0 : index, %arg1 : index, %arg2 : index) -> i32 { - // This exposes a crash in sccp/forward data flow analysis: https://github.com/llvm/llvm-project/issues/54928 +func.func @affine_loop_zero_iter() -> i32 { + // CHECK: %[[C1:.*]] = arith.constant 1 : i32 + %s1 = arith.constant 1 : i32 + %result = affine.for %i = 0 to 0 iter_args(%si = %s1) -> (i32) { + %sn = arith.addi %si, %si : i32 + affine.yield %sn : i32 + } + // CHECK: return %[[C1]] : i32 + return %result : i32 +} + +// CHECK-LABEL: func @affine_loop_unknown_trip_count( +func.func @affine_loop_unknown_trip_count(%ub: index) -> i32 { // CHECK: %[[C0:.*]] = arith.constant 0 : i32 %s0 = arith.constant 0 : i32 - // %result = affine.for %i = 0 to 0 iter_args(%si = %s0) -> (i32) { - // %sn = arith.addi %si, %si : i32 - // affine.yield %sn : i32 - // } - // return %result : i32 + %result = affine.for %i = 0 to %ub iter_args(%si = %s0) -> (i32) { + %sn = arith.addi %si, %si : i32 + affine.yield %sn : i32 + } // CHECK: return %[[C0]] : i32 - return %s0 : i32 + return %result : i32 } // CHECK-LABEL: func @while_loop_different_arg_count diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -1301,8 +1301,8 @@ parser.getCurrentLocation(), result.operands); } -OperandRange RegionIfOp::getSuccessorEntryOperands(unsigned index) { - assert(index < 2 && "invalid region index"); +OperandRange RegionIfOp::getSuccessorEntryOperands(Optional index) { + assert(index && *index < 2 && "invalid region index"); return getOperands(); } @@ -1339,7 +1339,7 @@ SmallVectorImpl ®ions) { // The parent op branches into the only region, and the region branches back // to the parent op. - if (index) + if (!index) regions.emplace_back(&getRegion()); else regions.emplace_back(getResults()); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2549,7 +2549,8 @@ ::mlir::Block::BlockArgListType getJoinArgs() { return getBody(2)->getArguments(); } - ::mlir::OperandRange getSuccessorEntryOperands(unsigned index); + ::mlir::OperandRange getSuccessorEntryOperands( + ::llvm::Optional index); }]; let hasCustomAssemblyFormat = 1; }