diff --git a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h --- a/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h @@ -397,13 +397,13 @@ void visitRegionSuccessors(RegionBranchOpInterface branch, ArrayRef operands); - /// Visit a terminator (an op implementing `RegionBranchTerminatorOpInterface` - /// or a return-like op) to compute the lattice values of its operands, given - /// its parent op `branch`. The lattice value of an operand is determined - /// based on the corresponding arguments in `terminator`'s region - /// successor(s). - void visitRegionSuccessorsFromTerminator(Operation *terminator, - RegionBranchOpInterface branch); + /// Visit a `RegionBranchTerminatorOpInterface` to compute the lattice values + /// of its operands, given its parent op `branch`. The lattice value of an + /// operand is determined based on the corresponding arguments in + /// `terminator`'s region successor(s). + void visitRegionSuccessorsFromTerminator( + RegionBranchTerminatorOpInterface terminator, + RegionBranchOpInterface branch); /// Get the lattice element for a value, and also set up /// dependencies so that the analysis on the given ProgramPoint is re-invoked diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h @@ -229,32 +229,6 @@ /// exists. Region *getEnclosingRepetitiveRegion(Value value); -//===----------------------------------------------------------------------===// -// RegionBranchTerminatorOpInterface -//===----------------------------------------------------------------------===// - -/// Returns true if the given operation is either annotated with the -/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. -bool isRegionReturnLike(Operation *operation); - -/// Returns the mutable operands that are passed to the region with the given -/// `regionIndex`. If the operation does not implement the -/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the -/// result will be `std::nullopt`. In all other cases, the resulting -/// `OperandRange` represents all operands that are passed to the specified -/// successor region. If `regionIndex` is `std::nullopt`, all operands that are -/// passed to the parent operation will be returned. -std::optional -getMutableRegionBranchSuccessorOperands(Operation *operation, - std::optional regionIndex); - -/// Returns the read only operands that are passed to the region with the given -/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more -/// information. -std::optional -getRegionBranchSuccessorOperands(Operation *operation, - std::optional regionIndex); - //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -272,6 +272,19 @@ //===----------------------------------------------------------------------===// // Op is "return-like". -def ReturnLike : NativeOpTrait<"ReturnLike">; +def ReturnLike : TraitList<[ + DeclareOpInterfaceMethods, + NativeOpTrait< + /*name=*/"ReturnLike", + /*traits=*/[], + /*extraOpDeclaration=*/"", + /*extraOpDefinition=*/[{ + ::mlir::MutableOperandRange $cppClass::getMutableSuccessorOperands( + ::std::optional index) { + return ::mlir::MutableOperandRange(*this); + } + }] + > +]>; #endif // MLIR_INTERFACES_CONTROLFLOWINTERFACES diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -91,15 +91,14 @@ for (int i = 0, e = op->getNumRegions(); i != e; ++i) { if (std::optional operandIndex = getOperandIndexIfPred(i)) { for (Block &block : op->getRegion(i)) { - Operation *term = block.getTerminator(); // Try to determine possible region-branch successor operands for the // current region. - auto successorOperands = - getRegionBranchSuccessorOperands(term, regionIndex); - if (successorOperands) { - collectUnderlyingAddressValues((*successorOperands)[*operandIndex], - maxDepth, visited, output); - } else if (term->getNumSuccessors()) { + if (auto term = dyn_cast( + block.getTerminator())) { + collectUnderlyingAddressValues( + term.getSuccessorOperands(regionIndex)[*operandIndex], maxDepth, + visited, output); + } else if (block.getNumSuccessors()) { // Otherwise, if this terminator may exit the region we can't make // any assumptions about which values get passed. output.push_back(inputValue); diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -337,9 +337,8 @@ // There may be a weird case where a terminator may be transferring control // either to the parent or to another block, so exit blocks and successors // are not mutually exclusive. - Operation *terminator = b->getTerminator(); - return terminator && (terminator->hasTrait() || - isa(terminator)); + return isa_and_nonnull( + b->getTerminator()); }; if (isExitBlock(block)) { // If this block is exiting from a callable, the successors of exiting from diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -93,11 +93,9 @@ // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op. Operation *op = operand.getOwner(); assert((isa(op) || isa(op) || - isa(op) || - op->hasTrait()) && + isa(op)) && "expected the op to be `RegionBranchOpInterface`, " - "`BranchOpInterface`, `RegionBranchTerminatorOpInterface`, or " - "return-like"); + "`BranchOpInterface` or `RegionBranchTerminatorOpInterface`"); // The lattices of the non-forwarded branch operands don't get updated like // the forwarded branch operands or the non-branch operands. Thus they need @@ -161,11 +159,10 @@ visitOperation(op, operandLiveness, resultsLiveness); // We also visit the parent op with the parent's results and this operand if - // `op` is a `RegionBranchTerminatorOpInterface` or return-like because its - // non-forwarded operand depends on not only its memory effects/results but - // also on those of its parent's. - if (!isa(op) && - !op->hasTrait()) + // `op` is a `RegionBranchTerminatorOpInterface` because its non-forwarded + // operand depends on not only its memory effects/results but also on those of + // its parent's. + if (!isa(op)) return; Operation *parentOp = op->getParentOp(); SmallVector parentResultsLiveness; diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -226,9 +226,9 @@ if (op == branch) { operands = branch.getSuccessorEntryOperands(successorIndex); // Otherwise, try to deduce the operands from a region return-like op. - } else { - if (isRegionReturnLike(op)) - operands = getRegionBranchSuccessorOperands(op, successorIndex); + } else if (auto regionTerminator = + dyn_cast(op)) { + operands = regionTerminator.getSuccessorOperands(successorIndex); } if (!operands) { @@ -439,10 +439,9 @@ // successor's input. There are two types of successor operands: the operands // of this op itself and the operands of the terminators of the regions of // this op. - if (isa(op) || - op->hasTrait()) { + if (auto terminator = dyn_cast(op)) { if (auto branch = dyn_cast(op->getParentOp())) { - visitRegionSuccessorsFromTerminator(op, branch); + visitRegionSuccessorsFromTerminator(terminator, branch); return; } } @@ -506,12 +505,11 @@ } void AbstractSparseBackwardDataFlowAnalysis:: - visitRegionSuccessorsFromTerminator(Operation *terminator, - RegionBranchOpInterface branch) { - assert(isa(terminator) || - terminator->hasTrait() && - "expected a `RegionBranchTerminatorOpInterface` op or a " - "return-like op"); + visitRegionSuccessorsFromTerminator( + RegionBranchTerminatorOpInterface terminator, + RegionBranchOpInterface branch) { + assert(isa(terminator) && + "expected a `RegionBranchTerminatorOpInterface` op"); assert(terminator->getParentOp() == branch.getOperation() && "expected `branch` to be the parent op of `terminator`"); @@ -527,10 +525,8 @@ for (const RegionSuccessor &successor : successors) { ValueRange inputs = successor.getSuccessorInputs(); Region *region = successor.getSuccessor(); - OperandRange operands = - region ? *getRegionBranchSuccessorOperands(terminator, - region->getRegionNumber()) - : *getRegionBranchSuccessorOperands(terminator, {}); + OperandRange operands = terminator.getSuccessorOperands( + region ? region->getRegionNumber() : std::optional{}); MutableArrayRef opOperands = operandsToOpOperands(operands); for (auto [opOperand, input] : llvm::zip(opOperands, inputs)) { meet(getLatticeElement(opOperand.get()), diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -690,7 +690,7 @@ return true; // Check if the op is returning/yielding. - if (isRegionReturnLike(op)) + if (isa(op)) return true; // Add all aliasing OpResults to the worklist. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocation.cpp @@ -70,14 +70,15 @@ using namespace mlir::bufferization; /// Walks over all immediate return-like terminators in the given region. -static LogicalResult -walkReturnOperations(Region *region, - llvm::function_ref func) { +static LogicalResult walkReturnOperations( + Region *region, + llvm::function_ref func) { for (Block &block : *region) { Operation *terminator = block.getTerminator(); // Skip non region-return-like terminators. - if (isRegionReturnLike(terminator)) { - if (failed(func(terminator))) + if (auto regionTerminator = + dyn_cast(terminator)) { + if (failed(func(regionTerminator))) return failure(); } } @@ -447,23 +448,25 @@ // Iterate over all immediate terminator operations to introduce // new buffer allocations. Thereby, the appropriate terminator operand // will be adjusted to point to the newly allocated buffer instead. - if (failed(walkReturnOperations(®ion, [&](Operation *terminator) { - // Get the actual mutable operands for this terminator op. - auto terminatorOperands = *getMutableRegionBranchSuccessorOperands( - terminator, region.getRegionNumber()); - // Extract the source value from the current terminator. - // This conversion needs to exist on a separate line due to a bug in - // GCC conversion analysis. - OperandRange immutableTerminatorOperands = terminatorOperands; - Value sourceValue = immutableTerminatorOperands[operandIndex]; - // Create a new clone at the current location of the terminator. - auto clone = introduceCloneBuffers(sourceValue, terminator); - if (failed(clone)) - return failure(); - // Wire clone and terminator operand. - terminatorOperands.slice(operandIndex, 1).assign(*clone); - return success(); - }))) + if (failed(walkReturnOperations( + ®ion, [&](RegionBranchTerminatorOpInterface terminator) { + // Get the actual mutable operands for this terminator op. + auto terminatorOperands = + terminator.getMutableSuccessorOperands( + region.getRegionNumber()); + // Extract the source value from the current terminator. + // This conversion needs to exist on a separate line due to a + // bug in GCC conversion analysis. + OperandRange immutableTerminatorOperands = terminatorOperands; + Value sourceValue = immutableTerminatorOperands[operandIndex]; + // Create a new clone at the current location of the terminator. + auto clone = introduceCloneBuffers(sourceValue, terminator); + if (failed(clone)) + return failure(); + // Wire clone and terminator operand. + terminatorOperands.slice(operandIndex, 1).assign(*clone); + return success(); + }))) return failure(); } return success(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferOptimizations.cpp @@ -75,7 +75,8 @@ // If there is at least one alias that leaves the parent region, we know // that this alias escapes the whole region and hence the associated // allocation leaves allocation scope. - if (isRegionReturnLike(use) && use->getParentRegion() == parentRegion) + if (isa(use) && + use->getParentRegion() == parentRegion) return true; } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -128,14 +128,11 @@ regionIndex = regionSuccessor->getRegionNumber(); // Iterate over all immediate terminator operations and wire the // successor inputs with the successor operands of each terminator. - for (Block &block : region) { - auto successorOperands = getRegionBranchSuccessorOperands( - block.getTerminator(), regionIndex); - if (successorOperands) { - registerDependencies(*successorOperands, + for (Block &block : region) + if (auto terminator = dyn_cast( + block.getTerminator())) + registerDependencies(terminator.getSuccessorOperands(regionIndex), successorRegion.getSuccessorInputs()); - } - } } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -183,7 +183,8 @@ // the IR. void OneShotAnalysisState::gatherYieldedTensors(Operation *op) { op->walk([&](Operation *returnOp) { - if (!isRegionReturnLike(returnOp) || !getOptions().isOpAllowed(returnOp)) + if (!isa(returnOp) || + !getOptions().isOpAllowed(returnOp)) return WalkResult::advance(); for (OpOperand &returnValOperand : returnOp->getOpOperands()) { @@ -1059,7 +1060,7 @@ LogicalResult status = success(); DominanceInfo domInfo(op); op->walk([&](Operation *returnOp) { - if (!isRegionReturnLike(returnOp) || + if (!isa(returnOp) || !state.getOptions().isOpAllowed(returnOp)) return WalkResult::advance(); diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -183,12 +183,13 @@ std::optional regionReturnOperands; for (Block &block : region) { - Operation *terminator = block.getTerminator(); - auto terminatorOperands = - getRegionBranchSuccessorOperands(terminator, regionNo); - if (!terminatorOperands) + auto terminator = + dyn_cast(block.getTerminator()); + if (!terminator) continue; + OperandRange terminatorOperands = + terminator.getSuccessorOperands(regionNo); if (!regionReturnOperands) { regionReturnOperands = terminatorOperands; continue; @@ -197,7 +198,7 @@ // Found more than one ReturnLike terminator. Make sure the operand types // match with the first one. if (!areTypesCompatible(regionReturnOperands->getTypes(), - terminatorOperands->getTypes())) + terminatorOperands.getTypes())) return op->emitOpError("Region #") << regionNo << " operands mismatch between return-like terminators"; @@ -316,7 +317,7 @@ // exiting terminator in the region. for (Block &block : getOperation()->getRegion(*index)) { Operation *terminator = block.getTerminator(); - if (getRegionBranchSuccessorOperands(terminator, *index)) { + if (isa(terminator)) { numInputs = terminator->getNumOperands(); break; } @@ -350,51 +351,3 @@ } return nullptr; } - -//===----------------------------------------------------------------------===// -// RegionBranchTerminatorOpInterface -//===----------------------------------------------------------------------===// - -/// Returns true if the given operation is either annotated with the -/// `ReturnLike` trait or implements the `RegionBranchTerminatorOpInterface`. -bool mlir::isRegionReturnLike(Operation *operation) { - return dyn_cast(operation) || - operation->hasTrait(); -} - -/// Returns the mutable operands that are passed to the region with the given -/// `regionIndex`. If the operation does not implement the -/// `RegionBranchTerminatorOpInterface` and is not marked as `ReturnLike`, the -/// result will be `std::nullopt`. In all other cases, the resulting -/// `OperandRange` represents all operands that are passed to the specified -/// successor region. If `regionIndex` is `std::nullopt`, all operands that are -/// passed to the parent operation will be returned. -std::optional -mlir::getMutableRegionBranchSuccessorOperands( - Operation *operation, std::optional regionIndex) { - // Try to query a RegionBranchTerminatorOpInterface to determine - // all successor operands that will be passed to the successor - // input arguments. - if (auto regionTerminatorInterface = - dyn_cast(operation)) - return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex); - - // TODO: The ReturnLike trait should imply a default implementation of the - // RegionBranchTerminatorOpInterface. This would make this code significantly - // easier. Furthermore, this may even make this function obsolete. - if (operation->hasTrait()) - return MutableOperandRange(operation); - return std::nullopt; -} - -/// Returns the read only operands that are passed to the region with the given -/// `regionIndex`. See `getMutableRegionBranchSuccessorOperands` for more -/// information. -std::optional -mlir::getRegionBranchSuccessorOperands(Operation *operation, - std::optional regionIndex) { - auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex); - if (range) - return range->operator OperandRange(); - return std::nullopt; -}