diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -22,15 +22,15 @@ /// Tile a parallel loop of the form /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) -/// step (%arg4, %arg5) +/// step (%arg4, %arg5) /// /// into /// scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) -/// step (%arg4*tileSize[0], -/// %arg5*tileSize[1]) +/// step (%arg4*tileSize[0], +/// %arg5*tileSize[1]) /// scf.parallel (%j0, %j1) = (0, 0) to (min(%arg4*tileSize[0], %arg2-%i0) -/// min(%arg5*tileSize[1], %arg3-%i1)) -/// step (%arg4, %arg5) +/// min(%arg5*tileSize[1], %arg3-%i1)) +/// step (%arg4, %arg5) /// /// where the uses of %i0 and %i1 in the loop body are replaced by /// %i0 + j0 and %i1 + %j1. @@ -126,17 +126,27 @@ op.erase(); } -/// Get a list of most nested parallel loops. Assumes that ParallelOps are -/// only directly nested. -static bool getInnermostNestedLoops(Block *block, - SmallVectorImpl &loops) { - bool hasInnerLoop = false; - for (auto parallelOp : block->getOps()) { - hasInnerLoop = true; - if (!getInnermostNestedLoops(parallelOp.getBody(), loops)) - loops.push_back(parallelOp); +/// Get a list of most nested parallel loops. +static bool getInnermostPloops(Operation *rootOp, + SmallVector &result) { + assert(rootOp != nullptr && "Root operation must not be a nullptr."); + bool rootEnclosesPloops = false; + for (auto ®ion : rootOp->getRegions()) { + for (auto &block : region.getBlocks()) { + for (Operation &op : block) { + bool enclosesPloops = getInnermostPloops(&op, result); + rootEnclosesPloops |= enclosesPloops; + if (auto ploop = dyn_cast(op)) { + rootEnclosesPloops = true; + + // Collect ploop if it is an innermost one. + if (!enclosesPloops) + result.push_back(ploop); + } + } + } } - return hasInnerLoop; + return rootEnclosesPloops; } namespace { @@ -148,14 +158,12 @@ } void runOnFunction() override { - SmallVector mostNestedParallelOps; - for (Block &block : getFunction()) { - getInnermostNestedLoops(&block, mostNestedParallelOps); - } - for (ParallelOp pLoop : mostNestedParallelOps) { + SmallVector innermostPloops; + getInnermostPloops(getFunction().getOperation(), innermostPloops); + for (ParallelOp ploop : innermostPloops) { // FIXME: Add reduction support. - if (pLoop.getNumReductions() == 0) - tileParallelLoop(pLoop, tileSizes); + if (ploop.getNumReductions() == 0) + tileParallelLoop(ploop, tileSizes); } } }; diff --git a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir --- a/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir +++ b/mlir/test/Dialect/SCF/parallel-loop-tiling.mlir @@ -112,3 +112,29 @@ // CHECK: } // CHECK: return // CHECK: } + +// ----- + +func @tile_nested_in_non_ploop() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + scf.for %i = %c0 to %c2 step %c1 { + scf.for %j = %c0 to %c2 step %c1 { + scf.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + } + return +} + +// CHECK-LABEL: func @tile_nested_in_non_ploop +// CHECK: scf.for +// CHECK: scf.for +// CHECK: scf.parallel +// CHECK: scf.parallel +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: }