diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -935,7 +935,7 @@ Note that the types of region arguments need not to match with each other. The op expects the operand types to match with argument types of the - "before" region"; the result types to match with the trailing operand types + "before" region; the result types to match with the trailing operand types of the terminator of the "before" region, and with the argument types of the "after" region. The following scheme can be used to share the results of some operations executed in the "before" region with the "after" region, @@ -983,7 +983,16 @@ let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after); + let builders = [ + OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$operands, + "function_ref":$beforeBuilder, + "function_ref":$afterBuilder)> + ]; + let extraClassDeclaration = [{ + using BodyBuilderFn = + function_ref; + OperandRange getSuccessorEntryOperands(Optional index); ConditionOp getConditionOp(); YieldOp getYieldOp(); diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -71,40 +71,32 @@ SmallVector types = {elementTy, elementTy, elementTy}; SmallVector locations = {loc, loc, loc}; - auto whileOp = rewriter.create(loc, types, operands); - Block *before = - rewriter.createBlock(&whileOp.getBefore(), {}, types, locations); - Block *after = - rewriter.createBlock(&whileOp.getAfter(), {}, types, locations); - - // The conditional block of the while loop. - { - rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); - Value input = before->getArgument(0); - Value zero = before->getArgument(2); - - Value inputNotZero = rewriter.create( - loc, arith::CmpIPredicate::ne, input, zero); - rewriter.create(loc, inputNotZero, - before->getArguments()); - } - - // The body of the while loop: shift right until reaching a value of 0. - { - rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); - Value input = after->getArgument(0); - Value leadingZeros = after->getArgument(1); - - auto one = - rewriter.create(loc, IntegerAttr::get(elementTy, 1)); - auto shifted = rewriter.create(loc, resultTy, input, one); - auto leadingZerosMinusOne = - rewriter.create(loc, resultTy, leadingZeros, one); - - rewriter.create( - loc, - ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)})); - } + auto whileOp = rewriter.create( + loc, types, operands, + [&](OpBuilder &beforeBuilder, Location beforeLoc, ValueRange args) { + // The conditional block of the while loop. + Value input = args[0]; + Value zero = args[2]; + + Value inputNotZero = beforeBuilder.create( + loc, arith::CmpIPredicate::ne, input, zero); + beforeBuilder.create(loc, inputNotZero, args); + }, + [&](OpBuilder &afterBuilder, Location afterLoc, ValueRange args) { + // The body of the while loop: shift right until reaching a value of 0. + Value input = args[0]; + Value leadingZeros = args[1]; + + auto one = afterBuilder.create( + loc, IntegerAttr::get(elementTy, 1)); + auto shifted = + afterBuilder.create(loc, resultTy, input, one); + auto leadingZerosMinusOne = afterBuilder.create( + loc, resultTy, leadingZeros, one); + + afterBuilder.create( + loc, ValueRange({shifted, leadingZerosMinusOne, args[2]})); + }); rewriter.setInsertionPointAfter(whileOp); rewriter.replaceOp(op, whileOp->getResult(1)); diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -2669,6 +2669,34 @@ // WhileOp //===----------------------------------------------------------------------===// +void WhileOp::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, TypeRange resultTypes, + ValueRange operands, BodyBuilderFn beforeBuilder, + BodyBuilderFn afterBuilder) { + assert(beforeBuilder && "the builder callback for 'before' must be present"); + assert(afterBuilder && "the builder callback for 'after' must be present"); + + odsState.addOperands(operands); + odsState.addTypes(resultTypes); + + OpBuilder::InsertionGuard guard(odsBuilder); + + SmallVector blockArgLocs; + for (Value operand : operands) { + blockArgLocs.push_back(operand.getLoc()); + } + + Region *beforeRegion = odsState.addRegion(); + Block *beforeBlock = odsBuilder.createBlock(beforeRegion, /*insertPt=*/{}, + resultTypes, blockArgLocs); + beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); + + Region *afterRegion = odsState.addRegion(); + Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{}, + resultTypes, blockArgLocs); + afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); +} + OperandRange WhileOp::getSuccessorEntryOperands(Optional index) { assert(index && *index == 0 && "WhileOp is expected to branch only to the first region");