diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -3066,9 +3066,6 @@ ::mlir::OperationState &odsState, TypeRange resultTypes, ValueRange operands, BodyBuilderFn beforeBuilder, BodyBuilderFn afterBuilder) { - assert(beforeBuilder && "the builder callback for 'before' must be present"); - assert(afterBuilder && "the builder callback for 'after' must be present"); - odsState.addOperands(operands); odsState.addTypes(resultTypes); @@ -3084,7 +3081,8 @@ Region *beforeRegion = odsState.addRegion(); Block *beforeBlock = odsBuilder.createBlock( beforeRegion, /*insertPt=*/{}, operands.getTypes(), beforeArgLocs); - beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); + if (beforeBuilder) + beforeBuilder(odsBuilder, odsState.location, beforeBlock->getArguments()); // Build after region. SmallVector afterArgLocs(resultTypes.size(), odsState.location); @@ -3092,7 +3090,9 @@ Region *afterRegion = odsState.addRegion(); Block *afterBlock = odsBuilder.createBlock(afterRegion, /*insertPt=*/{}, resultTypes, afterArgLocs); - afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); + + if (afterBuilder) + afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments()); } OperandRange WhileOp::getSuccessorEntryOperands(std::optional index) { @@ -3811,13 +3811,11 @@ return rewriter.notifyMatchFailure(op, "No results to remove"); ValueRange argsRange(newArgs); - auto emptyBuilder = [](OpBuilder &, Location, ValueRange) { - // Nothing - }; Location loc = op.getLoc(); auto newWhileOp = rewriter.create( - loc, argsRange.getTypes(), op.getInits(), emptyBuilder, emptyBuilder); + loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr, + /*afterBody*/ nullptr); Block &newBeforeBlock = newWhileOp.getBefore().front(); Block &newAfterBlock = newWhileOp.getAfter().front(); @@ -3878,13 +3876,10 @@ beforeBlock.eraseArguments(argsToRemove); - auto emptyBuilder = [](OpBuilder &, Location, ValueRange) { - // Nothing - }; - Location loc = op.getLoc(); - auto newWhileOp = rewriter.create( - loc, op->getResultTypes(), newInits, emptyBuilder, emptyBuilder); + auto newWhileOp = + rewriter.create(loc, op->getResultTypes(), newInits, + /*beforeBody*/ nullptr, /*afterBody*/ nullptr); Block &newBeforeBlock = newWhileOp.getBefore().front(); Block &newAfterBlock = newWhileOp.getAfter().front();