diff --git a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h @@ -36,12 +36,15 @@ /// Adapters for building if conditions using the builder and the location /// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted /// if the condition should not have an 'else' part. -ValueRange -conditionBuilder(TypeRange results, Value condition, - function_ref thenBody, - function_ref elseBody = nullptr); +/// When `ifOp` is specified, the scf::IfOp is captured. This is particularly +/// convenient for 0-result conditions. +ValueRange conditionBuilder(TypeRange results, Value condition, + function_ref thenBody, + function_ref elseBody = nullptr, + scf::IfOp *ifOp = nullptr); ValueRange conditionBuilder(Value condition, function_ref thenBody, - function_ref elseBody = nullptr); + function_ref elseBody = nullptr, + scf::IfOp *ifOp = nullptr); } // namespace edsc } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp --- a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp @@ -76,14 +76,17 @@ ValueRange mlir::edsc::conditionBuilder(TypeRange results, Value condition, function_ref thenBody, - function_ref elseBody) { + function_ref elseBody, + scf::IfOp *ifOp) { assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); assert(thenBody && "thenBody is mandatory"); - auto ifOp = ScopedContext::getBuilderRef().create( + auto newOp = ScopedContext::getBuilderRef().create( ScopedContext::getLocation(), results, condition, wrapIfBody(thenBody, results), wrapIfBody(elseBody, results)); - return ifOp.getResults(); + if (ifOp) + *ifOp = newOp; + return newOp.getResults(); } static std::function @@ -97,14 +100,17 @@ ValueRange mlir::edsc::conditionBuilder(Value condition, function_ref thenBody, - function_ref elseBody) { + function_ref elseBody, + scf::IfOp *ifOp) { assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); assert(thenBody && "thenBody is mandatory"); - ScopedContext::getBuilderRef().create( + auto newOp = ScopedContext::getBuilderRef().create( ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody), elseBody ? llvm::function_ref( wrapZeroResultIfBody(elseBody)) : llvm::function_ref(nullptr)); + if (ifOp) + *ifOp = newOp; return {}; }