diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -62,13 +62,8 @@ // TODO: Consider moving this functionality to RegionBranchOpInterface. bool insideMutuallyExclusiveBranches(Operation *a, Operation *b); -/// Promotes the loop body of a scf::ForallOp to its containing block if the -/// loop was known to have a single iteration. -LogicalResult promoteIfSingleIteration(PatternRewriter &rewriter, - scf::ForallOp forallOp); - /// Promotes the loop body of a scf::ForallOp to its containing block. -void promote(PatternRewriter &rewriter, scf::ForallOp forallOp); +void promote(RewriterBase &rewriter, scf::ForallOp forallOp); /// An owning vector of values, handy to return from functions. using ValueVector = SmallVector; diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -121,7 +121,7 @@ def ForOp : SCF_Op<"for", [AutomaticAllocationScope, DeclareOpInterfaceMethods, + "getSingleUpperBound", "promoteIfSingleIteration"]>, AllTypesMatch<["lowerBound", "upperBound", "step"]>, ConditionallySpeculatable, DeclareOpInterfaceMethods, @@ -361,6 +361,8 @@ def ForallOp : SCF_Op<"forall", [ AttrSizedOperandSegments, AutomaticAllocationScope, + DeclareOpInterfaceMethods, RecursiveMemoryEffects, SingleBlockImplicitTerminator<"scf::InParallelOp">, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -141,10 +141,6 @@ void collapseParallelLoops(scf::ParallelOp loops, ArrayRef> combinedDimensions); -/// Promotes the loop body of a scf::ForOp to its containing block if the loop -/// was known to have a single iteration. -LogicalResult promoteIfSingleIteration(scf::ForOp forOp); - /// Unrolls this for operation by the specified unroll factor. Returns failure /// if the loop cannot be unrolled either due to restrictions or due to invalid /// unroll factors. Requires positive loop bounds and step. If specified, diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h @@ -15,6 +15,10 @@ #include "mlir/IR/OpDefinition.h" +namespace mlir { +class RewriterBase; +} // namespace mlir + /// Include the generated interface declarations. #include "mlir/Interfaces/LoopLikeInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td --- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td +++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td @@ -48,6 +48,19 @@ op->moveBefore($_op); }] >, + InterfaceMethod<[{ + Promotes the loop body to its containing block if the loop is known to + have a single iteration. Returns "success" if the promotion was + successful. + }], + /*retTy=*/"::mlir::LogicalResult", + /*methodName=*/"promoteIfSingleIteration", + /*args=*/(ins "::mlir::RewriterBase &":$rewriter), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return ::mlir::failure(); + }] + >, InterfaceMethod<[{ If there is a single induction variable return it, otherwise return std::nullopt. diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -385,6 +385,35 @@ return OpFoldResult(getUpperBound()); } +/// Promotes the loop body of a forOp to its containing block if the forOp +/// it can be determined that the loop has a single iteration. +LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) { + std::optional tripCount = + constantTripCount(getLowerBound(), getUpperBound(), getStep()); + if (!tripCount.has_value() || tripCount != 1) + return failure(); + + // Replace all results with the yielded values. + auto yieldOp = cast(getBody()->getTerminator()); + rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands()); + + // Replace block arguments with lower bound (replacement for IV) and + // iter_args. + SmallVector bbArgReplacements; + bbArgReplacements.push_back(getLowerBound()); + bbArgReplacements.append(getIterOperands().begin(), getIterOperands().end()); + + // Move the loop body operations to the loop's containing block. + rewriter.inlineBlockBefore(getBody(), getOperation()->getBlock(), + getOperation()->getIterator(), bbArgReplacements); + + // Erase the old terminator and the loop. + rewriter.eraseOp(yieldOp); + rewriter.eraseOp(*this); + + return success(); +} + /// Prints the initialization list in the form of /// (%inner = %outer, %inner2 = %outer2, <...>) /// where 'inner' values are assumed to be region arguments and 'outer' values @@ -536,59 +565,64 @@ regions.push_back(RegionSuccessor(getResults())); } +Region &ForallOp::getLoopBody() { return getRegion(); } + /// Promotes the loop body of a forallOp to its containing block if it can be /// determined that the loop has a single iteration. -LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter, - scf::ForallOp forallOp) { +LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) { for (auto [lb, ub, step] : - llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), - forallOp.getMixedStep())) { + llvm::zip(getMixedLowerBound(), getMixedUpperBound(), getMixedStep())) { auto tripCount = constantTripCount(lb, ub, step); if (!tripCount.has_value() || *tripCount != 1) return failure(); } - promote(rewriter, forallOp); + promote(rewriter, *this); return success(); } /// Promotes the loop body of a scf::ForallOp to its containing block. -void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) { - IRMapping mapping; - mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter)); - mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs()); - for (auto &bodyOp : forallOp.getBody()->without_terminator()) - rewriter.clone(bodyOp, mapping); +void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { + OpBuilder::InsertionGuard g(rewriter); + scf::InParallelOp terminator = forallOp.getTerminator(); + + // Replace block arguments with lower bounds (replacements for IVs) and + // outputs. + SmallVector bbArgReplacements = forallOp.getLowerBound(rewriter); + bbArgReplacements.append(forallOp.getOutputs().begin(), + forallOp.getOutputs().end()); + // Move the loop body operations to the loop's containing block. + rewriter.inlineBlockBefore(forallOp.getBody(), forallOp->getBlock(), + forallOp->getIterator(), bbArgReplacements); + + // Replace the terminator with tensor.insert_slice ops. + rewriter.setInsertionPointAfter(forallOp); SmallVector results; results.reserve(forallOp.getResults().size()); - scf::InParallelOp terminator = forallOp.getTerminator(); for (auto &yieldingOp : terminator.getYieldingOps()) { auto parallelInsertSliceOp = cast(yieldingOp); Value dst = parallelInsertSliceOp.getDest(); Value src = parallelInsertSliceOp.getSource(); - - auto getMappedValues = [&](ValueRange values) { - return llvm::to_vector(llvm::map_range( - values, [&](Value value) { return mapping.lookupOrDefault(value); })); - }; - - Value srcVal = mapping.lookupOrDefault(src); - if (llvm::isa(srcVal.getType())) { + if (llvm::isa(src.getType())) { results.push_back(rewriter.create( - forallOp.getLoc(), dst.getType(), srcVal, - mapping.lookupOrDefault(dst), - getMappedValues(parallelInsertSliceOp.getOffsets()), - getMappedValues(parallelInsertSliceOp.getSizes()), - getMappedValues(parallelInsertSliceOp.getStrides()), + forallOp.getLoc(), dst.getType(), src, dst, + parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), + parallelInsertSliceOp.getStrides(), parallelInsertSliceOp.getStaticOffsets(), parallelInsertSliceOp.getStaticSizes(), parallelInsertSliceOp.getStaticStrides())); + } else { + llvm_unreachable("unsupported terminator"); } } - rewriter.replaceOp(forallOp, results); + rewriter.replaceAllUsesWith(forallOp.getResults(), results); + + // Erase the old terminator and the loop. + rewriter.eraseOp(terminator); + rewriter.eraseOp(forallOp); } LoopNest mlir::scf::buildLoopNest( diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -362,44 +362,6 @@ return builder.create(loc, sum, divisor); } -/// Helper to replace uses of loop carried values (iter_args) and loop -/// yield values while promoting single iteration scf.for ops. -static void replaceIterArgsAndYieldResults(scf::ForOp forOp) { - // Replace uses of iter arguments with iter operands (initial values). - auto iterOperands = forOp.getIterOperands(); - auto iterArgs = forOp.getRegionIterArgs(); - for (auto e : llvm::zip(iterOperands, iterArgs)) - std::get<1>(e).replaceAllUsesWith(std::get<0>(e)); - - // Replace uses of loop results with the values yielded by the loop. - auto outerResults = forOp.getResults(); - auto innerResults = forOp.getBody()->getTerminator()->getOperands(); - for (auto e : llvm::zip(outerResults, innerResults)) - std::get<0>(e).replaceAllUsesWith(std::get<1>(e)); -} - -/// Promotes the loop body of a forOp to its containing block if the forOp -/// it can be determined that the loop has a single iteration. -LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) { - std::optional tripCount = constantTripCount( - forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); - if (!tripCount.has_value() || tripCount != 1) - return failure(); - auto iv = forOp.getInductionVar(); - iv.replaceAllUsesWith(forOp.getLowerBound()); - - replaceIterArgsAndYieldResults(forOp); - - // Move the loop body operations, except for its terminator, to the loop's - // containing block. - auto *parentBlock = forOp->getBlock(); - forOp.getBody()->getTerminator()->erase(); - parentBlock->getOperations().splice(Block::iterator(forOp), - forOp.getBody()->getOperations()); - forOp.erase(); - return success(); -} - /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each @@ -469,6 +431,7 @@ // Compute tripCount = ceilDiv((upperBound - lowerBound), step) and populate // 'upperBoundUnrolled' and 'stepUnrolled' for static and dynamic cases. OpBuilder boundsBuilder(forOp); + IRRewriter rewriter(forOp.getContext()); auto loc = forOp.getLoc(); Value step = forOp.getStep(); Value upperBoundUnrolled; @@ -488,7 +451,7 @@ int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst); if (unrollFactor == 1) { - if (tripCount == 1 && failed(promoteIfSingleIteration(forOp))) + if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) return failure(); return success(); } @@ -553,7 +516,7 @@ } epilogueForOp->setOperands(epilogueForOp.getNumControlOperands(), epilogueForOp.getNumIterOperands(), results); - (void)promoteIfSingleIteration(epilogueForOp); + (void)epilogueForOp.promoteIfSingleIteration(rewriter); } // Create unrolled loop. @@ -573,7 +536,7 @@ }, annotateFn, iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. - (void)promoteIfSingleIteration(forOp); + (void)forOp.promoteIfSingleIteration(rewriter); return success(); } diff --git a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir --- a/mlir/test/Dialect/SCF/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/SCF/one-shot-bufferize.mlir @@ -545,9 +545,9 @@ %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index + // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1] // CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]]) %2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor) { - // CHECK: %[[subview:.*]] = memref.subview %[[arg2]][5] [%[[idx]]] [1] %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor to tensor // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview]] : memref) -> tensor @@ -591,9 +591,9 @@ // CHECK: %[[alloc1:.*]] = memref.alloc // CHECK: memref.copy %[[arg2]], %[[alloc1]] + // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1] // CHECK: scf.forall (%[[tidx:.*]]) in (%[[idx2]]) %2 = scf.forall (%arg3) in (%idx2) shared_outs(%o = %arg2) -> (tensor) { - // CHECK: %[[subview1:.*]] = memref.subview %[[alloc1]][5] [%[[idx]]] [1] %6 = tensor.extract_slice %o[5] [%idx] [%c1] : tensor to tensor // CHECK: linalg.fill ins(%{{.*}}) outs(%[[subview1]] : memref