diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1473,19 +1473,19 @@ function_ref thenBuilder, function_ref elseBuilder) { assert(thenBuilder && "the builder callback for 'then' must be present"); - result.addOperands(cond); result.addTypes(resultTypes); + // Build then region. OpBuilder::InsertionGuard guard(builder); Region *thenRegion = result.addRegion(); builder.createBlock(thenRegion); thenBuilder(builder, result.location); + // Build else region. Region *elseRegion = result.addRegion(); if (!elseBuilder) return; - builder.createBlock(elseRegion); elseBuilder(builder, result.location); } @@ -1493,7 +1493,25 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, function_ref thenBuilder, function_ref elseBuilder) { - build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder); + assert(thenBuilder && "the builder callback for 'then' must be present"); + result.addOperands(cond); + + // Build then region. + OpBuilder::InsertionGuard guard(builder); + Region *thenRegion = result.addRegion(); + Block *thenBlock = builder.createBlock(thenRegion); + thenBuilder(builder, result.location); + + // Infer types if there are any. + if (auto yieldOp = llvm::dyn_cast(thenBlock->getTerminator())) + result.addTypes(yieldOp.getOperandTypes()); + + // Build else region. + Region *elseRegion = result.addRegion(); + if (!elseBuilder) + return; + builder.createBlock(elseRegion); + elseBuilder(builder, result.location); } LogicalResult IfOp::verify() {