diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -220,23 +220,24 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<"Builder *builder, OperationState &result, " + OpBuilder<"Builder *builder, OperationState &result, " "Value cond, bool withElseRegion">, OpBuilder<"Builder *builder, OperationState &result, " - "TypeRange resultTypes, Value cond, " - "bool withElseRegion"> + "TypeRange resultTypes, Value cond, bool withElseRegion"> ]; let extraClassDeclaration = [{ OpBuilder getThenBodyBuilder() { assert(!thenRegion().empty() && "Unexpected empty 'then' region."); Block &body = thenRegion().front(); - return OpBuilder(&body, std::prev(body.end())); + return OpBuilder(&body, + results().empty() ? std::prev(body.end()) : body.end()); } OpBuilder getElseBodyBuilder() { assert(!elseRegion().empty() && "Unexpected empty 'else' region."); Block &body = elseRegion().front(); - return OpBuilder(&body, std::prev(body.end())); + return OpBuilder(&body, + results().empty() ? std::prev(body.end()) : body.end()); } }]; } diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -201,18 +201,25 @@ void IfOp::build(Builder *builder, OperationState &result, Value cond, bool withElseRegion) { - build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion); + build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion); } void IfOp::build(Builder *builder, OperationState &result, TypeRange resultTypes, Value cond, bool withElseRegion) { result.addOperands(cond); result.addTypes(resultTypes); + Region *thenRegion = result.addRegion(); - Region *elseRegion = result.addRegion(); - IfOp::ensureTerminator(*thenRegion, *builder, result.location); - if (withElseRegion) - IfOp::ensureTerminator(*elseRegion, *builder, result.location); + thenRegion->push_back(new Block()); + if (resultTypes.empty()) + IfOp::ensureTerminator(*thenRegion, *builder, result.location); + + if (withElseRegion) { + Region *elseRegion = result.addRegion(); + elseRegion->push_back(new Block()); + if (resultTypes.empty()) + IfOp::ensureTerminator(*elseRegion, *builder, result.location); + } } static LogicalResult verify(IfOp op) {