diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -37,7 +37,10 @@ } def ConditionOp : SCF_Op<"condition", - [HasParent<"WhileOp">, NoSideEffect, Terminator]> { + [HasParent<"WhileOp">, + DeclareOpInterfaceMethods, + NoSideEffect, + Terminator]> { let summary = "loop continuation condition"; let description = [{ This operation accepts the continuation (i.e., inverse of exit) condition diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td --- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td +++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td @@ -174,6 +174,34 @@ }]; } +//===----------------------------------------------------------------------===// +// RegionBranchTerminatorOpInterface +//===----------------------------------------------------------------------===// + +def RegionBranchTerminatorOpInterface : + OpInterface<"RegionBranchTerminatorOpInterface"> { + let description = [{ + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod<[{ + This interface provides information for branching terminator operations + in the presence of a parent RegionBranchOpInterface implementation. It + specifies which operands are passed to which successor region. + }], + "void", "getSuccessorOperands", + (ins "Optional":$index, "SmallVectorImpl &":$operands) + > + ]; + + let verify = [{ + static_assert(ConcreteOp::template hasTrait(), + "expected operation to be a terminator"); + return success(); + }]; +} + //===----------------------------------------------------------------------===// // ControlFlow Traits //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp --- a/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Analysis/BufferViewFlowAnalysis.cpp @@ -101,11 +101,27 @@ regionInterface.getSuccessorRegions(region.getRegionNumber(), successorRegions); for (RegionSuccessor &successorRegion : successorRegions) { + // Determine the current region index (if any). + auto regionIndex = + successorRegion.getSuccessor() + ? Optional( + successorRegion.getSuccessor()->getRegionNumber()) + : llvm::None; // Iterate over all immediate terminator operations and wire the // successor inputs with the operands of each terminator. for (Block &block : region) { for (Operation &operation : block) { - if (operation.hasTrait()) + // Try to query a RegionBranchTerminatorOpInterface to determine + // all successor operands that will be passed to the successor + // input arguments. + if (auto regionTerminatorInterface = + dyn_cast(operation)) { + SmallVector operands; + regionTerminatorInterface.getSuccessorOperands(regionIndex, + operands); + registerDependencies(operands, + successorRegion.getSuccessorInputs()); + } else if (operation.hasTrait()) registerDependencies(operation.getOperands(), successorRegion.getSuccessorInputs()); } diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -162,6 +162,16 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// ConditionOp +//===----------------------------------------------------------------------===// + +void ConditionOp::getSuccessorOperands(Optional index, + SmallVectorImpl &operands) { + // Pass all operands except the condition to the successor region. + operands.insert(operands.end(), args().begin(), args().end()); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===//