diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -36,8 +36,12 @@ let parser = [{ return ::parse$cppClass(parser, result); }]; } -def ConditionOp : SCF_Op<"condition", - [HasParent<"WhileOp">, NoSideEffect, Terminator]> { +def ConditionOp : SCF_Op<"condition", [ + HasParent<"WhileOp">, + DeclareOpInterfaceMethods, + NoSideEffect, + Terminator +]> { let summary = "loop continuation condition"; let description = [{ This operation accepts the continuation (i.e., inverse of exit) condition diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -86,6 +86,32 @@ ValueRange inputs; }; +//===----------------------------------------------------------------------===// +// RegionBranchTerminatorOpInterface +//===----------------------------------------------------------------------===// + +/// Returns true if the given operation is either annotated with the +/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. +bool isRegionReturnLike(Operation *operation); + +/// Returns the mutable operands that are passed to the region with the given +/// `regionIndex`. If the operation does not implement the +/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the +/// result will be `llvm::None`. In all other cases, the resulting +/// `OperandRange` represents all operands that are passed to the specified +/// successor region. If `regionIndex` is `llvm::None`, all operands that are +/// passed to the parent operation will be returned. +Optional +getMutableRegionBranchSuccessorOperands(Operation *operation, + Optional regionIndex); + +/// Returns the read only operands that are passed to the region with the given +/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more +/// information. +Optional +getRegionBranchSuccessorOperands(Operation *operation, + Optional regionIndex); + //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -174,6 +174,54 @@ }]; } +//===----------------------------------------------------------------------===// +// RegionBranchTerminatorOpInterface +//===----------------------------------------------------------------------===// + +def RegionBranchTerminatorOpInterface : + OpInterface<"RegionBranchTerminatorOpInterface"> { + let description = [{ + This interface provides information for branching terminator operations + in the presence of a parent RegionBranchOpInterface implementation. It + specifies which operands are passed to which successor region. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + Returns a mutable range of operands that are semantically "returned" by + passing them to the region successor given by `index`. If `index` is + None, this function returns the operands that are passed as a result to + the parent operation. + }], + "MutableOperandRange", "getMutableSuccessorOperands", + (ins "Optional":$index) + >, + InterfaceMethod<[{ + Returns a range of operands that are semantically "returned" by passing + them to the region successor given by `index`. If `index` is None, this + function returns the operands that are passed as a result to the parent + operation. + }], + "OperandRange", "getSuccessorOperands", + (ins "Optional":$index), [{}], [{ + ConcreteOp *op = static_cast(this); + return op->getMutableSuccessorOperands(index); + }] + > + ]; + + let verify = [{ + static_assert(ConcreteOp::template hasTrait(), + "expected operation to be a terminator"); + static_assert(ConcreteOp::template hasTrait(), + "expected operation to have zero results"); + static_assert(ConcreteOp::template hasTrait(), + "expected operation to have zero successors"); + return success(); + }]; +} + //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -74,12 +74,14 @@ }; // Check branches from the parent operation. + Optional regionIndex; if (region) { + // Determine the actual region number from the passed region. + regionIndex = region->getRegionNumber(); if (Optional operandIndex = getOperandIndexIfPred(/*predIndex=*/llvm::None)) { collectUnderlyingAddressValues( - branch.getSuccessorEntryOperands( - region->getRegionNumber())[*operandIndex], + branch.getSuccessorEntryOperands(*regionIndex)[*operandIndex], maxDepth, visited, output); } } @@ -89,9 +91,14 @@ if (Optional operandIndex = getOperandIndexIfPred(i)) { for (Block &block : op->getRegion(i)) { Operation *term = block.getTerminator(); - if (term->hasTrait()) { - collectUnderlyingAddressValues(term->getOperand(*operandIndex), - maxDepth, visited, output); + // Try to determine possible region-branch successor operands for the + // current region. + auto successorOperands = + getRegionBranchSuccessorOperands(term, regionIndex); + if (successorOperands.hasValue()) { + collectUnderlyingAddressValues( + successorOperands.getValue()[*operandIndex], maxDepth, visited, + output); } else if (term->getNumSuccessors()) { // Otherwise, if this terminator may exit the region we can't make // any assumptions about which values get passed. diff --git a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp @@ -101,12 +101,20 @@ regionInterface.getSuccessorRegions(region.getRegionNumber(), successorRegions); for (RegionSuccessor &successorRegion : successorRegions) { + // Determine the current region index (if any). + Optional regionIndex; + if (successorRegion.getSuccessor()) + regionIndex = successorRegion.getSuccessor()->getRegionNumber(); // Iterate over all immediate terminator operations and wire the // successor inputs with the operands of each terminator. for (Block &block : region) { for (Operation &operation : block) { - if (operation.hasTrait()) - registerDependencies(operation.getOperands(), + // Try to get all region branch successor operands and wire them + // with the successor inputs. + auto successorOperands = + getRegionBranchSuccessorOperands(&operation, regionIndex); + if (successorOperands.hasValue()) + registerDependencies(successorOperands.getValue(), successorRegion.getSuccessorInputs()); } } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -162,6 +162,16 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// ConditionOp +//===----------------------------------------------------------------------===// + +MutableOperandRange +ConditionOp::getMutableSuccessorOperands(Optional index) { + // Pass all operands except the condition to the successor region. + return argsMutable(); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -2100,18 +2110,6 @@ if (!beforeTerminator) return failure(); - TypeRange trailingTerminatorOperands = beforeTerminator.args().getTypes(); - if (failed(verifyTypeRangesMatch(op, trailingTerminatorOperands, - op.after().getArgumentTypes(), - "trailing operands of the 'before' block " - "terminator and 'after' region arguments"))) - return failure(); - - if (failed(verifyTypeRangesMatch( - op, trailingTerminatorOperands, op.getResultTypes(), - "trailing operands of the 'before' block terminator and op results"))) - return failure(); - auto afterTerminator = verifyAndGetTerminator( op, op.after(), "expects the 'after' region to terminate with 'scf.yield'"); diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -176,25 +176,27 @@ for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) { Region ®ion = op->getRegion(regionNo); - // Since the interface cannot distinguish between different ReturnLike - // ops within the region branching to different successors, all ReturnLike - // ops in this region should have the same operand types. We will then use - // one of them as the representative for type matching. + // Since there can be multiple `ReturnLike` terminators or others + // implementing the `RegionBranchTerminatorOpInterface`, all should have the + // same operand types when passing them to the same region. - Operation *regionReturn = nullptr; + Optional regionReturnOperands; for (Block &block : region) { Operation *terminator = block.getTerminator(); - if (!terminator->hasTrait()) + auto terminatorOperands = + getRegionBranchSuccessorOperands(terminator, regionNo); + if (!terminatorOperands) continue; - if (!regionReturn) { - regionReturn = terminator; + if (!regionReturnOperands) { + regionReturnOperands = terminatorOperands; continue; } // Found more than one ReturnLike terminator. Make sure the operand types // match with the first one. - if (regionReturn->getOperandTypes() != terminator->getOperandTypes()) + if (regionReturnOperands.getValue().getTypes() != + terminatorOperands.getValue().getTypes()) return op->emitOpError("Region #") << regionNo << " operands mismatch between return-like terminators"; @@ -204,11 +206,11 @@ [&](Optional regionNo) -> Optional { // If there is no return-like terminator, the op itself should verify // type consistency. - if (!regionReturn) + if (!regionReturnOperands) return llvm::None; - // All successors get the same set of operands. - return TypeRange(regionReturn->getOperands().getTypes()); + // All successors get the same set of operand types. + return TypeRange(regionReturnOperands.getValue().getTypes()); }; if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion))) @@ -217,3 +219,46 @@ return success(); } + +//===----------------------------------------------------------------------===// +// RegionBranchTerminatorOpInterface +//===----------------------------------------------------------------------===// + +/// Returns true if the given operation is either annotated with the +/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. +bool mlir::isRegionReturnLike(Operation *operation) { + return dyn_cast(operation) || + (operation && operation->hasTrait()); +} + +/// Returns the mutable operands that are passed to the region with the given +/// `regionIndex`. If the operation does not implement the +/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the +/// result will be `llvm::None`. In all other cases, the resulting +/// `OperandRange` represents all operands that are passed to the specified +/// successor region. If `regionIndex` is `llvm::None`, all operands that are +/// passed to the parent operation will be returned. +Optional +mlir::getMutableRegionBranchSuccessorOperands(Operation *operation, + Optional regionIndex) { + // Try to query a RegionBranchTerminatorOpInterface to determine + // all successor operands that will be passed to the successor + // input arguments. + if (auto regionTerminatorInterface = + dyn_cast(operation)) + return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); + else if (operation->hasTrait()) + return MutableOperandRange(operation); + + return llvm::None; +} + +/// Returns the read only operands that are passed to the region with the given +/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more +/// information. +Optional +mlir::getRegionBranchSuccessorOperands(Operation *operation, + Optional regionIndex) { + auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); + return range ? Optional(*range) : llvm::None; +} diff --git a/mlir/lib/Transforms/BufferDeallocation.cpp b/mlir/lib/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Transforms/BufferDeallocation.cpp @@ -68,8 +68,8 @@ static void walkReturnOperations(Region *region, const FuncT &func) { for (Block &block : *region) for (Operation &operation : block) { - // Skip non-return-like terminators. - if (operation.hasTrait()) + // Skip non region-return-like terminators. + if (isRegionReturnLike(&operation)) func(&operation); } } @@ -390,12 +390,15 @@ // new buffer allocations. Thereby, the appropriate terminator operand // will be adjusted to point to the newly allocated buffer instead. walkReturnOperations(®ion, [&](Operation *terminator) { + // Get the actual mutable operands for this terminator op. + auto terminatorOperands = *getMutableRegionBranchSuccessorOperands( + terminator, region.getRegionNumber()); // Extract the source value from the current terminator. - Value sourceValue = terminator->getOperand(operandIndex); + Value sourceValue = ((OperandRange)terminatorOperands)[operandIndex]; // Create a new clone at the current location of the terminator. Value clone = introduceCloneBuffers(sourceValue, terminator); // Wire clone and terminator operand. - terminator->setOperand(operandIndex, clone); + terminatorOperands.slice(operandIndex, 1).assign(clone); }); } } diff --git a/mlir/lib/Transforms/BufferOptimizations.cpp b/mlir/lib/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Transforms/BufferOptimizations.cpp @@ -64,8 +64,7 @@ // If there is at least one alias that leaves the parent region, we know // that this alias escapes the whole region and hence the associated // allocation leaves allocation scope. - if (use->hasTrait() && - use->getParentRegion() == parentRegion) + if (isRegionReturnLike(use) && use->getParentRegion() == parentRegion) return true; } } diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -462,7 +462,7 @@ func @while_cross_region_type_mismatch() { %true = constant true - // expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and 'after' region arguments}} + // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to Region #1: source has 0 operands, but target successor needs 1}} scf.while : () -> () { scf.condition(%true) } do { @@ -475,8 +475,7 @@ func @while_cross_region_type_mismatch() { %true = constant true - // expected-error@+2 {{expects the same types for trailing operands of the 'before' block terminator and 'after' region arguments}} - // expected-note@+1 {{for argument 0, found 'i1' and 'i32}} + // expected-error@+1 {{'scf.while' op along control flow edge from Region #0 to Region #1: source type #0 'i1' should match input type #0 'i32'}} scf.while : () -> () { scf.condition(%true) %true : i1 } do { @@ -489,7 +488,7 @@ func @while_result_type_mismatch() { %true = constant true - // expected-error@+1 {{expects the same number of trailing operands of the 'before' block terminator and op results}} + // expected-error@+1 {{'scf.while' op region control flow edge from Region #0 to parent results: source has 1 operands, but target successor needs 0}} scf.while : () -> () { scf.condition(%true) %true : i1 } do {