diff --git a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td --- a/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td +++ b/mlir/include/mlir/Dialect/Async/IR/AsyncOps.td @@ -29,7 +29,8 @@ def Async_ExecuteOp : Async_Op<"execute", [SingleBlockImplicitTerminator<"YieldOp">, DeclareOpInterfaceMethods, + ["getSuccessorEntryOperands", + "getNumRegionInvocations"]>, AttrSizedOperandSegments]> { let summary = "Asynchronous execute operation"; let description = [{ @@ -99,7 +100,9 @@ } def Async_YieldOp : - Async_Op<"yield", [HasParent<"ExecuteOp">, NoSideEffect, Terminator]> { + Async_Op<"yield", [ + HasParent<"ExecuteOp">, NoSideEffect, Terminator, + DeclareOpInterfaceMethods]> { let summary = "terminator for Async execute operation"; let description = [{ The `async.yield` is a special terminator operation for the block inside diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -48,6 +48,12 @@ return success(); } +MutableOperandRange +YieldOp::getMutableSuccessorOperands(Optional index) { + assert(!index.hasValue()); + return operandsMutable(); +} + //===----------------------------------------------------------------------===// /// ExecuteOp //===----------------------------------------------------------------------===// @@ -55,24 +61,28 @@ constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; void ExecuteOp::getNumRegionInvocations( - ArrayRef operands, SmallVectorImpl &countPerRegion) { - (void)operands; + ArrayRef, SmallVectorImpl &countPerRegion) { assert(countPerRegion.empty()); countPerRegion.push_back(1); } +OperandRange ExecuteOp::getSuccessorEntryOperands(unsigned index) { + assert(index == 0 && "invalid region index"); + return operands(); +} + void ExecuteOp::getSuccessorRegions(Optional index, - ArrayRef operands, + ArrayRef, SmallVectorImpl ®ions) { // The `body` region branch back to the parent operation. if (index.hasValue()) { - assert(*index == 0); - regions.push_back(RegionSuccessor(getResults())); + assert(*index == 0 && "invalid region index"); + regions.push_back(RegionSuccessor(results())); return; } // Otherwise the successor is the body region. - regions.push_back(RegionSuccessor(&body())); + regions.push_back(RegionSuccessor(&body(), body().getArguments())); } void ExecuteOp::build(OpBuilder &builder, OperationState &result,