diff --git a/mlir/lib/Dialect/SCF/SCF.cpp b/mlir/lib/Dialect/SCF/SCF.cpp --- a/mlir/lib/Dialect/SCF/SCF.cpp +++ b/mlir/lib/Dialect/SCF/SCF.cpp @@ -1705,11 +1705,70 @@ } }; +struct MergeNestedParallelLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ParallelOp op, + PatternRewriter &rewriter) const override { + Block &outerBody = op.getLoopBody().front(); + if (!llvm::hasSingleElement(outerBody.without_terminator())) + return failure(); + + auto innerOp = dyn_cast(outerBody.front()); + if (!innerOp) + return failure(); + + auto hasVal = [](const auto &range, Value val) { + return llvm::find(range, val) != range.end(); + }; + + for (auto val : outerBody.getArguments()) + if (hasVal(innerOp.lowerBound(), val) || + hasVal(innerOp.upperBound(), val) || hasVal(innerOp.step(), val)) + return failure(); + + // Reductions are not supported yet. + if (!op.initVals().empty() || !innerOp.initVals().empty()) + return failure(); + + auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/, + ValueRange iterVals, ValueRange) { + Block &innerBody = innerOp.getLoopBody().front(); + assert(iterVals.size() == + (outerBody.getNumArguments() + innerBody.getNumArguments())); + BlockAndValueMapping mapping; + mapping.map(outerBody.getArguments(), + iterVals.take_front(outerBody.getNumArguments())); + mapping.map(innerBody.getArguments(), + iterVals.take_back(innerBody.getNumArguments())); + for (Operation &op : innerBody.without_terminator()) + builder.clone(op, mapping); + }; + + auto concatValues = [](const auto &first, const auto &second) { + SmallVector ret; + ret.reserve(first.size() + second.size()); + ret.assign(first.begin(), first.end()); + ret.append(second.begin(), second.end()); + return ret; + }; + + auto newLowerBounds = concatValues(op.lowerBound(), innerOp.lowerBound()); + auto newUpperBounds = concatValues(op.upperBound(), innerOp.upperBound()); + auto newSteps = concatValues(op.step(), innerOp.step()); + + rewriter.replaceOpWithNewOp(op, newLowerBounds, newUpperBounds, + newSteps, llvm::None, bodyBuilder); + return success(); + } +}; + } // namespace void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -99,6 +99,41 @@ // ----- +func @nested_parallel(%0: memref) -> memref { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %1 = memref.dim %0, %c0 : memref + %2 = memref.dim %0, %c1 : memref + %3 = memref.dim %0, %c2 : memref + %4 = memref.alloc(%1, %2, %3) : memref + scf.parallel (%arg1) = (%c0) to (%1) step (%c1) { + scf.parallel (%arg2) = (%c0) to (%2) step (%c1) { + scf.parallel (%arg3) = (%c0) to (%3) step (%c1) { + %5 = memref.load %0[%arg1, %arg2, %arg3] : memref + memref.store %5, %4[%arg1, %arg2, %arg3] : memref + scf.yield + } + scf.yield + } + scf.yield + } + return %4 : memref +} + +// CHECK-LABEL: func @nested_parallel( +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[B0:%.*]] = memref.dim {{.*}}, [[C0]] +// CHECK: [[B1:%.*]] = memref.dim {{.*}}, [[C1]] +// CHECK: [[B2:%.*]] = memref.dim {{.*}}, [[C2]] +// CHECK: scf.parallel ([[V0:%.*]], [[V1:%.*]], [[V2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[B0]], [[B1]], [[B2]]) step ([[C1]], [[C1]], [[C1]]) +// CHECK: memref.load {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]] +// CHECK: memref.store {{.*}}{{\[}}[[V0]], [[V1]], [[V2]]] + +// ----- + func private @side_effect() func @one_unused(%cond: i1) -> (index) { %c0 = constant 0 : index @@ -632,7 +667,7 @@ } else { %v2 = "test.get_some_value"() : () -> i32 scf.yield %c2 : index - } + } scf.yield %res1 : index } else { %res2 = scf.if %arg0 -> index { @@ -641,7 +676,7 @@ } else { %v4 = "test.get_some_value"() : () -> i32 scf.yield %c4 : index - } + } scf.yield %res2 : index } return %res : index