diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -215,4 +215,38 @@ }]; } + +def LoopCoalesceParallelOp : Op { + let summary = "Coalesces scf.forall loop"; + let description = [{ + Given a scf.forall, the dimensions are coalesced + + #### Return modes + + The return handle points to the coalesced loop if coalescing happens, or + the given input loop if coalescing does not happen. + }]; + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$collapsed_dim0, + DefaultValuedAttr:$collapsed_dim1, + DefaultValuedAttr:$collapsed_dim2, + OptionalAttr:$mapping + ); + + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = [{ + $target attr-dict `:` functional-type($target, $transformed) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::Operation *target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // SCF_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -141,6 +141,14 @@ void collapseParallelLoops(scf::ParallelOp loops, ArrayRef> combinedDimensions); +/// Take the Forall and for each set of dimension indices, combine them +/// into a single dimension. combinedDimensions must contain each index into +/// loops exactly once. +LogicalResult +collapseForallLoops(scf::ForallOp loops, + ArrayRef> combinedDimensions, + std::optional mapping); + /// Promotes the loop body of a scf::ForOp to its containing block if the loop /// was known to have a single iteration. LogicalResult promoteIfSingleIteration(scf::ForOp forOp); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -245,6 +245,42 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// LoopCoalesceParallelOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::LoopCoalesceParallelOp::applyToOne( + Operation *op, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + LogicalResult result(failure()); + if (scf::ForallOp scfForallOp = dyn_cast(op)) { + llvm::SmallVector, 3> combinedLoops; + if (!getCollapsedDim0().empty()) { + auto vec = getCollapsedDim0().vec(); + combinedLoops.push_back( + std::vector(vec.begin(), vec.end())); + } + if (!getCollapsedDim1().empty()) { + auto vec = getCollapsedDim1().vec(); + combinedLoops.push_back( + std::vector(vec.begin(), vec.end())); + } + if (!getCollapsedDim2().empty()) { + auto vec = getCollapsedDim2().vec(); + combinedLoops.push_back( + std::vector(vec.begin(), vec.end())); + } + result = collapseForallLoops(scfForallOp, combinedLoops, getMapping()); + } + + results.push_back(op); + if (failed(result)) { + DiagnosedSilenceableFailure diag = emitSilenceableError() + << "failed to coalesce"; + return diag; + } + return DiagnosedSilenceableFailure::success(); +} //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -707,6 +707,110 @@ return success(); } +LogicalResult +mlir::collapseForallLoops(scf::ForallOp loops, + ArrayRef> combinedDimensions, + std::optional mapping) { + OpBuilder outsideBuilder(loops); + Location loc = loops.getLoc(); + + // Presort combined dimensions. + auto sortedDimensions = llvm::to_vector<3>(combinedDimensions); + for (auto &dims : sortedDimensions) + llvm::sort(dims); + + // Normalize ForallOp's iteration pattern. + SmallVector normalizedLowerBounds, normalizedSteps, + normalizedUpperBounds; + for (unsigned i = 0, e = loops.getRank(); i < e; ++i) { + OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody()); + auto resultBounds = normalizeLoop(outsideBuilder, insideLoopBuilder, loc, + loops.getLowerBound(outsideBuilder)[i], + loops.getUpperBound(outsideBuilder)[i], + loops.getStep(outsideBuilder)[i], + loops.getBody()->getArgument(i)); + + normalizedLowerBounds.push_back(resultBounds.lowerBound); + normalizedUpperBounds.push_back(resultBounds.upperBound); + normalizedSteps.push_back(resultBounds.step); + } + + // Combine iteration spaces. + SmallVector lowerBounds, upperBounds, steps; + auto cst0 = outsideBuilder.create(loc, 0); + auto cst1 = outsideBuilder.create(loc, 1); + for (unsigned i = 0, e = sortedDimensions.size(); i < e; ++i) { + Value newUpperBound = + outsideBuilder.createOrFold(loc, 1); + for (auto idx : sortedDimensions[i]) { + newUpperBound = outsideBuilder.createOrFold( + loc, newUpperBound, normalizedUpperBounds[idx]); + } + lowerBounds.push_back(cst0); + steps.push_back(cst1); + upperBounds.push_back(newUpperBound); + } + + // Create new Forall with conversions to the original induction values. + // The loop below uses divisions to get the relevant range of values in the + // new induction value that represent each range of the original induction + // value. The remainders then determine based on that range, which iteration + // of the original induction value this represents. This is a normalized value + // that is un-normalized already by the previous logic. + auto newPloop = outsideBuilder.create( + loc, getAsOpFoldResult(upperBounds), loops.getOutputs(), mapping, + [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { + for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { + Value previous = ploopIVs[i]; + unsigned numberCombinedDimensions = combinedDimensions[i].size(); + // Iterate over all except the last induction value. + for (unsigned j = numberCombinedDimensions - 1; j > 0; --j) { + unsigned idx = combinedDimensions[i][j]; + + // Determine the current induction value's current loop iteration + Value iv = insideBuilder.createOrFold( + loc, previous, normalizedUpperBounds[idx]); + replaceAllUsesInRegionWith(loops.getInductionVar(idx), iv, + loops.getRegion()); + + // Remove the effect of the current induction value to prepare for + // the next value. + previous = insideBuilder.createOrFold( + loc, previous, normalizedUpperBounds[idx]); + } + + // The final induction value is just the remaining value. + unsigned idx = combinedDimensions[i][0]; + replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), + previous, loops.getRegion()); + } + // Create empty in_parallel section + insideBuilder.create(loc); + }); + + // Map the old values to new values when cloning the code + IRMapping irMapping; + irMapping.map(loops.getOutputBlockArguments(), + newPloop.getOutputBlockArguments()); + + // Clone the body of forall + outsideBuilder.setInsertionPoint(newPloop.getTerminator()); + for (auto &op : loops.getBody()->without_terminator()) + outsideBuilder.clone(op, irMapping); + + // Clone the body of forall terminator + outsideBuilder.setInsertionPointToStart(newPloop.getTerminator().getBody()); + auto forallTerminator = loops.getTerminator(); + for (auto &bodyOp : forallTerminator.getYieldingOps()) { + outsideBuilder.clone(bodyOp, irMapping); + } + + // Replace the old loop with the new loop. + loops.replaceAllUsesWith(newPloop); + loops.erase(); + return success(); +} + void mlir::collapseParallelLoops( scf::ParallelOp loops, ArrayRef> combinedDimensions) { OpBuilder outsideBuilder(loops); diff --git a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir --- a/mlir/test/Dialect/SCF/transform-op-coalesce.mlir +++ b/mlir/test/Dialect/SCF/transform-op-coalesce.mlir @@ -90,3 +90,36 @@ %2 = transform.loop.coalesce %1 : (!transform.op<"scf.for">) -> (!transform.op<"scf.for">) transform.loop.unroll %2 {factor = 3} : !transform.op<"scf.for"> } + + +// ----- +#map = affine_map<(d0) -> (d0 * 2)> +func.func @conv(%0: tensor<32x230x230x32xf32>, %1: tensor<3x3x32x64xf32>, %2: tensor<32x228x228x64xf32>) -> tensor<32x228x228x64xf32> { + %c0 = arith.constant 0 : index + %6 = scf.forall (%arg0, %arg1, %arg2, %arg3) in (16, 114, 114, 32) shared_outs(%arg4 = %2) -> (tensor<32x228x228x64xf32>) { + // CHECK: scf.forall (%[[IV1:.+]], %[[IV2:.+]]) in + // CHECK: %[[IDX0:.+]] = arith.remsi %[[IV1]] + // CHECK: %[[IDX1:.+]] = arith.divsi %[[IV1]] + // CHECK: %[[IDX2:.+]] = arith.remsi %[[IV2]] + // CHECK: %[[IDX3:.+]] = arith.divsi %[[IV2]] + %7 = affine.apply #map(%arg0) + %8 = affine.apply #map(%arg1) + %9 = affine.apply #map(%arg2) + %10 = affine.apply #map(%arg3) + %extracted_slice = tensor.extract_slice %0[%7, %8, %9, 0] [2, 4, 4, 32] [1, 1, 1, 1] : tensor<32x230x230x32xf32> to tensor<2x4x4x32xf32> + %extracted_slice_0 = tensor.extract_slice %1[0, 0, 0, %10] [3, 3, 32, 2] [1, 1, 1, 1] : tensor<3x3x32x64xf32> to tensor<3x3x32x2xf32> + %extracted_slice_1 = tensor.extract_slice %arg4[%7, %8, %9, %10] [2, 2, 2, 2] [1, 1, 1, 1] : tensor<32x228x228x64xf32> to tensor<2x2x2x2xf32> + %11 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%extracted_slice, %extracted_slice_0 : tensor<2x4x4x32xf32>, tensor<3x3x32x2xf32>) outs(%extracted_slice_1 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> + // CHECK: scf.forall.in_parallel + scf.forall.in_parallel { + tensor.parallel_insert_slice %11 into %arg4[%7, %8, %9, %10] [2, 2, 2, 2] [1, 1, 1, 1] : tensor<2x2x2x2xf32> into tensor<32x228x228x64xf32> + } + } + return %6 : tensor<32x228x228x64xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!pdl.operation) -> !pdl.operation + %2 = transform.loop.coalesce_parallel %0 {collapsed_dim0 = array, collapsed_dim1 = array, mapping = [#gpu.block, #gpu.block]} : (!pdl.operation) -> !pdl.operation +} \ No newline at end of file