diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -982,6 +982,12 @@ let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after); + let builders = [ + OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands, + "function_ref":$beforeBuilder, + "function_ref":$afterBuilder)> + ]; + let extraClassDeclaration = [{ OperandRange getSuccessorEntryOperands(Optional index); ConditionOp getConditionOp(); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2692,6 +2692,35 @@ // WhileOp //===----------------------------------------------------------------------===// +void WhileOp::build( + ::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, + TypeRange resultTypes, ValueRange operands, + function_ref beforeBuilder, + function_ref afterBuilder) { + assert(beforeBuilder && "the builder callback for 'before' must be present"); + assert(afterBuilder && "the builder callback for 'after' must be present"); + + odsState.addOperands(operands); + odsState.addTypes(resultTypes); + + OpBuilder::InsertionGuard guard(odsBuilder); + + SmallVector blockArgLocs; + for (Value operand : operands) { + blockArgLocs.push_back(operand.getLoc()); + } + + Region *beforeRegion = odsState.addRegion(); + Block *beforeBlock = + odsBuilder.createBlock(beforeRegion, {}, resultTypes, blockArgLocs); + beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); + + Region *afterRegion = odsState.addRegion(); + Block *afterBlock = + odsBuilder.createBlock(afterRegion, {}, resultTypes, blockArgLocs); + afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); +} + OperandRange WhileOp::getSuccessorEntryOperands(Optional index) { assert(index && *index == 0 && "WhileOp is expected to branch only to the first region");