diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h --- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h +++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h @@ -62,6 +62,14 @@ // TODO: Consider moving this functionality to RegionBranchOpInterface. bool insideMutuallyExclusiveBranches(Operation *a, Operation *b); +/// Promotes the loop body of a scf::ForallOp to its containing block if the +/// loop was known to have a single iteration. +LogicalResult promoteIfSingleIteration(PatternRewriter &rewriter, + scf::ForallOp forallOp); + +/// Promotes the loop body of a scf::ForallOp to its containing block. +void promote(PatternRewriter &rewriter, scf::ForallOp forallOp); + /// An owning vector of values, handy to return from functions. using ValueVector = SmallVector; using LoopVector = SmallVector; diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -128,6 +128,11 @@ getValuesSortedByKey(ArrayRef keys, ArrayRef values, llvm::function_ref compare); +/// Return the number of iterations for a loop with a lower bound `lb`, upper +/// bound `ub` and step `step`. +std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, + OpFoldResult step); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_STATICVALUEUTILS_H diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -534,6 +534,61 @@ regions.push_back(RegionSuccessor(getResults())); } +/// Promotes the loop body of a forallOp to its containing block if it can be +/// determined that the loop has a single iteration. +LogicalResult mlir::scf::promoteIfSingleIteration(PatternRewriter &rewriter, + scf::ForallOp forallOp) { + for (auto [lb, ub, step] : + llvm::zip(forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep())) { + auto tripCount = constantTripCount(lb, ub, step); + if (!tripCount.has_value() || *tripCount != 1) + return failure(); + } + + promote(rewriter, forallOp); + return success(); +} + +/// Promotes the loop body of a scf::ForallOp to its containing block. +void mlir::scf::promote(PatternRewriter &rewriter, scf::ForallOp forallOp) { + IRMapping mapping; + mapping.map(forallOp.getInductionVars(), forallOp.getLowerBound(rewriter)); + mapping.map(forallOp.getOutputBlockArguments(), forallOp.getOutputs()); + for (auto &bodyOp : forallOp.getBody()->without_terminator()) + rewriter.clone(bodyOp, mapping); + + SmallVector results; + results.reserve(forallOp.getResults().size()); + scf::InParallelOp terminator = forallOp.getTerminator(); + for (auto &yieldingOp : terminator.getYieldingOps()) { + auto parallelInsertSliceOp = + cast(yieldingOp); + + Value dst = parallelInsertSliceOp.getDest(); + Value src = parallelInsertSliceOp.getSource(); + + auto getMappedValues = [&](ValueRange values) { + return llvm::to_vector(llvm::map_range( + values, [&](Value value) { return mapping.lookupOrDefault(value); })); + }; + + Value srcVal = mapping.lookupOrDefault(src); + if (srcVal.getType().isa()) { + results.push_back(rewriter.create( + forallOp.getLoc(), dst.getType(), srcVal, + mapping.lookupOrDefault(dst), + getMappedValues(parallelInsertSliceOp.getOffsets()), + getMappedValues(parallelInsertSliceOp.getSizes()), + getMappedValues(parallelInsertSliceOp.getStrides()), + parallelInsertSliceOp.getStaticOffsets(), + parallelInsertSliceOp.getStaticSizes(), + parallelInsertSliceOp.getStaticStrides())); + } + } + rewriter.replaceOp(forallOp, results); +} + LoopNest mlir::scf::buildLoopNest( OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ValueRange iterArgs, @@ -1452,16 +1507,99 @@ dispatchIndexOpFoldResults(mixedStep, dynamicStep, staticStep); op.getDynamicStepMutable().assign(dynamicStep); op.setStaticStep(staticStep); + + op->setAttr(ForallOp::getOperandSegmentSizeAttr(), + rewriter.getDenseI32ArrayAttr( + {static_cast(dynamicLowerBound.size()), + static_cast(dynamicUpperBound.size()), + static_cast(dynamicStep.size()), + static_cast(op.getNumResults())})); }); return success(); } }; +struct ForallOpSingleOrZeroIterationDimsFolder + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForallOp op, + PatternRewriter &rewriter) const override { + // Do not fold dimensions if they are mapped to processing units. + if (op.getMapping().has_value()) + return failure(); + Location loc = op.getLoc(); + + // Compute new loop bounds that omit all single-iteration loop dimensions. + SmallVector newMixedLowerBounds, newMixedUpperBounds, + newMixedSteps; + IRMapping mapping; + for (auto [lb, ub, step, iv] : + llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), + op.getMixedStep(), op.getInductionVars())) { + auto numIterations = constantTripCount(lb, ub, step); + if (numIterations.has_value()) { + // Remove the loop if it performs zero iterations. + if (*numIterations == 0) { + rewriter.replaceOp(op, op.getOutputs()); + return success(); + } + // Replace the loop induction variable by the lower bound if the loop + // performs a single iteration. Otherwise, copy the loop bounds. + if (*numIterations == 1) { + mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); + continue; + } + } + newMixedLowerBounds.push_back(lb); + newMixedUpperBounds.push_back(ub); + newMixedSteps.push_back(step); + } + // Exit if none of the loop dimensions perform a single iteration. + if (newMixedLowerBounds.size() == static_cast(op.getRank())) { + return rewriter.notifyMatchFailure( + op, "no dimensions have 0 or 1 iterations"); + } + + // All of the loop dimensions perform a single iteration. Inline loop body. + if (newMixedLowerBounds.empty()) { + promote(rewriter, op); + return success(); + } + + // Replace the loop by a lower-dimensional loop. + ForallOp newOp; + newOp = rewriter.create(loc, newMixedLowerBounds, + newMixedUpperBounds, newMixedSteps, + op.getOutputs(), std::nullopt, nullptr); + newOp.getBodyRegion().getBlocks().clear(); + // The new loop needs to keep all attributes from the old one, except for + // "operand_segment_sizes" and static loop bound attributes which capture + // the outdated information of the old iteration domain. + SmallVector elidedAttrs{newOp.getOperandSegmentSizesAttrName(), + newOp.getStaticLowerBoundAttrName(), + newOp.getStaticUpperBoundAttrName(), + newOp.getStaticStepAttrName()}; + for (const auto &namedAttr : op->getAttrs()) { + if (llvm::is_contained(elidedAttrs, namedAttr.getName())) + continue; + rewriter.updateRootInPlace(newOp, [&]() { + newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + }); + } + rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().begin(), mapping); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + } // namespace void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// @@ -2615,41 +2753,37 @@ namespace { // Collapse loop dimensions that perform a single iteration. -struct CollapseSingleIterationLoops : public OpRewritePattern { +struct ParallelOpSingleOrZeroIterationDimsFolder + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ParallelOp op, PatternRewriter &rewriter) const override { - IRMapping mapping; + Location loc = op.getLoc(); + // Compute new loop bounds that omit all single-iteration loop dimensions. - SmallVector newLowerBounds; - SmallVector newUpperBounds; - SmallVector newSteps; - newLowerBounds.reserve(op.getLowerBound().size()); - newUpperBounds.reserve(op.getUpperBound().size()); - newSteps.reserve(op.getStep().size()); - for (auto [lowerBound, upperBound, step, iv] : + SmallVector newLowerBounds, newUpperBounds, newSteps; + IRMapping mapping; + for (auto [lb, ub, step, iv] : llvm::zip(op.getLowerBound(), op.getUpperBound(), op.getStep(), op.getInductionVars())) { - // Collect the statically known loop bounds. - auto lowerBoundConstant = - dyn_cast_or_null(lowerBound.getDefiningOp()); - auto upperBoundConstant = - dyn_cast_or_null(upperBound.getDefiningOp()); - auto stepConstant = - dyn_cast_or_null(step.getDefiningOp()); - // Replace the loop induction variable by the lower bound if the loop - // performs a single iteration. Otherwise, copy the loop bounds. - if (lowerBoundConstant && upperBoundConstant && stepConstant && - (upperBoundConstant.value() - lowerBoundConstant.value()) > 0 && - (upperBoundConstant.value() - lowerBoundConstant.value()) <= - stepConstant.value()) { - mapping.map(iv, lowerBound); - } else { - newLowerBounds.push_back(lowerBound); - newUpperBounds.push_back(upperBound); - newSteps.push_back(step); + auto numIterations = constantTripCount(lb, ub, step); + if (numIterations.has_value()) { + // Remove the loop if it performs zero iterations. + if (*numIterations == 0) { + rewriter.replaceOp(op, op.getInitVals()); + return success(); + } + // Replace the loop induction variable by the lower bound if the loop + // performs a single iteration. Otherwise, copy the loop bounds. + if (*numIterations == 1) { + mapping.map(iv, getValueOrCreateConstantIndexOp(rewriter, loc, lb)); + continue; + } } + newLowerBounds.push_back(lb); + newUpperBounds.push_back(ub); + newSteps.push_back(step); } // Exit if none of the loop dimensions perform a single iteration. if (newLowerBounds.size() == op.getLowerBound().size()) @@ -2694,23 +2828,6 @@ } }; -/// Removes parallel loops in which at least one lower/upper bound pair consists -/// of the same values - such loops have an empty iteration domain. -struct RemoveEmptyParallelLoops : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ParallelOp op, - PatternRewriter &rewriter) const override { - for (auto dim : llvm::zip(op.getLowerBound(), op.getUpperBound())) { - if (std::get<0>(dim) == std::get<1>(dim)) { - rewriter.replaceOp(op, op.getInitVals()); - return success(); - } - } - return failure(); - } -}; - struct MergeNestedParallelLoops : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -2773,8 +2890,9 @@ void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -381,18 +381,12 @@ /// Promotes the loop body of a forOp to its containing block if the forOp /// it can be determined that the loop has a single iteration. LogicalResult mlir::promoteIfSingleIteration(scf::ForOp forOp) { - auto lbCstOp = forOp.getLowerBound().getDefiningOp(); - auto ubCstOp = forOp.getUpperBound().getDefiningOp(); - auto stepCstOp = forOp.getStep().getDefiningOp(); - if (!lbCstOp || !ubCstOp || !stepCstOp || lbCstOp.value() < 0 || - ubCstOp.value() < 0 || stepCstOp.value() < 0) - return failure(); - int64_t tripCount = - mlir::ceilDiv(ubCstOp.value() - lbCstOp.value(), stepCstOp.value()); - if (tripCount != 1) + std::optional tripCount = constantTripCount( + forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep()); + if (!tripCount.has_value() || tripCount != 1) return failure(); auto iv = forOp.getInductionVar(); - iv.replaceAllUsesWith(lbCstOp); + iv.replaceAllUsesWith(forOp.getLowerBound()); replaceIterArgsAndYieldResults(forOp); diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/MathExtras.h" #include "llvm/ADT/APSInt.h" namespace mlir { @@ -228,4 +229,24 @@ return getValuesSortedByKeyImpl(keys, values, compare); } +/// Return the number of iterations for a loop with a lower bound `lb`, upper +/// bound `ub` and step `step`. +std::optional constantTripCount(OpFoldResult lb, OpFoldResult ub, + OpFoldResult step) { + if (lb == ub) + return 0; + + std::optional lbConstant = getConstantIntValue(lb); + if (!lbConstant) + return std::nullopt; + std::optional ubConstant = getConstantIntValue(ub); + if (!ubConstant) + return std::nullopt; + std::optional stepConstant = getConstantIntValue(step); + if (!stepConstant) + return std::nullopt; + + return mlir::ceilDiv(*ubConstant - *lbConstant, *stepConstant); +} + } // namespace mlir diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1544,3 +1544,110 @@ return %result : tensor } // CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10) + +// ----- + +func.func @inline_forall_loop(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (%c0, %c0) to (%c1, %c1) + step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1] + : tensor<8x8xf32> to tensor<2x3xf32> + %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>) + -> tensor<2x3xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1] + : tensor<2x3xf32> into tensor<8x8xf32> + } + } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @inline_forall_loop +// CHECK-NOT: scf.forall +// CHECK: %[[OUT:.*]] = tensor.empty + +// CHECK-NEXT: %[[SLICE:.*]] = tensor.extract_slice %[[OUT]] +// CHECK-SAME: : tensor<8x8xf32> to tensor<2x3xf32> + +// CHECK-NEXT: %[[FILL:.*]] = linalg.fill +// CHECK-SAME: outs(%[[SLICE]] + +// CHECK-NEXT: tensor.insert_slice %[[FILL]] +// CHECK-SAME: : tensor<2x3xf32> into tensor<8x8xf32> + +// ----- + +func.func @do_not_inline_distributed_forall_loop( + %in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8) + shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1] + : tensor<8x8xf32> to tensor<2x3xf32> + %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>) + -> tensor<2x3xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1] + : tensor<2x3xf32> into tensor<8x8xf32> + } + }{ mapping = [#gpu.thread, #gpu.thread] } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @do_not_inline_distributed_forall_loop +// CHECK: scf.forall + +// ----- + +func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (0, %c0) to (1, %c16) + step (8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>) + -> tensor<8x8xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1] + : tensor<8x8xf32> into tensor<8x8xf32> + } + } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @collapse_one_dim_parallel +// CHECK: scf.forall (%[[ARG:.*]]) = (0) to (16) step (8) +// CHECK: linalg.fill +// CHECK: tensor.parallel_insert_slice + +// ----- + +func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (%c0, %c16) to (%c1, %c16) + step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>) + -> tensor<8x8xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1] + : tensor<8x8xf32> into tensor<8x8xf32> + } + } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @remove_empty_forall +// CHECK-NOT: scf.forall +// CHECK: %[[EMPTY:.*]] = tensor.empty +// CHECK: return %[[EMPTY]] + diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir --- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir +++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir @@ -86,7 +86,7 @@ // CHECK-LABEL: func.func @parallel_insert_slice // CHECK-NOT: tensor.insert_slice -// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<1x2xf32> +// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[0, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<1x2xf32> func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index