diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -445,6 +445,10 @@ llvm::find(regionSuccessor->getSuccessorInputs(), argValue) .getIndex(); + std::optional successorRegionNumber; + if (Region *successorRegion = regionSuccessor->getSuccessor()) + successorRegionNumber = successorRegion->getRegionNumber(); + // Iterate over all immediate terminator operations to introduce // new buffer allocations. Thereby, the appropriate terminator operand // will be adjusted to point to the newly allocated buffer instead. @@ -453,7 +457,7 @@ // Get the actual mutable operands for this terminator op. auto terminatorOperands = terminator.getMutableSuccessorOperands( - region.getRegionNumber()); + successorRegionNumber); // Extract the source value from the current terminator. // This conversion needs to exist on a separate line due to a // bug in GCC conversion analysis. diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -283,6 +283,9 @@ MutableOperandRange ConditionOp::getMutableSuccessorOperands(std::optional index) { + assert((!index || index == getParentOp().getAfter().getRegionNumber()) && + "condition op can only exit the loop or branch to the after" + "region"); // Pass all operands except the condition to the successor region. return getArgsMutable(); } diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -84,6 +84,23 @@ // RegionBranchOpInterface //===----------------------------------------------------------------------===// +static InFlightDiagnostic & +printRegionEdgeName(InFlightDiagnostic &diag, std::optional sourceNo, + std::optional succRegionNo) { + diag << "from "; + if (sourceNo) + diag << "Region #" << sourceNo.value(); + else + diag << "parent operands"; + + diag << " to "; + if (succRegionNo) + diag << "Region #" << succRegionNo.value(); + else + diag << "parent results"; + return diag; +} + /// Verify that types match along all region control flow edges originating from /// `sourceNo` (region # if source is a region, std::nullopt if source is parent /// op). `getInputsTypesForRegion` is a function that returns the types of the @@ -92,7 +109,7 @@ /// the match itself). static LogicalResult verifyTypesAlongAllEdges( Operation *op, std::optional sourceNo, - function_ref(std::optional)> + function_ref(std::optional)> getInputsTypesForRegion) { auto regionInterface = cast(op); @@ -104,32 +121,17 @@ if (!succ.isParent()) succRegionNo = succ.getSuccessor()->getRegionNumber(); - auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & { - diag << "from "; - if (sourceNo) - diag << "Region #" << sourceNo.value(); - else - diag << "parent operands"; - - diag << " to "; - if (succRegionNo) - diag << "Region #" << succRegionNo.value(); - else - diag << "parent results"; - return diag; - }; - - std::optional sourceTypes = - getInputsTypesForRegion(succRegionNo); - if (!sourceTypes.has_value()) - continue; + FailureOr sourceTypes = getInputsTypesForRegion(succRegionNo); + if (failed(sourceTypes)) + return failure(); TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes(); if (sourceTypes->size() != succInputsTypes.size()) { InFlightDiagnostic diag = op->emitOpError(" region control flow edge "); - return printEdgeName(diag) << ": source has " << sourceTypes->size() - << " operands, but target successor needs " - << succInputsTypes.size(); + return printRegionEdgeName(diag, sourceNo, succRegionNo) + << ": source has " << sourceTypes->size() + << " operands, but target successor needs " + << succInputsTypes.size(); } for (const auto &typesIdx : @@ -138,7 +140,7 @@ Type inputType = std::get<1>(typesIdx.value()); if (!regionInterface.areTypesCompatible(sourceType, inputType)) { InFlightDiagnostic diag = op->emitOpError(" along control flow edge "); - return printEdgeName(diag) + return printRegionEdgeName(diag, sourceNo, succRegionNo) << ": source type #" << typesIdx.index() << " " << sourceType << " should match input type #" << typesIdx.index() << " " << inputType; @@ -177,45 +179,48 @@ for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { Region ®ion = op->getRegion(regionNo); - // Since there can be multiple `ReturnLike` terminators or others - // implementing the `RegionBranchTerminatorOpInterface`, all should have the - // same operand types when passing them to the same region. - - std::optional regionReturnOperands; - for (Block &block : region) { - auto terminator = - dyn_cast(block.getTerminator()); - if (!terminator) - continue; - - OperandRange terminatorOperands = - terminator.getSuccessorOperands(regionNo); - if (!regionReturnOperands) { - regionReturnOperands = terminatorOperands; - continue; - } + // Since there can be multiple terminators implementing the + // `RegionBranchTerminatorOpInterface`, all should have the same operand + // types when passing them to the same region. - // Found more than one ReturnLike terminator. Make sure the operand types - // match with the first one. - if (!areTypesCompatible(regionReturnOperands->getTypes(), - terminatorOperands.getTypes())) - return op->emitOpError("Region #") - << regionNo - << " operands mismatch between return-like terminators"; - } + SmallVector regionReturnOps; + for (Block &block : region) + if (auto terminator = dyn_cast( + block.getTerminator())) + regionReturnOps.push_back(terminator); - auto inputTypesFromRegion = - [&](std::optional regionNo) -> std::optional { - // If there is no return-like terminator, the op itself should verify - // type consistency. - if (!regionReturnOperands) - return std::nullopt; + // If there is no return-like terminator, the op itself should verify + // type consistency. + if (regionReturnOps.empty()) + continue; + + auto inputTypesForRegion = + [&](std::optional succRegionNo) -> FailureOr { + std::optional regionReturnOperands; + for (RegionBranchTerminatorOpInterface regionReturnOp : regionReturnOps) { + auto terminatorOperands = + regionReturnOp.getSuccessorOperands(succRegionNo); + + if (!regionReturnOperands) { + regionReturnOperands = terminatorOperands; + continue; + } + + // Found more than one ReturnLike terminator. Make sure the operand + // types match with the first one. + if (!areTypesCompatible(regionReturnOperands->getTypes(), + terminatorOperands.getTypes())) { + InFlightDiagnostic diag = op->emitOpError(" along control flow edge"); + return printRegionEdgeName(diag, regionNo, succRegionNo) + << " operands mismatch between return-like terminators"; + } + } // All successors get the same set of operand types. return TypeRange(regionReturnOperands->getTypes()); }; - if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) + if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesForRegion))) return failure(); } diff --git a/mlir/test/IR/test-region-branch-op-verifier.mlir b/mlir/test/IR/test-region-branch-op-verifier.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/test-region-branch-op-verifier.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s + +func.func @test_ops_verify(%arg: i32) -> f32 { + %0 = "test.constant"() { value = 5.3 : f32 } : () -> f32 + %1 = test.loop_block %arg : (i32) -> f32 { + ^bb0(%arg1 : i32): + test.loop_block_term iter %arg exit %0 + } + return %1 : f32 +} diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -980,6 +980,37 @@ invocationBounds.emplace_back(1, 1); } +//===----------------------------------------------------------------------===// +// LoopBlockOp +//===----------------------------------------------------------------------===// + +void LoopBlockOp::getSuccessorRegions( + std::optional index, SmallVectorImpl ®ions) { + regions.emplace_back(&getBody(), getBody().getArguments()); + if (!index) + return; + + regions.emplace_back((*this)->getResults()); +} + +OperandRange +LoopBlockOp::getEntrySuccessorOperands(std::optional index) { + assert(index == 0); + return getInitMutable(); +} + +//===----------------------------------------------------------------------===// +// LoopBlockTerminatorOp +//===----------------------------------------------------------------------===// + +MutableOperandRange LoopBlockTerminatorOp::getMutableSuccessorOperands( + std::optional index) { + assert(!index || index == 0); + if (!index) + return getExitArgMutable(); + return getNextIterArgMutable(); +} + //===----------------------------------------------------------------------===// // SingleNoTerminatorCustomAsmOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2084,6 +2084,30 @@ let regions = (region AnyRegion:$region); } +def LoopBlockOp : TEST_Op<"loop_block", + [DeclareOpInterfaceMethods, RecursiveMemoryEffects]> { + + let results = (outs F32:$floatResult); + let arguments = (ins I32:$init); + let regions = (region SizedRegion<1>:$body); + + let assemblyFormat = [{ + $init `:` functional-type($init, $floatResult) $body + attr-dict-with-keyword + }]; +} + +def LoopBlockTerminatorOp : TEST_Op<"loop_block_term", + [DeclareOpInterfaceMethods, Pure, + Terminator]> { + let arguments = (ins I32:$nextIterArg, F32:$exitArg); + + let assemblyFormat = [{ + `iter` $nextIterArg `exit` $exitArg attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Test TableGen generated build() methods //===----------------------------------------------------------------------===//