diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -670,6 +670,8 @@ OpBuilder<(ins "Value":$cond, "bool":$withElseRegion)>, OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond, "bool":$withElseRegion)>, + // TODO: Remove builder when it is no longer used to create invalid `if` ops + // (with a type mispatch between the op and it's inner `yield` op). OpBuilder<(ins "TypeRange":$resultTypes, "Value":$cond, CArg<"function_ref", "buildTerminatedBody">:$thenBuilder, diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1490,19 +1490,19 @@ function_ref thenBuilder, function_ref elseBuilder) { assert(thenBuilder && "the builder callback for 'then' must be present"); - result.addOperands(cond); result.addTypes(resultTypes); + // Build then region. OpBuilder::InsertionGuard guard(builder); Region *thenRegion = result.addRegion(); builder.createBlock(thenRegion); thenBuilder(builder, result.location); + // Build else region. Region *elseRegion = result.addRegion(); if (!elseBuilder) return; - builder.createBlock(elseRegion); elseBuilder(builder, result.location); } @@ -1510,7 +1510,25 @@ void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, function_ref thenBuilder, function_ref elseBuilder) { - build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder); + assert(thenBuilder && "the builder callback for 'then' must be present"); + result.addOperands(cond); + + // Build then region. + OpBuilder::InsertionGuard guard(builder); + Region *thenRegion = result.addRegion(); + Block *thenBlock = builder.createBlock(thenRegion); + thenBuilder(builder, result.location); + + // Infer types if there are any. + if (auto yieldOp = llvm::dyn_cast(thenBlock->getTerminator())) + result.addTypes(yieldOp.getOperandTypes()); + + // Build else region. + Region *elseRegion = result.addRegion(); + if (!elseBuilder) + return; + builder.createBlock(elseRegion); + elseBuilder(builder, result.location); } LogicalResult IfOp::verify() {