diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1463,11 +1463,122 @@ } }; +struct ForallOpSingleOrZeroIterationDimsFolder + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ForallOp op, + PatternRewriter &rewriter) const override { + // Do not fold dimensions if they are mapped to processing units. + if (op.getMapping().has_value()) + return failure(); + Location loc = op.getLoc(); + + // Compute new loop bounds that omit all single-iteration loop dimensions. + SmallVector newMixedLowerBounds, newMixedUpperBounds, + newMixedSteps; + IRMapping mapping; + for (auto [lowerBound, upperBound, step, iv] : + llvm::zip(op.getMixedLowerBound(), op.getMixedUpperBound(), + op.getMixedStep(), op.getInductionVars())) { + // Collect the statically known loop bounds. + std::optional lowerBoundConstant = + getConstantIntValue(lowerBound); + std::optional upperBoundConstant = + getConstantIntValue(upperBound); + std::optional stepConstant = getConstantIntValue(step); + // Remove the loop if it performs zero iterations. + if (lowerBoundConstant && upperBoundConstant && + *lowerBoundConstant == *upperBoundConstant) { + rewriter.replaceOp(op, op.getOutputs()); + return success(); + } + // Replace the loop induction variable by the lower bound if the loop + // performs a single iteration. Otherwise, copy the loop bounds. + if (lowerBoundConstant && upperBoundConstant && stepConstant && + (*upperBoundConstant - *lowerBoundConstant) > 0 && + (*upperBoundConstant - *lowerBoundConstant) <= *stepConstant) { + mapping.map(iv, + getValueOrCreateConstantIndexOp(rewriter, loc, lowerBound)); + } else { + newMixedLowerBounds.push_back(lowerBound); + newMixedUpperBounds.push_back(upperBound); + newMixedSteps.push_back(step); + } + } + // Exit if none of the loop dimensions perform a single iteration. + if (newMixedLowerBounds.size() == static_cast(op.getRank())) { + return failure(); + } + + // All of the loop dimensions perform a single iteration. Inline loop body. + if (newMixedLowerBounds.empty()) { + mapping.map(op.getOutputBlockArguments(), op.getOutputs()); + for (auto &bodyOp : op.getBody()->without_terminator()) + rewriter.clone(bodyOp, mapping); + SmallVector results; + results.reserve(op.getResults().size()); + scf::InParallelOp terminator = op.getTerminator(); + for (auto &yieldingOp : terminator.getYieldingOps()) { + auto parallelInsertSliceOp = + cast(yieldingOp); + + Value dst = parallelInsertSliceOp.getDest(); + Value src = parallelInsertSliceOp.getSource(); + + auto getMappedValues = [&](ValueRange values) { + return llvm::to_vector(llvm::map_range(values, [&](Value value) { + return mapping.lookupOrDefault(value); + })); + }; + + Value srcVal = mapping.lookupOrDefault(src); + if (srcVal.getType().isa()) { + results.push_back(rewriter.create( + op.getLoc(), dst.getType(), srcVal, mapping.lookupOrDefault(dst), + getMappedValues(parallelInsertSliceOp.getOffsets()), + getMappedValues(parallelInsertSliceOp.getSizes()), + getMappedValues(parallelInsertSliceOp.getStrides()), + parallelInsertSliceOp.getStaticOffsets(), + parallelInsertSliceOp.getStaticSizes(), + parallelInsertSliceOp.getStaticStrides())); + } + } + rewriter.replaceOp(op, results); + return success(); + } + + // Replace the loop by a lower-dimensional loop. + ForallOp newOp; + newOp = rewriter.create(loc, newMixedLowerBounds, + newMixedUpperBounds, newMixedSteps, + op.getOutputs(), std::nullopt, nullptr); + newOp.getBodyRegion().getBlocks().clear(); + // The new loop needs to keep all attributes from the old one, except for + // "operand_segment_sizes" and static loop bound attributes which capture + // the outdated information of the old iteration domain. + SmallVector elidedAttrs{newOp.getOperandSegmentSizesAttrName(), + newOp.getStaticLowerBoundAttrName(), + newOp.getStaticUpperBoundAttrName(), + newOp.getStaticStepAttrName()}; + for (const auto &namedAttr : op->getAttrs()) { + if (llvm::is_contained(elidedAttrs, namedAttr.getName())) + continue; + newOp->setAttr(namedAttr.getName(), namedAttr.getValue()); + } + rewriter.cloneRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().begin(), mapping); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + } // namespace void ForallOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -1544,3 +1544,110 @@ return %result : tensor } // CHECK: forall (%{{.*}}, %{{.*}}) in (%{{.*}}, 10) + +// ----- + +func.func @inline_forall_loop(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (%c0, %c0) to (%c1, %c1) + step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1] + : tensor<8x8xf32> to tensor<2x3xf32> + %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>) + -> tensor<2x3xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1] + : tensor<2x3xf32> into tensor<8x8xf32> + } + } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @inline_forall_loop +// CHECK-NOT: scf.forall +// CHECK: %[[OUT:.*]] = tensor.empty + +// CHECK-NEXT: %[[SLICE:.*]] = tensor.extract_slice %[[OUT]] +// CHECK-SAME: : tensor<8x8xf32> to tensor<2x3xf32> + +// CHECK-NEXT: %[[FILL:.*]] = linalg.fill +// CHECK-SAME: outs(%[[SLICE]] + +// CHECK-NEXT: tensor.insert_slice %[[FILL]] +// CHECK-SAME: : tensor<2x3xf32> into tensor<8x8xf32> + +// ----- + +func.func @do_not_inline_distributed_forall_loop( + %in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (0, 0) to (1, 1) step (8, 8) + shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %slice = tensor.extract_slice %out_[%i, %j] [2, 3] [1, 1] + : tensor<8x8xf32> to tensor<2x3xf32> + %fill = linalg.fill ins(%cst : f32) outs(%slice : tensor<2x3xf32>) + -> tensor<2x3xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [2, 3] [1, 1] + : tensor<2x3xf32> into tensor<8x8xf32> + } + }{ mapping = [#gpu.thread, #gpu.thread] } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @do_not_inline_distributed_forall_loop +// CHECK: scf.forall + +// ----- + +func.func @collapse_one_dim_parallel(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (0, %c0) to (1, %c16) + step (8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>) + -> tensor<8x8xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1] + : tensor<8x8xf32> into tensor<8x8xf32> + } + } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @collapse_one_dim_parallel +// CHECK: scf.forall (%[[ARG:.*]]) = (0) to (16) step (8) +// CHECK: linalg.fill +// CHECK: tensor.parallel_insert_slice + +// ----- + +func.func @remove_empty_forall(%in: tensor<8x8xf32>) -> tensor<8x8xf32> { + %c8 = arith.constant 8 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c16 = arith.constant 16 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%i, %j) = (%c0, %c16) to (%c1, %c16) + step (%c8, %c8) shared_outs (%out_ = %0) -> (tensor<8x8xf32>) { + %fill = linalg.fill ins(%cst : f32) outs(%out_ : tensor<8x8xf32>) + -> tensor<8x8xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %fill into %out_[%i, %j] [8, 8] [1, 1] + : tensor<8x8xf32> into tensor<8x8xf32> + } + } + return %1 : tensor<8x8xf32> +} +// CHECK-LABEL: @remove_empty_forall +// CHECK-NOT: scf.forall +// CHECK: %[[EMPTY:.*]] = tensor.empty +// CHECK: return %[[EMPTY]] + diff --git a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir --- a/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir +++ b/mlir/test/Dialect/Tensor/fold-consecutive-insert-extract-slice.mlir @@ -86,7 +86,7 @@ // CHECK-LABEL: func.func @parallel_insert_slice // CHECK-NOT: tensor.insert_slice -// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<1x2xf32> +// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[0, %{{.*}}] [1, 1] [1, 1] : tensor into tensor<1x2xf32> func.func @parallel_insert_slice(%t0: tensor<1x2xf32>, %t1: tensor, %t2: tensor<1x1xf32>) -> tensor<1x2xf32> { %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index