diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -63,6 +63,9 @@ struct SCFTilingResult { Operation *tiledOp; SmallVector loops; + llvm::SmallBitVector tiledLoops; + SmallVector tileOffsets; + SmallVector tileSizes; }; /// Pattern to tile an op that implements the `TilingInterface` using @@ -106,23 +109,92 @@ /// it is legal to fuse the fill with the matmul only if the matmul is tiled /// along the parallel dimensions and not the reduction dimension, i.e. the tile /// size for the reduction dimension should be 0. + +struct SCFTileAndFuseOptions { + SCFTilingOptions tilingOptions; + SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) { + tilingOptions = options; + return *this; + } + + /// When the access pattern between producer and consumer is such that + /// the producer is fused with the consumer without the producer being + /// recomputed, then it is possible to yield the value of the producer from + /// within the loop nest. For example, consider + // + /// ```mlir + /// %0 = linalg.matmul ins(%lhs0, %rhs0) outs(%init0) + /// %1 = linalg.matmul ins(%0, %rhs1) outs(%init1) + /// ``` + /// + /// If the tile sizes chosen are such that the second `linalg.matmul` + /// is tiled along the outer two dimensions of the op, then fusing + /// the first `linalg.matmul` using tile and fuse results in + /// recomputation of parts of the fused producer during computation + /// of different tiles of the consumer. Instead if only the outer dimension + /// is chosen, then the producer is not recomputed. + /// + /// ```mlir + /// scf.for %iv0 = + /// %lhs0_slice = tensor.extract_slice %lhs0[%iv0, 0] + /// %rhs0_slice = tensor.extract_slice %rhs0[0, 0] + /// %init0_slice = tensor.extract_slice %init0[%iv0, 0] + /// %0 = linalg.matmul ins(%lhs0_slice, %rhs0_slice) outs(%init0_slice) + /// %rhs1_slice = tensor.extract_slice %rhs1[0, 0] + /// %init1_slice = tensor.extract_slice %init1[%iv0, 0] + /// %1 = linalg.matmul ins(%0, %rhs1_slice) outs(%init1_slice) + /// ``` + /// + /// If needed the value of the untiled first matmul can be reconstructed + /// using the tiled and fused operation (similar to how the replacement of the + /// consumer is done). + /// It is unclear how to automatically determine when the producer is + /// recomputed and when it is not (especially through the `TilingInterface`). + /// So for now this is left as a option for the caller which is expected to + /// set the tile sizes appropriately to ensure the producer is not recomputed. + /// With this based on the uses of the untiled producer, a replacement value + /// for this is yielded by the tiled loop nest. + /// + /// One way to ensure this in the producer is when the producer and consumer + /// are LinalgOps, is + /// ```cpp + /// bool canProducerBeFusedWithoutRedundantComputation( + /// llvm::SmallBitVector tiledLoops, OpOperand *fusedOperand) { + /// auto consumerOp = cast(fusedOperand->getOwner()); + /// AffineMap consumerIndexingMap = + /// consumerOp.getTiedIndexingMap(fusedOperand); + /// AffineMap projectedConsumerMap = + /// getProjectedMap(consumerOp, tiledLoops); + /// AffineMap projectedProducerMap = + /// getProjectedMap(producerOp, tiledLoops); + /// return projectedConsumerMap.isIdentity(); + /// } + /// ``` + bool producerCanBeFusedWithoutRedundantComputations = false; + SCFTileAndFuseOptions & + setProducerCanBeFusedWithoutRedundantComputations(bool val) { + producerCanBeFusedWithoutRedundantComputations = val; + return *this; + } +}; struct SCFTileAndFuseResult { + SmallVector fusableProducers; SmallVector tiledAndFusedOps; SmallVector loops; }; -struct TileConsumerAndFuseProducersUsingSCFForOp + +struct TileConsumerAndFuseProducersGreedilyUsingSCFForOp : public OpInterfaceRewritePattern { /// Construct a generic pattern applied to all TilingInterface ops. - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); + TileConsumerAndFuseProducersGreedilyUsingSCFForOp( + MLIRContext *context, SCFTileAndFuseOptions options, + PatternBenefit benefit = 1); /// Construct a generic pattern applied to `opName`. - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); + TileConsumerAndFuseProducersGreedilyUsingSCFForOp( + StringRef opName, MLIRContext *context, SCFTileAndFuseOptions options, + PatternBenefit benefit = 1); /// `matchAndRewrite` implementation that returns the significant transformed /// pieces of IR. @@ -135,10 +207,7 @@ } private: - /// This pattern uses the tiling pattern. Instead of using inheritance, use - /// the patterns as private object that is instantiated at the same time as - /// this pattern. - TileUsingSCFForOp tilingPattern; + SCFTileAndFuseOptions tileAndFuseOptions; }; /// Pattern to lower operations that implement the `TilingInterface` to @@ -158,6 +227,18 @@ } }; +/// Pattern to remove unused yields from `scf.for` operations. This is a cleanup +/// pattern that is to be used in conjunction with the Tiling and Tile+Fuse +/// patterns above. It could be a canonicalization of `scf.for` in general. THe +/// pattern accounts for the `tensor.insert_slice` that is generated by the +/// above patterns, which might not be something that holds in general. +struct RemoveDeadYields : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const override; +}; + } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/TilingInterface.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "tile-using-interface" @@ -167,22 +168,182 @@ return loops; } -scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} +/// For a value that is to be `yieldedValue`, get the `initValue` to use +/// after tile + fuse. For example, after +/// +/// ```mlir +/// %0 = linalg.init_tensor ... +/// %1 = linalg.fill ins(...) outs(%0 : ...) +/// %2 = linalg.matmul ins(...) outs(%1 : ...) +/// ``` +/// +/// is tiled and fused +/// +/// ```mlir +/// %0 = linalg.init_tensor +/// ... = scf.for ... +/// ... = scf.for ... +/// %1 = tensor.extract_slice %0 ... +/// %2 = linalg.fill ins(...) outs(%1 : ...) +/// %3 = linalg.matmul ins(...) outs(%2 : ...) +/// ``` +/// +/// Tracks through operations that are in destination passing style to get to +/// the slice of the value that is not from within the loop. This forms the new +/// destination to use for the yielded value and is used as the `init` of the +/// loop +static Optional getInitValue(OpBuilder &builder, Value yieldedValue, + scf::ForOp outerMost) { + auto getOwner = [](Value v) -> Operation * { + if (auto op = v.getDefiningOp()) + return op; + return v.cast().getOwner()->getParentOp(); + }; + while (outerMost->isAncestor(getOwner(yieldedValue))) { + if (auto blockArg = yieldedValue.dyn_cast()) { + if (auto loop = + dyn_cast(blockArg.getOwner()->getParentOp())) { + yieldedValue = loop.getOpOperandForRegionIterArg(blockArg).get(); + continue; + } + // Unable to find init value. + return llvm::None; + } + Operation *definingOp = yieldedValue.getDefiningOp(); + Optional nextVal = + TypeSwitch>(definingOp) + .Case([&](TilingInterface interfaceOp) -> Value { + return interfaceOp.getDestinationOperands( + builder)[yieldedValue.cast().getResultNumber()]; + }) + .Case( + [&](auto sourceOp) -> Value { return sourceOp.getSource(); }) + .Default([](Operation *) { return llvm::None; }); + if (!nextVal) + return llvm::None; + yieldedValue = nextVal.value(); + } + return yieldedValue; +} -scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, - MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} +/// Given a list of `untiledOps` that are the original untiled operations, and +/// their tiled counterparts (`tiledOps`) yield the values of the results +/// through the generated `tilingLoops`. The result of the outermost loop forms +/// a replacement for the `untiledOps`. This method performs the replacement as +/// well. +static LogicalResult yieldTiledValues(RewriterBase &rewriter, + ArrayRef untiledOps, + ArrayRef tiledOps, + ArrayRef tileOffsets, + ArrayRef tileSizes, + MutableArrayRef tilingLoops) { + // 5b. `scf.for` with tensor semantics requires the loop nest to yield the + // replacement values using destructive updates. Use the `TilingInterface` + // to get the position of the result tiles and use that to generate the + // destructive update pattern, i.e., + // + // ```mlir + // scf.for %iv0 = ... { + // %0 = tiled_op + // } + // ``` + // + // is transformed to + // + // ```mlir + // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { + // %0 = tiled_op + // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] + // scf.yield %1 + // } + // ``` + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector yieldedValues; + Attribute one = b.getIndexAttr(1); + Attribute zero = b.getIndexAttr(0); + unsigned bbArgNum = 0; + for (auto it : llvm::enumerate(untiledOps)) { + TilingInterface untiledOp = it.value(); + SmallVector opTileOffsets(tileOffsets), + opTileSizes(tileSizes); + unsigned opRank = untiledOp.getLoopIteratorTypes().size(); + opTileOffsets.resize(opRank, zero); + opTileSizes.resize(opRank, zero); + for (auto source : llvm::enumerate(tiledOps[it.index()]->getResults())) { + SmallVector resultTileOffsets, resultTileSizes; + if (failed(untiledOp.getResultTilePosition( + b, source.index(), opTileOffsets, opTileSizes, + resultTileOffsets, resultTileSizes))) { + return {}; + } + + SmallVector resultTileStrides(resultTileOffsets.size(), + one); + Value yieldedValue = b.create( + loc, source.value(), newBBArgs[bbArgNum++], resultTileOffsets, + resultTileSizes, resultTileStrides); + yieldedValues.push_back(yieldedValue); + } + } + return yieldedValues; + }; -FailureOr -scf::TileUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { + // Get the init value to use for each of the tiled op results. + SmallVector initValues; + for (auto tiledOp : llvm::enumerate(tiledOps)) { + for (auto result : llvm::enumerate(tiledOp.value()->getResults())) { + Optional initValue = + getInitValue(rewriter, result.value(), tilingLoops.front()); + if (!initValue) { + // Always valid to fall back to use the destination of the untiled ops. + TilingInterface untiledOp = untiledOps[tiledOp.index()]; + initValue = untiledOp.getDestinationOperands(rewriter)[result.index()]; + } + initValues.push_back(initValue.value()); + } + } + + // Modify the loop nest to yield the result values. + SmallVector newLoops = replaceLoopNestWithNewYields( + rewriter, tilingLoops, initValues, yieldValueFn); + for (const auto &loop : llvm::enumerate(tilingLoops)) { + rewriter.eraseOp(loop.value()); + tilingLoops[loop.index()] = newLoops[loop.index()]; + } + scf::ForOp outerMost = tilingLoops.front(); + llvm::SmallDenseSet untiledOpsSet; + untiledOpsSet.insert(untiledOps.begin(), untiledOps.end()); + unsigned resultNum = 0; + for (auto untiledOp : untiledOps) { + if (untiledOp->getNumResults() + resultNum > outerMost.getNumResults()) { + return rewriter.notifyMatchFailure( + untiledOp, "failed to yield results of operation"); + } + // Replace all uses of the untiled op with values returned from the + // loop except for in `tensor.dim` operations or other ops that are fused + // here. Those need to be resolved separately. + // TODO: Find a better way to handle replacements. + rewriter.replaceOpWithIf( + untiledOp, + outerMost.getResults().drop_front(resultNum).take_front( + untiledOp->getNumResults()), + [&](OpOperand &use) { + Operation *user = use.getOwner(); + return !untiledOpsSet.count(user) && + !isa(use.getOwner()); + }); + resultNum += untiledOp->getNumResults(); + } + return success(); +} + +/// Implementation of tiling transformation of `op` that implements the +/// `TilingInterface` using `scf.for` to iterate over the tiles. +static FailureOr +tileConsumer(RewriterBase &rewriter, TilingInterface op, + scf::SCFTilingOptions options) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); @@ -211,6 +372,11 @@ } scf::SCFTilingResult tilingResult; + tilingResult.tiledLoops.resize(numLoops); + for (auto tileSize : llvm::enumerate(tileSizeVector)) + if (!isConstantIntValue(tileSize.value(), 0)) + tilingResult.tiledLoops.set(tileSize.index()); + SmallVector offsets, sizes; { // If there is an interchange specified, permute the iteration domain and @@ -264,6 +430,8 @@ op, "expected tiled implementation to return a single op"); } tilingResult.tiledOp = tiledImplementation[0]; + std::swap(tilingResult.tileOffsets, offsets); + std::swap(tilingResult.tileSizes, sizes); LLVM_DEBUG({ if (!tilingResult.loops.empty()) { @@ -274,130 +442,206 @@ }); } + return tilingResult; +} + +scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)) {} + +scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, + MLIRContext *context, + scf::SCFTilingOptions options, + PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)) {} + +FailureOr +scf::TileUsingSCFForOp::returningMatchAndRewrite( + TilingInterface op, PatternRewriter &rewriter) const { + + FailureOr tilingResult = + tileConsumer(rewriter, op, options); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(op, "failed to tile operation"); + + // If there are no results (i.e. buffer semantics), there is nothing to do. + // Erase op and return. if (op->getNumResults() == 0) { rewriter.eraseOp(op); return tilingResult; } - // 5. If the original operations has results, modify the loop nest to yield - // the replacement values. - SmallVector replacements; - if (tilingResult.loops.empty()) { - // 5a. If there were no loops, the tiled implementation results are the - // replacements. - rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); + // If there are no loops, there is nothing more to do. + if (tilingResult->loops.empty()) { + // Replace the original op with the tiled op. + rewriter.replaceOp(op, tilingResult->tiledOp->getResults()); return tilingResult; } - // 5b. `scf.for` with tensor semantics requires the loop nest to yield the - // replacement values using destructive updates. Use the `TilingInterface` - // to get the position of the result tiles and use that to generate the - // destructive update pattern, i.e., - // - // ```mlir - // scf.for %iv0 = ... { - // %0 = tiled_op - // } - // ``` - // - // is transformed to - // - // ```mlir - // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { - // %0 = tiled_op - // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] - // scf.yield %1 - // } - // ``` - NewYieldValueFn yieldValueFn = - [&](OpBuilder &b, Location loc, - ArrayRef newBBArgs) -> SmallVector { - SmallVector yieldedValues; - Attribute one = b.getIndexAttr(1); - for (auto resultNum : llvm::seq(0, op->getNumResults())) { - SmallVector resultTileOffsets, resultTileSizes; - if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, - resultTileOffsets, - resultTileSizes))) { - op.emitOpError("unable to get position of result ") - << resultNum << " of the tiled implementation"; - return {}; - } - SmallVector resultTileStrides(resultTileOffsets.size(), - one); - Value yieldedValue = b.create( - op->getLoc(), tilingResult.tiledOp->getResult(resultNum), - newBBArgs[resultNum], resultTileOffsets, resultTileSizes, - resultTileStrides); - yieldedValues.push_back(yieldedValue); - } - return yieldedValues; - }; - SmallVector newLoops = replaceLoopNestWithNewYields( - rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), - yieldValueFn); - for (const auto &loop : llvm::enumerate(tilingResult.loops)) { - rewriter.eraseOp(loop.value()); - tilingResult.loops[loop.index()] = newLoops[loop.index()]; - } - rewriter.replaceOp(op, tilingResult.loops.front().getResults()); + // From the tiled ops reconstruct the value to replace the result of the + // untiled op using destructive updates. + if (failed(yieldTiledValues(rewriter, op, tilingResult->tiledOp, + tilingResult->tileOffsets, + tilingResult->tileSizes, tilingResult->loops))) + return failure(); return tilingResult; } //===----------------------------------------------------------------------===// -// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. +// TileConsumerAndFuseProducersGreedilyUsingSCFForOp pattern implementation. //===----------------------------------------------------------------------===// -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) +scf::TileConsumerAndFuseProducersGreedilyUsingSCFForOp:: + TileConsumerAndFuseProducersGreedilyUsingSCFForOp( + MLIRContext *context, scf::SCFTileAndFuseOptions options, + PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} + tileAndFuseOptions(std::move(options)) {} -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) +scf::TileConsumerAndFuseProducersGreedilyUsingSCFForOp:: + TileConsumerAndFuseProducersGreedilyUsingSCFForOp( + StringRef opName, MLIRContext *context, + scf::SCFTileAndFuseOptions options, PatternBenefit benefit) : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} - -/// Return the `Value` that is defined by an operation that implements -/// the `TilingInterface`. Looks through `iter_args` of scf.for nest -/// if required. -static Optional getFusableProducer(Value v) { - while (auto blockArg = v.dyn_cast()) { - auto loopOp = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!loopOp) - return llvm::None; - v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); + tileAndFuseOptions(std::move(options)) {} + +/// Collect all transitive producers of `op` and return then in sorted order +/// (i.e def before use). +static SmallVector collectTransitiveProducers(TilingInterface op) { + SetVector visited; + SmallVector worklist; + SmallVector sortedOps; + worklist.push_back(op); + while (!worklist.empty()) { + TilingInterface currOp = worklist.back(); + bool addedProducer = false; + for (OpOperand &operand : currOp->getOpOperands()) { + auto producerOp = operand.get().getDefiningOp(); + if (!producerOp || visited.count(producerOp)) + continue; + addedProducer = true; + worklist.push_back(producerOp); + visited.insert(producerOp); + } + if (!addedProducer) { + if (op != currOp) + sortedOps.push_back(currOp); + worklist.pop_back(); + } } - if (!isa_and_nonnull(v.getDefiningOp())) - return llvm::None; - return v.cast(); + return sortedOps; } -// Replace iter args of the outer most loop with region args of the inner most -// one. -static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, - PatternRewriter &rewriter) { - assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && - "expect same number of iter args"); - Block *block = &(*innerFor.getRegion().begin()); - for (auto it : - llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { - Value source = std::get<0>(it); - Value target = std::get<1>(it); - source.replaceUsesWithIf(target, [&](OpOperand &use) { - return use.getOwner()->getBlock() == block; - }); +/// For a given untiled op `producer` find all instances where +/// slices of this operation are used in `tiledOps`. +struct CandidateSliceOp { + tensor::ExtractSliceOp sliceOp; + OpResult producerResult; +}; +static SmallVector +collectAllSlicesToProducer(ArrayRef tiledOps, + TilingInterface producer) { + SmallVector slicesOfProducer; + for (auto tiledOp : tiledOps) { + for (OpOperand &operand : tiledOp->getOpOperands()) { + auto sliceOp = operand.get().getDefiningOp(); + if (!sliceOp) + continue; + + Value source = sliceOp.getSource(); + while (auto blockArg = source.dyn_cast()) { + auto loopOp = dyn_cast(blockArg.getOwner()->getParentOp()); + if (!loopOp) + break; + source = loopOp.getOpOperandForRegionIterArg(blockArg).get(); + } + + if (source.getDefiningOp() == producer.getOperation()) { + slicesOfProducer.emplace_back( + CandidateSliceOp{sliceOp, source.cast()}); + } + } } + return slicesOfProducer; +} + +/// Implementation of the the steps to fuse an untiled `producer` with +/// all uses of it in `tiledOps`. If `isFusableWithRedundantComputation` is +/// - `false` : each slice is replaced with a tiled version of the producer +/// that produces the +/// - `true` : it is assumed that a single instance of tiling the producer +/// can be used to replace all (slice) uses of the untiled +/// producer in `tiledOps`. +static FailureOr> +fuseProducer(RewriterBase &rewriter, ArrayRef tiledOps, + TilingInterface producer, + bool isFusableWithoutRedundantComputation = false) { + SmallVector fusedProducers; + SmallVector slicesOfProducer = + collectAllSlicesToProducer(tiledOps, producer); + if (slicesOfProducer.empty()) + return fusedProducers; + + if (!isFusableWithoutRedundantComputation) { + // Simpler usage, with lesser constraints. Just replace each slice + // with tiled implementation of the producer. + for (auto candidateSliceOp : slicesOfProducer) { + tensor::ExtractSliceOp sliceOp = candidateSliceOp.sliceOp; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(sliceOp); + FailureOr fusedProducerValue = + tensor::replaceExtractSliceWithTiledProducer( + rewriter, sliceOp, candidateSliceOp.producerResult); + if (failed(fusedProducerValue)) { + // The fusion failed. Continue fusing other slices. + continue; + } + rewriter.replaceOp(sliceOp, fusedProducerValue.value()); + fusedProducers.push_back(fusedProducerValue->getDefiningOp()); + } + return fusedProducers; + } + + // Assume one instance of the tiled producer can replace all uses in + // `tiledOps`. Take the first slice op and use that to produce the tiled + // implementation. + auto currSliceOp = slicesOfProducer.front().sliceOp; + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(currSliceOp); + FailureOr fusedProducerVal = + tensor::replaceExtractSliceWithTiledProducer( + rewriter, currSliceOp, slicesOfProducer.front().producerResult); + if (failed(fusedProducerVal)) { + // The fusion failed. Continue fusing other ops. + return fusedProducers; + } + TilingInterface tiledProducer = + fusedProducerVal->getDefiningOp(); + if (!tiledProducer || + tiledProducer->getNumResults() != producer->getNumResults()) { + return rewriter.notifyMatchFailure( + producer, + "unhandled case where tiled implementation does not return a single " + "operation with as many results as the untiled operation"); + } + fusedProducers.push_back(tiledProducer); + + // 3c. Replace the slice uses with the corresponding producer use. + for (auto candidateSliceOp : slicesOfProducer) { + unsigned resultNumber = candidateSliceOp.producerResult.getResultNumber(); + rewriter.replaceOp(candidateSliceOp.sliceOp, + tiledProducer->getResult(resultNumber)); + } + return fusedProducers; } FailureOr -scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { +scf::TileConsumerAndFuseProducersGreedilyUsingSCFForOp:: + returningMatchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const { // This transformation is only valid for ops that return values (i.e. not // valid to use with operations that have memref operands). if (!op->getNumResults()) { @@ -407,111 +651,73 @@ // 1. First tile the consumer. SCFTileAndFuseResult tileAndFuseResult; + SmallVector consumerTileOffsets, consumerTileSizes; { FailureOr tilingResult = - tilingPattern.returningMatchAndRewrite(op, rewriter); - if (failed(tilingResult)) { - return failure(); - } + tileConsumer(rewriter, op, tileAndFuseOptions.tilingOptions); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(op, "failed to tile consumer"); tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); tileAndFuseResult.loops = std::move(tilingResult->loops); + std::swap(tilingResult->tileOffsets, consumerTileOffsets); + std::swap(tilingResult->tileSizes, consumerTileSizes); } - // 2. Typically, the operands of the tiled operation are slices of the - // operands of the untiled operation. These are expressed in IR using - // `tensor.extract_slice` operations with source being the operands of the - // untiled operation. Create a worklist of these `tensor.extract_slice` - // operations. If the producers of the source of the `tensor.extract_slice` - // can be tiled such that the tiled value is generated in-place, that - // effectively tiles + fuses the operations. - auto addCandidateSlices = [](Operation *fusedOp, - std::deque &candidates) { - for (Value operand : fusedOp->getOperands()) - if (auto sliceOp = operand.getDefiningOp()) - candidates.push_back(sliceOp); - }; + // If there are no loops generated, fusion is immaterial. + if (tileAndFuseResult.loops.empty()) + return tileAndFuseResult; - std::deque candidates; - addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); - OpBuilder::InsertionGuard g(rewriter); - while (!candidates.empty()) { - // 2a. Traverse the slices in BFS fashion. - tensor::ExtractSliceOp candidateSliceOp = candidates.front(); - candidates.pop_front(); - - // 2b. Get the producer of the source (potentially walking through - // `iter_args` of nested `scf.for`) - Optional fusableProducer = - getFusableProducer(candidateSliceOp.getSource()); - if (!fusableProducer) - continue; + // 2. Collect a list of producers of the original operation that are to be + // tiled and fused. + tileAndFuseResult.fusableProducers = collectTransitiveProducers(op); + + // 3. Iterate through the producers in reverse and tile and fuse them. + for (Operation *producerOp : + llvm::reverse(tileAndFuseResult.fusableProducers)) { + auto producer = cast(producerOp); - // 2c. Generate the tiled implementation of the producer of the source - rewriter.setInsertionPoint(candidateSliceOp); - FailureOr fusedProducerValue = - tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, - fusableProducer.value()); - if (failed(fusedProducerValue)) + FailureOr> fusedProducers = fuseProducer( + rewriter, tileAndFuseResult.tiledAndFusedOps, producer, + tileAndFuseOptions.producerCanBeFusedWithoutRedundantComputations); + if (failed(fusedProducers)) { + // Fusion failed. Continue with other producers. continue; - rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); - - // 2d. The operands of the fused producer might themselved be slices of - // values produced by operations that implement the `TilingInterface`. - // Add these operations to the worklist. - Operation *fusedProducer = fusedProducerValue->getDefiningOp(); - tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); - addCandidateSlices(fusedProducer, candidates); - - // 2e. If the operation being fused creates a value that is used as `outs` - // in the tiled operation, the result of the unfused operation will be - // used in the `iter_args` of the tiled loop generated. When the - // operation is fused, this use in `iter_args` needs to be modified to - // use the destination of the fused operation. For example, starting - // with - // - // ```mlir - // %0 = linalg.init_tensor ... - // %1 = linalg.fill ... outs(%0:...)... - // %2 = linalg.matmul ... outs(%1:...).... - // ``` - // - // First the `linalg.matmul` gets tiled - // - // ```mlir - // %0 = linalg.init_tensor - // %1 = linalg.fill - // %2 = scf.for .... iter_args(%arg0 = %1)... - // ... - // ... = linalg.matmul ... - // - // ``` - // - // When the `linalg.fill` gets fused, the `iter_args` needs to be - // modified - // - // ```mlir - // %0 = linalg.init_tensor - // %1 = scf.for ... iter_args(%arg0 = %0)... - // ... - // %2 = linalg.fill ... - // %3 = linalg.matmul ... outs(%2: ...)... - // ``` - TilingInterface unfusedProducerOp = - cast(fusableProducer->getOwner()); - scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); - SmallVector unfusedProducerOpDestValues = - unfusedProducerOp.getDestinationOperands(rewriter); - for (OpOperand &uses : unfusedProducerOp->getUses()) { - if (uses.getOwner() == outerMostTiledLoop.getOperation()) { - unsigned resultNumber = uses.get().cast().getResultNumber(); - unsigned operandNumber = uses.getOperandNumber(); - outerMostTiledLoop->setOperand( - operandNumber, unfusedProducerOpDestValues[resultNumber]); - } } + + if (tileAndFuseOptions.producerCanBeFusedWithoutRedundantComputations && + fusedProducers->size() != 1) { + return rewriter.notifyMatchFailure( + producer, "expected single operation for the fused producer"); + } + + tileAndFuseResult.tiledAndFusedOps.append(fusedProducers.value()); + } + + // 4. Finally reconstruct the replacements for the untiled operations + // using destructive updates. If + // `producerCanBeFusedWithoutRedundantComputation` is + // - `true` : Yield the results of all the producers. It is assumed + // to be valid. + // - `false` : Yield the results of just the tiled consumer. + SmallVector untiledOps; + ArrayRef tiledOps = {tileAndFuseResult.tiledAndFusedOps.front()}; + untiledOps.push_back(op); + if (tileAndFuseOptions.producerCanBeFusedWithoutRedundantComputations) { + if (tileAndFuseResult.tiledAndFusedOps.size() != + tileAndFuseResult.fusableProducers.size() + 1) { + return rewriter.notifyMatchFailure( + op, "expected as many tiled and fused ops as producer"); + } + untiledOps.append(tileAndFuseResult.fusableProducers.rbegin(), + tileAndFuseResult.fusableProducers.rend()); + tiledOps = llvm::makeArrayRef(tileAndFuseResult.tiledAndFusedOps); } - replaceIterArgs(tileAndFuseResult.loops.front(), - tileAndFuseResult.loops.back(), rewriter); + if (failed(yieldTiledValues(rewriter, untiledOps, tiledOps, + consumerTileOffsets, consumerTileSizes, + tileAndFuseResult.loops))) { + return rewriter.notifyMatchFailure(op, "failed to yield values"); + } + return tileAndFuseResult; } @@ -552,3 +758,102 @@ rewriter.eraseOp(op); return loops; } + +//===----------------------------------------------------------------------===// +// RemoveDeadYields +//===----------------------------------------------------------------------===// + +LogicalResult +scf::RemoveDeadYields::matchAndRewrite(scf::ForOp forOp, + PatternRewriter &rewriter) const { + // Check for a result with no uses. + llvm::BitVector droppedResultNums(forOp.getNumResults()); + for (auto result : forOp->getResults()) { + if (!result.use_empty()) + continue; + OpOperand &iterArg = forOp.getIterOpOperands()[result.getResultNumber()]; + BlockArgument bbArg = forOp.getRegionIterArgForOpOperand(iterArg); + if (bbArg.use_empty()) { + droppedResultNums.set(result.getResultNumber()); + continue; + } + if (bbArg.hasOneUse()) { + OpOperand &use = *(bbArg.getUses().begin()); + auto insertSliceOp = dyn_cast(use.getOwner()); + if (!insertSliceOp) + continue; + + // The use of the bbArg should be the `dest` of the `tensor.insert_slice`. + if (insertSliceOp.getDest() != use.get()) + continue; + + // The insert Slice op must have a single use. + if (!insertSliceOp->hasOneUse()) + continue; + + // The use must be a yield op, with the operand number same as the result + // number. + OpOperand &insertSliceUse = *(insertSliceOp->use_begin()); + auto yieldUser = dyn_cast(insertSliceUse.getOwner()); + if (yieldUser && + insertSliceUse.getOperandNumber() == result.getResultNumber()) + droppedResultNums.set(result.getResultNumber()); + } + } + // Check if there is nothing to do. + if (!droppedResultNums.any()) + return failure(); + + SmallVector newIterArgs; + for (auto iterArg : llvm::enumerate(forOp.getIterOperands())) { + if (droppedResultNums.test(iterArg.index())) + continue; + newIterArgs.push_back(iterArg.value()); + } + auto bodyBuilderFn = [&](OpBuilder &bodyBuilder, Location loc, Value iv, + ValueRange bbArgs) { + Block *body = forOp.getBody(); + // Drop the unnecessary yield values. + Operation *terminator = body->getTerminator(); + terminator->eraseOperands(droppedResultNums); + + // Next remove the `tensor.insert_slices`. + for (auto bbArg : llvm::enumerate(forOp.getRegionIterArgs())) { + if (!droppedResultNums.test(bbArg.index())) + continue; + if (bbArg.value().use_empty()) + continue; + assert(bbArg.value().hasOneUse() && + "expected bbArg to have a single use"); + Operation *user = *bbArg.value().getUsers().begin(); + rewriter.eraseOp(user); + } + + // Merge the blocks from the original scf.for op and the block + // created for this op. + SmallVector argValues(body->getNumArguments(), nullptr); + auto inductionVar = forOp.getInductionVar().cast(); + argValues[inductionVar.getArgNumber()] = iv; + unsigned argNum = 0; + for (auto origIterArgs : llvm::enumerate(forOp.getRegionIterArgs())) { + if (droppedResultNums.test(origIterArgs.index())) + continue; + argValues[origIterArgs.value().getArgNumber()] = bbArgs[argNum++]; + } + Block *newBody = bodyBuilder.getInsertionBlock(); + rewriter.mergeBlocks(body, newBody, argValues); + }; + auto newForOp = rewriter.create( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newIterArgs, bodyBuilderFn); + SmallVector replacements(forOp.getNumResults(), nullptr); + auto results = newForOp.getResults(); + unsigned resultNum = 0; + for (auto it : llvm::seq(0, forOp.getNumResults())) { + if (droppedResultNums.test(it)) + continue; + replacements[it] = results[resultNum++]; + } + rewriter.replaceOp(forOp, replacements); + return success(); +} \ No newline at end of file diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-yield-producers.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-yield-producers.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-yield-producers.mlir @@ -0,0 +1,259 @@ +// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-yield-producer-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -pass-pipeline="func.func(test-tiling-interface{tile-consumer-and-fuse-yield-producer-using-scf-for remove-unused-yields})" -split-input-file %s | FileCheck %s --check-prefix=PRUNEYIELDS + +func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %init0 = linalg.init_tensor [%d0, %d1] : tensor + %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor + %gemm = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%fill : tensor) -> tensor + %init1 = linalg.init_tensor [%d0, %d1] : tensor + %generic = linalg.generic { + __internal_linalg_transform__ = "fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%gemm, %arg2 : tensor, tensor) outs(%init1 : tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %add = arith.addf %b0, %b1 : f32 + linalg.yield %add : f32 + } -> tensor + return %gemm, %generic : tensor, tensor +} +// CHECK: func.func @gemm_generic_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[INIT0:.+]] = linalg.init_tensor +// CHECK: %[[INIT1:.+]] = linalg.init_tensor +// CHECK: %[[RESULT:[a-zA-Z0-9]+]]:3 = scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:.+]] = %[[INIT1]], %[[ITERARG1:.+]] = %[[INIT0]], %[[ITERARG2:.+]] = %[[INIT0]]) +// CHECK: %[[INNER_RESULT:[a-zA-Z0-9]+]]:3 = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[INNER_ITERARG0:[a-zA-Z0-9]+]] = %[[ITERARG0]], %[[INNER_ITERARG1:[a-zA-Z0-9]+]] = %[[ITERARG1]], %[[INNER_ITERARG2:[a-zA-Z0-9]+]] = %[[ITERARG1]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INNER_ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE]] : +// CHECK-DAG: %[[BIAS_TILE:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]]] +// CHECK-DAG: %[[OUTS_TILE:.+]] = tensor.extract_slice %[[INNER_ITERARG0]][%[[IV0]], %[[IV1]]] +// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] : +// CHECK-SAME: outs(%[[OUTS_TILE]] : +// CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[INNER_ITERARG0]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INSERT2:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INNER_ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INSERT3:.+]] = tensor.insert_slice %[[FILL_TILE]] into %[[INNER_ITERARG2]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[INSERT1]], %[[INSERT2]], %[[INSERT3]] +// CHECK: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1, %[[INNER_RESULT]]#2 +// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0 + +// PRUNEYIELDS-LABEL: func @gemm_generic_fusion( +// PRUNEYIELDS: %[[INIT0:.+]] = linalg.init_tensor +// PRUNEYIELDS: %[[INIT1:.+]] = linalg.init_tensor +// PRUNEYIELDS: %[[RESULT:[a-zA-Z0-9]+]]:2 = scf.for +// PRUNEYIELDS-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]]) +// PRUNEYIELDS: %[[INNER_RESULT:[a-zA-Z0-9]+]]:2 = scf.for +// PRUNEYIELDS-SAME: iter_args(%[[INNER_ITERARG0:[a-zA-Z0-9]+]] = %[[ITERARG0]], %[[INNER_ITERARG1:[a-zA-Z0-9]+]] = %[[ITERARG1]]) +// PRUNEYIELDS-DAG: %[[INSERT0:.+]] = tensor.insert_slice %{{.+}} into %[[INNER_ITERARG0]] +// PRUNEYIELDS-DAG: %[[INSERT1:.+]] = tensor.insert_slice %{{.+}} into %[[INNER_ITERARG1]] +// PRUNEYIELDS: scf.yield %[[INSERT0]], %[[INSERT1]] +// PRUNEYIELDS: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1 +// PRUNEYIELDS: return %[[RESULT]]#1, %[[RESULT]]#0 + +// ----- + +func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %rhs1 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %d0 = tensor.dim %lhs0, %c0 : tensor + %d1 = tensor.dim %rhs0, %c1 : tensor + %init0 = linalg.init_tensor [%d0, %d1] : tensor + %fill0 = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor + %gemm0 = linalg.matmul + ins(%lhs0, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor + %d2 = tensor.dim %rhs1, %c1 : tensor + %init1 = linalg.init_tensor [%d0, %d2] : tensor + %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor + %gemm1 = linalg.matmul {__internal_linalg_transform__ = "gemm_fusion"} + ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor + return %gemm0, %gemm1 : tensor, tensor +} +// CHECK: func.func @gemm_gemm_fusion( +// CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[RHS1:[a-zA-Z0-9]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[LHS0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[RHS0]], %[[C1]] +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[RHS1]], %[[C1]] +// CHECK: %[[INIT1:.+]] = linalg.init_tensor [%[[D0]], %[[D2]]] +// CHECK: %[[RESULT:[a-zA-Z0-9]+]]:4 = scf.for %[[IV:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG3:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK-DAG: %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0] +// CHECK-DAG: %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][0, 0] +// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0] +// CHECK: %[[FILL0_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_TILE]] : +// CHECK: %[[GEMM0_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS0_TILE]], %[[RHS0_TILE]] : +// CHECK-SAME: outs(%[[FILL0_TILE]] : +// CHECK-DAG: %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][0, 0] +// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[FILL1_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT1_TILE]] : +// CHECK: %[[GEMM1_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : +// CHECK-SAME: outs(%[[FILL1_TILE]] : +// CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0] +// CHECK-DAG: %[[INSERT2:.+]] = tensor.insert_slice %[[FILL0_TILE]] into %[[ITERARG2]][%[[IV]], 0] +// CHECK-DAG: %[[INSERT3:.+]] = tensor.insert_slice %[[FILL1_TILE]] into %[[ITERARG3]][%[[IV]], 0] +// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]], %[[INSERT3]] +// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0 + +// PRUNEYIELDS-LABEL: func @gemm_gemm_fusion( +// PRUNEYIELDS: %[[INIT0:.+]] = linalg.init_tensor +// PRUNEYIELDS: %[[INIT1:.+]] = linalg.init_tensor +// PRUNEYIELDS: %[[RESULT:[a-zA-Z0-9]+]]:2 = scf.for +// PRUNEYIELDS-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]]) +// PRUNEYIELDS-DAG: %[[INSERT0:.+]] = tensor.insert_slice %{{.+}} into %[[ITERARG0]] +// PRUNEYIELDS-DAG: %[[INSERT1:.+]] = tensor.insert_slice %{{.+}} into %[[ITERARG1]] +// PRUNEYIELDS: scf.yield %[[INSERT0]], %[[INSERT1]] +// PRUNEYIELDS: return %[[RESULT]]#1, %[[RESULT]]#0 + +// ----- + +func.func @interchange_matmul_fusion(%arg0 : tensor, %arg1 : tensor) + -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg1, %c1 : tensor + %cst = arith.constant 0.0 : f32 + %0 = linalg.init_tensor [%d0, %d1] : tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%1 : tensor) -> tensor + %3 = linalg.init_tensor [%d0, %d1] : tensor + %4 = linalg.generic { + __internal_linalg_transform__ = "gemm_interchange_fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%2 : tensor) outs(%3 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %4 = arith.addf %b0, %b0 : f32 + linalg.yield %4 : f32 + } -> tensor + return %2, %4 : tensor, tensor +} +// CHECK: func.func @interchange_matmul_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK: %[[INIT0:.+]] = linalg.init_tensor +// CHECK: %[[INIT1:.+]] = linalg.init_tensor +// CHECK: %[[RESULT:[a-zA-Z0-9]+]]:3 = scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT0]]) +// CHECK: %[[INNER_RESULT:[a-zA-Z0-9]+]]:3 = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-SAME: iter_args(%[[INNER_ITERARG0:[a-zA-Z0-9]+]] = %[[ITERARG0]], %[[INNER_ITERARG1:[a-zA-Z0-9]+]] = %[[ITERARG1]], %[[INNER_ITERARG2:[a-zA-Z0-9]+]] = %[[ITERARG1]]) +// CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] +// CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INNER_ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK: %[[FILL_TILE:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : +// CHECK-SAME: outs(%[[FILL_TILE]] : +// CHECK: %[[INIT_TILE_2:.+]] = tensor.extract_slice %[[INNER_ITERARG0]][%[[IV1]], %[[IV0]]] +// CHECK: %[[GENERIC_TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GEMM_TILE]] : +// CHECK-SAME: outs(%[[INIT_TILE_2]] : +// CHECK-DAG: %[[INSERT0:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[INNER_ITERARG0]][%[[IV1]], %[[IV0]]] +// CHECK-DAG: %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[INNER_ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK-DAG: %[[INSERT2:.+]] = tensor.insert_slice %[[FILL_TILE]] into %[[INNER_ITERARG2]][%[[IV1]], %[[IV0]]] +// CHECK: scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]] +// CHECK: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1, %[[INNER_RESULT]]#2 +// CHECK: return %[[RESULT]]#1, %[[RESULT]]#0 + +// PRUNEYIELDS-LABEL: func @interchange_matmul_fusion( +// PRUNEYIELDS: %[[INIT0:.+]] = linalg.init_tensor +// PRUNEYIELDS: %[[INIT1:.+]] = linalg.init_tensor +// PRUNEYIELDS: %[[RESULT:[a-zA-Z0-9]+]]:2 = scf.for +// PRUNEYIELDS-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]]) +// PRUNEYIELDS: %[[INNER_RESULT:[a-zA-Z0-9]+]]:2 = scf.for +// PRUNEYIELDS-SAME: iter_args(%[[INNER_ITERARG0:[a-zA-Z0-9]+]] = %[[ITERARG0]], %[[INNER_ITERARG1:[a-zA-Z0-9]+]] = %[[ITERARG1]]) +// PRUNEYIELDS-DAG: %[[INSERT0:.+]] = tensor.insert_slice %{{.+}} into %[[INNER_ITERARG0]] +// PRUNEYIELDS-DAG: %[[INSERT1:.+]] = tensor.insert_slice %{{.+}} into %[[INNER_ITERARG1]] +// PRUNEYIELDS: scf.yield %[[INSERT0]], %[[INSERT1]] +// PRUNEYIELDS: scf.yield %[[INNER_RESULT]]#0, %[[INNER_RESULT]]#1 +// PRUNEYIELDS: return %[[RESULT]]#1, %[[RESULT]]#0 + +// ----- + +func.func @matmul_sequence_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor, %arg3: tensor, %arg4: tensor, + %arg5: tensor, %arg6: tensor) + -> (tensor, tensor, tensor) { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) + outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] + %2 = linalg.matmul + {__internal_linalg_transform__ = "gemm_sequence_fusion"} + ins(%1, %arg5 : tensor, tensor) + outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] + return %0, %1, %2 : tensor, tensor, tensor +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @matmul_sequence_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : +// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]] +// CHECK-DAG: %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] : +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]] +// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]] +// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]] +// CHECK: %[[R0:[a-zA-Z0-9]+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = +// CHECK-SAME: iter_args(%[[ARG8:[a-zA-Z0-9]+]] = %[[ARG6]], %[[ARG9:[a-zA-Z0-9]+]] = %[[ARG4]], %[[ARG10:[a-zA-Z0-9]+]] = %[[ARG2]]) +// CHECK-DAG: %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%{{.+}}, %[[M]]] +// CHECK-DAG: %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]] +// CHECK-DAG: %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] +// CHECK-DAG: %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG10]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]] +// CHECK-DAG: %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : +// CHECK-SAME: outs(%[[SLICE_ARG2]] : +// CHECK-DAG: %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]] +// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG9]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]] +// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] : +// CHECK-SAME: outs(%[[SLICE_ARG4]] : +// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]] +// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] : +// CHECK-SAME: outs(%[[SLICE_ARG6]] : +// CHECK-DAG: %[[UPDATE0:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK-DAG: %[[UPDATE1:.+]] = tensor.insert_slice %[[TILE_GEMM2]] into %[[ARG9]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK-DAG: %[[UPDATE2:.+]] = tensor.insert_slice %[[TILE_GEMM1]] into %[[ARG10]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]], %[[UPDATE2]] +// CHECK: return %[[R0]]#2, %[[R0]]#1, %[[R0]]#0 diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -71,25 +71,26 @@ /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive /// application. -struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter - : public scf::TileConsumerAndFuseProducersUsingSCFForOp { - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - MLIRContext *context, scf::SCFTilingOptions options, +struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOpWithFilter + : public scf::TileConsumerAndFuseProducersGreedilyUsingSCFForOp { + TestTileConsumerAndFuseProducersGreedilyUsingSCFForOpWithFilter( + MLIRContext *context, scf::SCFTileAndFuseOptions options, linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter(), PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, - benefit), + : scf::TileConsumerAndFuseProducersGreedilyUsingSCFForOp(context, options, + benefit), filter(filter) {} /// Construct a generic pattern applied to `opName`. - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, + TestTileConsumerAndFuseProducersGreedilyUsingSCFForOpWithFilter( + StringRef opName, MLIRContext *context, + scf::SCFTileAndFuseOptions options, linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter(), PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp(context, options, - benefit), + : scf::TileConsumerAndFuseProducersGreedilyUsingSCFForOp(context, options, + benefit), filter(filter) {} LogicalResult matchAndRewrite(TilingInterface op, @@ -141,12 +142,25 @@ "with scf.for operations"), llvm::cl::init(false)}; + Option testTileConsumerAndFuseProducerWithoutRedundancy{ + *this, "tile-consumer-and-fuse-yield-producer-using-scf-for", + llvm::cl::desc( + "Test tile and fuse transformation using TilingInterface with " + "scf.for operations, when producers can be fused without redundancy"), + llvm::cl::init(false)}; + Option testLoweringToScalar{ *this, "lower-to-scalar-using-scf-for", llvm::cl::desc("Test lowering to scalar implementation using " "TilingInterface with scf.for operations"), llvm::cl::init(false)}; + Option removeUnusedYields{ + *this, "remove-unused-yields", + llvm::cl::desc( + "Include pattern to remove results of `scf.for` that are dead"), + llvm::cl::init(false)}; + void runOnOperation() override; private: @@ -154,70 +168,95 @@ }; } // namespace -template -static void -addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns, - StringRef filterName, ArrayRef tileSizes, - ArrayRef interchange = {}) { +static void addPatternForTiling(MLIRContext *context, + RewritePatternSet &patterns, + StringRef filterName, + ArrayRef tileSizes, + ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); linalg::LinalgTransformationFilter filter( StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); - patterns.add(context, tilingOptions, filter); + patterns.add(context, tilingOptions, filter); +} + +static void +addPatternForTileAndFuse(MLIRContext *context, RewritePatternSet &patterns, + StringRef filterName, ArrayRef tileSizes, + bool producerCanBeFusedWithoutRedundancy = false, + ArrayRef interchange = {}) { + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( + interchange); + tileAndFuseOptions.setProducerCanBeFusedWithoutRedundantComputations( + producerCanBeFusedWithoutRedundancy); + linalg::LinalgTransformationFilter filter( + StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); + patterns.add( + context, tileAndFuseOptions, filter); } void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { + if (removeUnusedYields) + patterns.add(context); if (testTiling) { // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTiling( - context, patterns, "simple_gemm", {10, 20}); + addPatternForTiling(context, patterns, "simple_gemm", {10, 20}); // 2. Tiling M, N and K of `linalg.matmul` on buffers. - addPatternForTiling( - context, patterns, "simple_gemm_memref", {10, 20, 30}); + addPatternForTiling(context, patterns, "simple_gemm_memref", {10, 20, 30}); // 3. Tiling 3D parallel generic op which implements a transpose - addPatternForTiling( - context, patterns, "parallel_generic_transpose", {10, 0, 20}); + addPatternForTiling(context, patterns, "parallel_generic_transpose", + {10, 0, 20}); // 4. Tiling 2D conv op. - addPatternForTiling( - context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30}); + addPatternForTiling(context, patterns, "simple_conv", + {0, 0, 0, 0, 10, 20, 30}); // 5. Tiling a simple op with `linalg.index` inside. - addPatternForTiling( - context, patterns, "indexed_semantics", {10, 20}); + addPatternForTiling(context, patterns, "indexed_semantics", {10, 20}); // 6. Tiling + interchange of an operation - addPatternForTiling( - context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0}); + addPatternForTiling(context, patterns, "gemm_interchange", {10, 20, 30}, + {1, 2, 0}); // 7. Tiling for 2D pad tensor operations. - addPatternForTiling( - context, patterns, "pad_2dtiling", {2, 3}); + addPatternForTiling(context, patterns, "pad_2dtiling", {2, 3}); // 8. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling( - context, patterns, "pad_inner_tiling", {0, 3}); + addPatternForTiling(context, patterns, "pad_inner_tiling", {0, 3}); // 9. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling( - context, patterns, "pad_outer_tiling", {2, 3}); + addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3}); return; } if (testTileConsumerAndFuseProducer) { - // 1. Tile and fuse of gemm with bias-add operation. - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "fusion", {10, 20}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_fusion", {10}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_plus_gemm_fusion", {10, 20}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_sequence_fusion", {10}); + // 1. Tile and fuse of gemm with fill producer and bias-add consumer. + addPatternForTileAndFuse(context, patterns, "fusion", {10, 20}); + // 2. Tile and fuse sequence of GEMMs, by fusing only along M. + addPatternForTileAndFuse(context, patterns, "gemm_fusion", {10}); + // 3. Tile and fuse gemm with consumer + interchange of tiled loops. + addPatternForTileAndFuse(context, patterns, "gemm_interchange_fusion", + {10, 20}, false, {1, 0}); + // 4. Tile and fuse matmul + transpose(matmul). Will introduce redundant + // computations. + addPatternForTileAndFuse(context, patterns, "gemm_plus_gemm_fusion", + {10, 20}); + // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M + // dimension. + addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}); return; } + if (testTileConsumerAndFuseProducerWithoutRedundancy) { + // 1. Tile and fuse of gemm with fill producer and bias-add consumer, + // returning both the gemm value and the bias-add value. + addPatternForTileAndFuse(context, patterns, "fusion", {10, 20}, true); + // 2. Tile and fuse sequence of GEMMs, by fusing only along M, return all + // the individual GEMM values. + addPatternForTileAndFuse(context, patterns, "gemm_fusion", {10}, true); + // 3. Tile and fuse gemm with consumer + interchange of tiled loops. + addPatternForTileAndFuse(context, patterns, "gemm_interchange_fusion", + {10, 20}, true, {1, 0}); + // 4. Tile and fuse a sequence of GEMMs by tiling and fusing only along M + // dimension. + addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}, + true); + } if (testLoweringToScalar) { patterns.add(context); }