diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Parser/Parser.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/StringSet.h" @@ -226,78 +225,167 @@ // FuseIntoContainingOp //===----------------------------------------------------------------------===// -static FailureOr> tileAndFuse(Operation *producerOp, - Operation *containingOp, - RewriterBase &rewriter) { +/// Find the first "extract" user of `producerOp` and tile it right before its +/// use. The tiled op is fused under the `containingOp`. +/// Return this fused op on success or nullptr if anything fails. +static Operation *tileAndFuseFirstExtractUse(Operation *producerOp, + Operation *containingOp, + RewriterBase &rewriter) { auto tileableProducer = dyn_cast(producerOp); if (!tileableProducer) - return failure(); + return nullptr; // Search the producer slices accessed within the containing operation. - // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe + // TODO: Generalize to more extract/insert/parallel_insert triples, maybe // evolve into an interface. - SmallVector sliceOps; - for (Operation *user : tileableProducer->getUsers()) { + auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { auto sliceOp = dyn_cast(user); - if (!sliceOp) - continue; - if (!containingOp->isProperAncestor(sliceOp)) - continue; - sliceOps.push_back(sliceOp); - } + return sliceOp && containingOp->isProperAncestor(sliceOp); + }); - // Check for a non-empty list of fusion opportunities. - if (sliceOps.empty()) - return failure(); + // Find a fusion opportunity. + if (it == tileableProducer->getUsers().end()) + return nullptr; + auto sliceOpToTile = cast(*it); // Try to fuse the producer in-place. - SmallVector fusedOps; - for (tensor::ExtractSliceOp sliceOp : sliceOps) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(sliceOp); - - // Tile the producer. - FailureOr tiledProducer = tileableProducer.generateResultTileValue( - rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(), - sliceOp.getMixedSizes()); - if (failed(tiledProducer)) - return failure(); - fusedOps.push_back(tiledProducer->getDefiningOp()); - } + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sliceOpToTile); + + // Tile the producer. + FailureOr tiledProducer = tileableProducer.generateResultTileValue( + rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(), + sliceOpToTile.getMixedSizes()); + if (failed(tiledProducer)) + return nullptr; + + // Replace the extract op. + Operation *fusedOp = tiledProducer->getDefiningOp(); + rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0)); + return fusedOp; +} + +/// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure +/// it is exactly the `containingOp`, otherwise bail. +/// Then, find the first "extract" user of the tied block argument and tile it +/// right before its "extract" use. The tiled op is fused under the +/// `containingOp`. +/// Return this fused op on success or nullptr if anything fails. +static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( + Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) { + + auto tileableProducer = dyn_cast(producerOp); + if (!tileableProducer) + return nullptr; + + // Search the first use by a "scf::ForeachThreadOp" user. + scf::ForeachThreadOp foreachThreadOp; + auto itProducerUses = + llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) { + foreachThreadOp = dyn_cast(use.getOwner()); + return foreachThreadOp; + }); + // If it's not from the containing op, return. + if (!foreachThreadOp || foreachThreadOp != containingOp) + return nullptr; + + // Search the producer slices accessed within the containing + // operation. + // TODO: Generalize to more extract/insert/parallel_insert triples. + // Maybe evolve into an interface. + OpOperand *pUse = &(*itProducerUses); + BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse); + + // Search the producer slices accessed within the containing operation. + // TODO: Generalize to more extract/insert/parallel_insert triples, maybe + // evolve into an interface. + auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { + auto sliceOp = dyn_cast(user); + return sliceOp && containingOp->isProperAncestor(sliceOp); + }); + + // Find a fusion opportunity. + if (itBBArgUsers == bbArg.getUsers().end()) + return nullptr; + auto sliceOpToTile = cast(*itBBArgUsers); + + // Ensure `tileableProducer` has exactly one destination operand that we can + // replace the ForeachThreadOp bbArg with. + auto destinationOperands = tileableProducer.getDestinationOperands(rewriter); + if (destinationOperands.size() != 1) + return nullptr; + + // Try to fuse the producer in-place. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sliceOpToTile); + + // Replace the use in the tileableProducer before tiling: clone, replace and + // then tile. + BlockAndValueMapping bvm; + bvm.map(destinationOperands.front(), bbArg); + auto tileableProducerClone = + cast(rewriter.clone(*tileableProducer, bvm)); + auto scopeGuard = + llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); + + // Tile the producer. + FailureOr tiledProducer = + tileableProducerClone.generateResultTileValue( + rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(), + sliceOpToTile.getMixedSizes()); + if (failed(tiledProducer)) + return nullptr; // Replace the extract op. - for (const auto &en : enumerate(sliceOps)) - rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0)); - return fusedOps; + Operation *fusedOp = tiledProducer->getDefiningOp(); + rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0)); + + // Replace the use in containingOp. + rewriter.updateRootInPlace(containingOp, [&]() { + containingOp->setOperand(pUse->getOperandNumber(), + destinationOperands.front()); + }); + + return fusedOp; } -static FailureOr> -cloneAndFuse(Operation *producerOp, Operation *containingOp, - RewriterBase &rewriter) { +static Operation *cloneAndFuseFirstUse(Operation *producerOp, + Operation *containingOp, + RewriterBase &rewriter) { // Gather all uses inside the containing op. SmallVector uses; - for (OpResult result : producerOp->getOpResults()) - for (OpOperand &use : result.getUses()) - if (containingOp->isProperAncestor(use.getOwner())) + for (OpResult result : producerOp->getOpResults()) { + for (OpOperand &use : result.getUses()) { + if (containingOp->isProperAncestor(use.getOwner())) { uses.push_back(&use); + continue; + } + // Cannot clone and fuse if the use is by the containing op itself: fail + // immediately. + if (containingOp == use.getOwner()) + return nullptr; + } + } // Check for a non-empty list of fusion opportunities. if (uses.empty()) - return failure(); + return nullptr; // Clone and fuse inside the containing op. - SmallVector fusedOps; - for (OpOperand *use : uses) { - unsigned resultNumber = use->get().cast().getResultNumber(); - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(use->getOwner()); - Operation *cloned = rewriter.clone(*producerOp); - rewriter.updateRootInPlace( - use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); }); - fusedOps.push_back(cloned); - } - - return fusedOps; + Operation *fusedOp = nullptr; + OpOperand *use = uses.front(); + // Parallel insert slice is not a valid clone destination. + // TODO: Generalize to other type of ops. + assert(!isa(use->getOwner()) && + "Parallel insert slice is not a valid clone destination"); + unsigned resultNumber = use->get().cast().getResultNumber(); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(use->getOwner()); + fusedOp = rewriter.clone(*producerOp); + rewriter.updateRootInPlace( + use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); + + return fusedOp; } DiagnosedSilenceableFailure @@ -312,7 +400,7 @@ } for (Operation *producerOp : producerOps) { if (producerOp->getNumResults() != 1) { - Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); diag << "op with != 1 results not supported"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } @@ -331,15 +419,17 @@ auto getNextProducer = [&]() -> FailureOr { for (const auto &it : enumerate(remainingProducers)) { Operation *producerOp = it.value(); - bool hasUseInContainingOp = - any_of(producerOp->getUsers(), [&](Operation *op) { - return containingOp->isProperAncestor(op); + // The containing op may be a user of producerOp: use isAncestor. + int64_t numUsesInContainingOp = + llvm::count_if(producerOp->getUsers(), [&](Operation *op) { + return containingOp->isAncestor(op); }); - // TODO: When resolving the TODO below (no duplicate ops), take an op that - // has no use among the remaining producers. This is a topological + // TODO: When resolving the TODO below (no duplicate ops), take an op + // that has no use among the remaining producers. This is a topological // sorting. - if (hasUseInContainingOp) { - remainingProducers.erase(remainingProducers.begin() + it.index()); + if (numUsesInContainingOp > 0) { + if (numUsesInContainingOp == 1) + remainingProducers.erase(remainingProducers.begin() + it.index()); return producerOp; } } @@ -350,29 +440,42 @@ while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { - Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note); + Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); diag << "could not fuse ops into container"; return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } Operation *producerOp = *nextProducer; - // TODO: If there are multiple uses of the producer in the containing op, we - // currently tile/clone the op multiple times (once per use). In some cases, - // we can tile/clone once and reuse the value for each use. Futhermore, - // producers should then be traversed according to a topological sorting. - auto tiled = tileAndFuse(producerOp, containingOp, rewriter); - if (succeeded(tiled)) - fusedOps.append(*tiled); - - auto cloned = cloneAndFuse(producerOp, containingOp, rewriter); - if (succeeded(cloned)) - fusedOps.append(*cloned); - - if (failed(tiled) && failed(cloned)) { - Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note); - diag << "could not fuse into containing op"; - return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + // TODO: If there are multiple uses of the producer in the containing op, + // we currently tile/clone the op multiple times (once per use). In some + // cases, we can tile/clone once and reuse the value for each use. + // Futhermore, producers should then be traversed according to a + // topological sorting. + Operation *tiled = + tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter); + if (tiled) { + fusedOps.push_back(tiled); + continue; + } + + Operation *tiledContainingOpOperand = + tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( + producerOp, containingOp, rewriter); + if (tiledContainingOpOperand) { + fusedOps.push_back(tiledContainingOpOperand); + continue; } + + Operation *cloned = + cloneAndFuseFirstUse(producerOp, containingOp, rewriter); + if (cloned) { + fusedOps.push_back(cloned); + continue; + } + + Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); + diag << "could not fuse " << *producerOp << "into " << *containingOp; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); } results.set(getFusedOp().cast(), fusedOps); @@ -626,9 +729,9 @@ extractFromI64ArrayAttr(getPaddingDimensions()); if (any_of(paddingDimensions, [](int64_t paddingDimension) { return paddingDimension < 0; })) { - return emitOpError() - << "expects padding_dimensions to contain positive integers, found " - << getPaddingDimensions(); + return emitOpError() << "expects padding_dimensions to contain positive " + "integers, found " + << getPaddingDimensions(); } SmallVector hoistPaddings = @@ -699,8 +802,8 @@ transform::TransformState &state) { LinalgTilingOptions tilingOptions; tilingOptions.scalarizeDynamicDims(); - // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile - // sizes and asserts that it is not already set. + // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the + // tile sizes and asserts that it is not already set. SmallVector emptyTileSizes; LinalgTilingPattern pattern(getContext(), tilingOptions); SimpleRewriter rewriter(getContext()); @@ -847,8 +950,8 @@ if ((static_cast(getStaticSplitPoint()) != ShapedType::kDynamicSize) ^ (getDynamicSplitPoint() == nullptr)) { - return emitOpError() - << "expects either a dynamic or a static split point to be provided"; + return emitOpError() << "expects either a dynamic or a static split " + "point to be provided"; } return success(); } @@ -1225,8 +1328,8 @@ //===----------------------------------------------------------------------===// namespace { -/// Registers new ops and declares PDL as dependent dialect since the additional -/// ops are using PDL types for operands and results. +/// Registers new ops and declares PDL as dependent dialect since the +/// additional ops are using PDL types for operands and results. class LinalgTransformDialectExtension : public transform::TransformDialectExtension< LinalgTransformDialectExtension> { diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -99,3 +99,51 @@ } } } + +// ----- + +#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)> +#map1 = affine_map<(d0)[s0] -> (d0 * s0)> +#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)> + +module { + // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg + // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index + // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor + // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor + func.func @fuse_tileable_op_through_bbarg(%arg0: index, %arg1: tensor, %arg2: tensor) -> tensor { + %cst = arith.constant 4.200000e+01 : f32 + %c0 = arith.constant 0 : index + %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor) -> tensor + %d0 = tensor.dim %arg1, %c0 : tensor + %1 = affine.apply #map0()[%d0, %arg0] + + // CHECK: scf.foreach_thread {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor) { + %2 = scf.foreach_thread (%arg3) in (%1) shared_outs(%o = %0) -> (tensor) { + %3 = affine.apply #map1(%arg3)[%arg0] + %4 = affine.min #map2(%arg3)[%d0, %arg0] + %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[BBARGOUT]][%{{.*}}] [%{{.*}}] [{{.*}}] + // CHECK: %[[T1:.*]] = linalg.fill {{.*}} outs(%[[T0]] + %6 = tensor.extract_slice %arg1[%3] [%4] [1] : tensor to tensor + + // CHECK: %[[T2:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T1]] + %7 = linalg.elemwise_unary ins(%6 : tensor) outs(%5 : tensor) -> tensor + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor into tensor + } + } + // CHECK: } + func.return %2 : tensor + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1 + + // linalg.fill is tileable. The op is tiled and fused. + transform.structured.fuse_into_containing_op %0 into %1 + } +}