diff --git a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/SCF/EDSC/Builders.h @@ -82,6 +82,16 @@ Value lb, Value ub, Value step, ValueRange iterArgInitValues, function_ref fun = nullptr); +/// Adapters for building if conditions using the builder and the location +/// stored in ScopedContext. 'thenBody' is mandatory, 'elseBody' can be omitted +/// if the condition should not have an 'else' part. +ValueRange +conditionBuilder(TypeRange results, Value condition, + function_ref thenBody, + function_ref elseBody = nullptr); +ValueRange conditionBuilder(Value condition, function_ref thenBody, + function_ref elseBody = nullptr); + } // namespace edsc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/SCF.h b/mlir/include/mlir/Dialect/SCF/SCF.h --- a/mlir/include/mlir/Dialect/SCF/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/SCF.h @@ -24,6 +24,8 @@ namespace mlir { namespace scf { +void buildTerminatedBody(OpBuilder &builder, Location loc); + #include "mlir/Dialect/SCF/SCFOpsDialect.h.inc" #define GET_OP_CLASSES diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -238,7 +238,18 @@ OpBuilder<"OpBuilder &builder, OperationState &result, " "Value cond, bool withElseRegion">, OpBuilder<"OpBuilder &builder, OperationState &result, " - "TypeRange resultTypes, Value cond, bool withElseRegion"> + "TypeRange resultTypes, Value cond, bool withElseRegion">, + OpBuilder< + "OpBuilder &builder, OperationState &result, TypeRange resultTypes, " + "Value cond, " + "function_ref thenBuilder " + " = buildTerminatedBody, " + "function_ref elseBuilder = nullptr">, + OpBuilder< + "OpBuilder &builder, OperationState &result, Value cond, " + "function_ref thenBuilder " + " = buildTerminatedBody, " + "function_ref elseBuilder = nullptr"> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -235,39 +235,38 @@ SmallVector resultType; if (options.unroll) resultType.push_back(vectorType); - auto ifOp = ScopedContext::getBuilderRef().create( - ScopedContext::getLocation(), resultType, inBoundsCondition, - /*withElseRegion=*/true); - - // 3.a. If in-bounds, progressively lower to a 1-D transfer read. - BlockBuilder(&ifOp.thenRegion().front(), Append())([&] { - Value vector = load1DVector(majorIvsPlusOffsets); - // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the - // aggregate. We must yield and merge with the `else` branch. - if (options.unroll) { - vector = vector_insert(vector, result, majorIvs); - (loop_yield(vector)); - return; - } - // 3.a.ii. Otherwise, just go through the temporary `alloc`. - std_store(vector, alloc, majorIvs); - }); - - // 3.b. If not in-bounds, splat a 1-D vector. - BlockBuilder(&ifOp.elseRegion().front(), Append())([&] { - Value vector = std_splat(minorVectorType, xferOp.padding()); - // 3.a.i. If `options.unroll` is true, insert the 1-D vector in the - // aggregate. We must yield and merge with the `then` branch. - if (options.unroll) { - vector = vector_insert(vector, result, majorIvs); - (loop_yield(vector)); - return; - } - // 3.b.ii. Otherwise, just go through the temporary `alloc`. - std_store(vector, alloc, majorIvs); - }); + + // 3. If in-bounds, progressively lower to a 1-D transfer read, otherwise + // splat a 1-D vector. + ValueRange ifResults = conditionBuilder( + resultType, inBoundsCondition, + [&]() -> scf::ValueVector { + Value vector = load1DVector(majorIvsPlusOffsets); + // 3.a. If `options.unroll` is true, insert the 1-D vector in the + // aggregate. We must yield and merge with the `else` branch. + if (options.unroll) { + vector = vector_insert(vector, result, majorIvs); + return {vector}; + } + // 3.b. Otherwise, just go through the temporary `alloc`. + std_store(vector, alloc, majorIvs); + return {}; + }, + [&]() -> scf::ValueVector { + Value vector = std_splat(minorVectorType, xferOp.padding()); + // 3.c. If `options.unroll` is true, insert the 1-D vector in the + // aggregate. We must yield and merge with the `then` branch. + if (options.unroll) { + vector = vector_insert(vector, result, majorIvs); + return {vector}; + } + // 3.d. Otherwise, just go through the temporary `alloc`. + std_store(vector, alloc, majorIvs); + return {}; + }); + if (!resultType.empty()) - result = *ifOp.results().begin(); + result = *ifResults.begin(); } else { // 4. Guaranteed in-bounds, progressively lower to a 1-D transfer read. Value loaded1D = load1DVector(majorIvsPlusOffsets); @@ -335,11 +334,8 @@ if (inBoundsCondition) { // 2.a. If the condition is not null, we need an IfOp, to write // conditionally. Progressively lower to a 1-D transfer write. - auto ifOp = ScopedContext::getBuilderRef().create( - ScopedContext::getLocation(), TypeRange{}, inBoundsCondition, - /*withElseRegion=*/false); - BlockBuilder(&ifOp.thenRegion().front(), - Append())([&] { emitTransferWrite(majorIvsPlusOffsets); }); + conditionBuilder(inBoundsCondition, + [&] { emitTransferWrite(majorIvsPlusOffsets); }); } else { // 2.b. Guaranteed in-bounds. Progressively lower to a 1-D transfer write. emitTransferWrite(majorIvsPlusOffsets); diff --git a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp --- a/mlir/lib/Dialect/SCF/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/SCF/EDSC/Builders.cpp @@ -159,3 +159,51 @@ iterArgInitValues.end()); }); } + +static std::function +wrapIfBody(function_ref body, TypeRange expectedTypes) { + (void)expectedTypes; + return [=](OpBuilder &builder, Location loc) { + ScopedContext context(builder, loc); + scf::ValueVector returned = body(); + assert(ValueRange(returned).getTypes() == expectedTypes && + "'if' body builder returned values of unexpected type"); + builder.create(loc, returned); + }; +} + +ValueRange +mlir::edsc::conditionBuilder(TypeRange results, Value condition, + function_ref thenBody, + function_ref elseBody) { + assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); + assert(thenBody && "thenBody is mandatory"); + + auto ifOp = ScopedContext::getBuilderRef().create( + ScopedContext::getLocation(), results, condition, + wrapIfBody(thenBody, results), wrapIfBody(elseBody, results)); + return ifOp.getResults(); +} + +static std::function +wrapZeroResultIfBody(function_ref body) { + return [=](OpBuilder &builder, Location loc) { + ScopedContext context(builder, loc); + body(); + builder.create(loc); + }; +} + +ValueRange mlir::edsc::conditionBuilder(Value condition, + function_ref thenBody, + function_ref elseBody) { + assert(ScopedContext::getContext() && "EDSC ScopedContext not set up"); + assert(thenBody && "thenBody is mandatory"); + + ScopedContext::getBuilderRef().create( + ScopedContext::getLocation(), condition, wrapZeroResultIfBody(thenBody), + elseBody ? llvm::function_ref( + wrapZeroResultIfBody(elseBody)) + : llvm::function_ref(nullptr)); + return {}; +} diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -35,6 +35,11 @@ >(); } +/// Default callback for IfOp builders. Inserts a yield without arguments. +void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { + builder.create(loc); +} + //===----------------------------------------------------------------------===// // ForOp //===----------------------------------------------------------------------===// @@ -338,20 +343,43 @@ void IfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value cond, bool withElseRegion) { + auto addTerminator = [&](OpBuilder &nested, Location loc) { + if (resultTypes.empty()) + IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested, + loc); + }; + + build(builder, result, resultTypes, cond, addTerminator, + withElseRegion ? addTerminator + : function_ref()); +} + +void IfOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, Value cond, + function_ref thenBuilder, + function_ref elseBuilder) { + assert(thenBuilder && "the builder callback for 'then' must be present"); + result.addOperands(cond); result.addTypes(resultTypes); + OpBuilder::InsertionGuard guard(builder); Region *thenRegion = result.addRegion(); - thenRegion->push_back(new Block()); - if (resultTypes.empty()) - IfOp::ensureTerminator(*thenRegion, builder, result.location); + builder.createBlock(thenRegion); + thenBuilder(builder, result.location); Region *elseRegion = result.addRegion(); - if (withElseRegion) { - elseRegion->push_back(new Block()); - if (resultTypes.empty()) - IfOp::ensureTerminator(*elseRegion, builder, result.location); - } + if (!elseBuilder) + return; + + builder.createBlock(elseRegion); + elseBuilder(builder, result.location); +} + +void IfOp::build(OpBuilder &builder, OperationState &result, Value cond, + function_ref thenBuilder, + function_ref elseBuilder) { + build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder); } static LogicalResult verify(IfOp op) {