diff --git a/mlir/include/mlir/Dialect/SCF/SCFOps.td b/mlir/include/mlir/Dialect/SCF/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/SCFOps.td @@ -328,7 +328,15 @@ let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " "ValueRange lowerBounds, ValueRange upperBounds, " - "ValueRange steps, ValueRange initVals = {}">, + "ValueRange steps, ValueRange initVals, " + "function_ref" + " bodyBuilderFn = nullptr">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "ValueRange lowerBounds, ValueRange upperBounds, " + "ValueRange steps, " + "function_ref" + " bodyBuilderFn = nullptr">, ]; let extraClassDeclaration = [{ @@ -380,7 +388,9 @@ let skipDefaultBuilders = 1; let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value operand"> + "Value operand, " + "function_ref" + " bodyBuilderFn = nullptr"> ]; let arguments = (ins AnyType:$operand); diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -495,25 +495,56 @@ // ParallelOp //===----------------------------------------------------------------------===// -void ParallelOp::build(OpBuilder &builder, OperationState &result, - ValueRange lbs, ValueRange ubs, ValueRange steps, - ValueRange initVals) { - result.addOperands(lbs); - result.addOperands(ubs); +void ParallelOp::build( + OpBuilder &builder, OperationState &result, ValueRange lowerBounds, + ValueRange upperBounds, ValueRange steps, ValueRange initVals, + function_ref + bodyBuilderFn) { + result.addOperands(lowerBounds); + result.addOperands(upperBounds); result.addOperands(steps); result.addOperands(initVals); result.addAttribute( ParallelOp::getOperandSegmentSizeAttr(), - builder.getI32VectorAttr({static_cast(lbs.size()), - static_cast(ubs.size()), + builder.getI32VectorAttr({static_cast(lowerBounds.size()), + static_cast(upperBounds.size()), static_cast(steps.size()), static_cast(initVals.size())})); + result.addTypes(initVals.getTypes()); + + OpBuilder::InsertionGuard guard(builder); + unsigned numIVs = steps.size(); + SmallVector argTypes(numIVs, builder.getIndexType()); Region *bodyRegion = result.addRegion(); + Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes); + + if (bodyBuilderFn) { + builder.setInsertionPointToStart(bodyBlock); + bodyBuilderFn(builder, result.location, + bodyBlock->getArguments().take_front(numIVs), + bodyBlock->getArguments().drop_front(numIVs)); + } ParallelOp::ensureTerminator(*bodyRegion, builder, result.location); - for (size_t i = 0, e = steps.size(); i < e; ++i) - bodyRegion->front().addArgument(builder.getIndexType()); - for (Value init : initVals) - result.addTypes(init.getType()); +} + +void ParallelOp::build( + OpBuilder &builder, OperationState &result, ValueRange lowerBounds, + ValueRange upperBounds, ValueRange steps, + function_ref bodyBuilderFn) { + // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure + // we don't capture a reference to a temporary by constructing the lambda at + // function level. + auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder, + Location nestedLoc, ValueRange ivs, + ValueRange) { + bodyBuilderFn(nestedBuilder, nestedLoc, ivs); + }; + function_ref wrapper; + if (bodyBuilderFn) + wrapper = wrappedBuilderFn; + + build(builder, result, lowerBounds, upperBounds, steps, ValueRange(), + wrapper); } static LogicalResult verify(ParallelOp op) { @@ -679,15 +710,18 @@ // ReduceOp //===----------------------------------------------------------------------===// -void ReduceOp::build(OpBuilder &builder, OperationState &result, - Value operand) { +void ReduceOp::build( + OpBuilder &builder, OperationState &result, Value operand, + function_ref bodyBuilderFn) { auto type = operand.getType(); result.addOperands(operand); - Region *bodyRegion = result.addRegion(); - Block *b = new Block(); - b->addArguments(ArrayRef{type, type}); - bodyRegion->getBlocks().insert(bodyRegion->end(), b); + OpBuilder::InsertionGuard guard(builder); + Region *bodyRegion = result.addRegion(); + Block *body = builder.createBlock(bodyRegion, {}, ArrayRef{type, type}); + if (bodyBuilderFn) + bodyBuilderFn(builder, result.location, body->getArgument(0), + body->getArgument(1)); } static LogicalResult verify(ReduceOp op) { diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp --- a/mlir/lib/Transforms/Utils/LoopUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp @@ -1472,33 +1472,34 @@ // value. The remainders then determine based on that range, which iteration // of the original induction value this represents. This is a normalized value // that is un-normalized already by the previous logic. - auto newPloop = outsideBuilder.create(loc, lowerBounds, - upperBounds, steps); - OpBuilder insideBuilder(newPloop.region()); - for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { - Value previous = newPloop.getBody()->getArgument(i); - unsigned numberCombinedDimensions = combinedDimensions[i].size(); - // Iterate over all except the last induction value. - for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) { - unsigned idx = combinedDimensions[i][j]; - - // Determine the current induction value's current loop iteration - Value iv = insideBuilder.create(loc, previous, - normalizedUpperBounds[idx]); - replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, - loops.region()); - - // Remove the effect of the current induction value to prepare for the - // next value. - previous = insideBuilder.create( - loc, previous, normalizedUpperBounds[idx + 1]); - } + auto newPloop = outsideBuilder.create( + loc, lowerBounds, upperBounds, steps, + [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { + for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { + Value previous = ploopIVs[i]; + unsigned numberCombinedDimensions = combinedDimensions[i].size(); + // Iterate over all except the last induction value. + for (unsigned j = 0, e = numberCombinedDimensions - 1; j < e; ++j) { + unsigned idx = combinedDimensions[i][j]; + + // Determine the current induction value's current loop iteration + Value iv = insideBuilder.create( + loc, previous, normalizedUpperBounds[idx]); + replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, + loops.region()); + + // Remove the effect of the current induction value to prepare for + // the next value. + previous = insideBuilder.create( + loc, previous, normalizedUpperBounds[idx + 1]); + } - // The final induction value is just the remaining value. - unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1]; - replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), previous, - loops.region()); - } + // The final induction value is just the remaining value. + unsigned idx = combinedDimensions[i][numberCombinedDimensions - 1]; + replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), + previous, loops.region()); + } + }); // Replace the old loop with the new loop. loops.getBody()->back().erase();