diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -60,38 +60,48 @@ } }; +/// Transformation information returned after tiling. struct SCFTilingResult { + /// The tiled operation generated. Operation *tiledOp; + /// The `scf.for` operations that iterate over the tiles. SmallVector loops; + /// Values to use as replacements for the untiled op. Is the same size as the + /// number of results of the untiled op. + SmallVector replacements; }; -/// Pattern to tile an op that implements the `TilingInterface` using +/// Method to tile an op that implements the `TilingInterface` using /// `scf.for` for iterating over the tiles. -struct TileUsingSCFForOp : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all TilingInterface ops. - TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// Construct a generic pattern applied to `opName`. - TileUsingSCFForOp(StringRef opName, MLIRContext *context, - SCFTilingOptions options, PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); +FailureOr tileUsingSCFForOp(RewriterBase &rewriter, + TilingInterface op, + SCFTilingOptions options); + +/// Options used to control tile + fuse. +struct SCFTileAndFuseOptions { + /// The tiling options used to control the tiling of the consumer. + SCFTilingOptions tilingOptions; + SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) { + tilingOptions = options; + return *this; } +}; -private: - /// Options to control tiling; - SCFTilingOptions options; +/// Transformation information returned after tile and fuse. +struct SCFTileAndFuseResult { + /// List of untiled operations that were fused with the tiled consumer. + llvm::SetVector fusedProducers; + /// List of tiled and fused operations generated. The first one in this list + /// is guaranteed to be the tiled operations generated during tiling of the + /// generated operation. + llvm::SetVector tiledAndFusedOps; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; + /// The replacement values to use for the tiled and fused operations. + llvm::DenseMap replacements; }; -/// Pattern to tile and fuse a sequence of operations, by tiling the consumer +/// Method to tile and fuse a sequence of operations, by tiling the consumer /// and fusing its producers. Note that this assumes that it is valid to /// tile+fuse the producer into the innermost tiled loop. Its up to the caller /// to ensure that the tile sizes provided make this fusion valid. @@ -99,64 +109,32 @@ /// For example, for the following sequence /// /// ```mlir -/// %0 = linalg.fill ... -/// %1 = linalg.matmul ... outs(%0 : ...) ... +/// %0 = +/// %1 = linalg.fill ... outs(%0 : ... ) +/// %2 = linalg.matmul ... outs(%1 : ...) ... /// ``` /// /// it is legal to fuse the fill with the matmul only if the matmul is tiled /// along the parallel dimensions and not the reduction dimension, i.e. the tile -/// size for the reduction dimension should be 0. -struct SCFTileAndFuseResult { - SmallVector tiledAndFusedOps; - SmallVector loops; -}; -struct TileConsumerAndFuseProducersUsingSCFForOp - : public OpInterfaceRewritePattern { - - /// Construct a generic pattern applied to all TilingInterface ops. - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// Construct a generic pattern applied to `opName`. - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// This pattern uses the tiling pattern. Instead of using inheritance, use - /// the patterns as private object that is instantiated at the same time as - /// this pattern. - TileUsingSCFForOp tilingPattern; -}; - -/// Pattern to lower operations that implement the `TilingInterface` to -/// loops/scalar IR using `scf.for`. -struct LowerToLoopsUsingSCFForOp - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr> - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } -}; +/// size for the reduction dimension should be 0. The resulting fused +/// transformation is +/// +/// ```mlir +/// %1 = scf.for ... iter_args(%arg0 = %0) +/// %2 = tensor.extract_slice %arg0 +/// %3 = linalg.fill .. outs(%2 : ... ) +/// %4 = linalg.matmul .. outs(%3 : ...) +/// } +/// ``` +FailureOr +tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter, + TilingInterface consumer, + SCFTileAndFuseOptions options); + +/// Method to lower an `op` that implements the `TilingInterface` to +/// loops/scalars. +FailureOr> +lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); } // namespace scf } // namespace mlir diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -87,7 +87,7 @@ } //===----------------------------------------------------------------------===// -// TileUsingSCFForOp pattern implementation. +// tileUsingSCFForOp implementation. //===----------------------------------------------------------------------===// // Check if `stride` evenly divides the trip count `size - offset`. @@ -167,7 +167,65 @@ return loops; } -/// If the tiled operation is in destination passing style, update the +/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, +/// construct the destructive update pattern that inserts the yielded +/// value into a destination tensor provided by `initValue` at offset +/// `tileOffsets` and size `tileSizes`. For example, +/// +/// ```mlir +/// scf.for %iv0 = ... { +/// %0 = tiled_op +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. +static FailureOr> +yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, + ValueRange yieldedValues, + ArrayRef> tileOffsetsList, + ArrayRef> tileSizesList, + MutableArrayRef loops) { + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector inserts; + for (auto yieldedValue : llvm::enumerate(yieldedValues)) { + ArrayRef tileOffsets = + tileOffsetsList[yieldedValue.index()]; + ArrayRef tileSizes = tileSizesList[yieldedValue.index()]; + SmallVector tileStrides(tileOffsets.size(), + b.getIndexAttr(1)); + Value insert = b.create( + loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], + tileOffsets, tileSizes, tileStrides); + inserts.push_back(insert); + } + return inserts; + }; + + SmallVector newLoops = + replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, + /*replaceIterOperandsUsesInLoop =*/false); + for (const auto &loop : llvm::enumerate(loops)) { + rewriter.eraseOp(loop.value()); + loops[loop.index()] = newLoops[loop.index()]; + } + return llvm::to_vector(llvm::map_range( + loops.front().getResults().take_back(yieldedValues.size()), + [](OpResult r) -> Value { return r; })); +} + +/// If the tiled operation is destination passing style, update the /// slice of the destination used (which refers to the untiled destination) /// to use the corresponding region argument of the innermost loop. /// @@ -191,8 +249,6 @@ /// scf.yield %3 /// } /// ``` -/// TODO: This can be made much cleaner when `DestinationStyleOp` interface is -/// available generally. static void updateDestinationOperandsForTiledOp(OpBuilder &builder, ValueRange tiledOpDestinationValues, @@ -205,22 +261,11 @@ } } -scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} - -scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, - MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} - +/// Implementation of tiling transformation of `op` that implements the +/// `TilingInterface` using `scf.for` to iterate over the tiles. FailureOr -scf::TileUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { +mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, + scf::SCFTilingOptions options) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); @@ -282,132 +327,86 @@ offsets = applyPermutationToVector(offsets, inversePermutation); sizes = applyPermutationToVector(sizes, inversePermutation); } + } - LLVM_DEBUG({ - if (!tilingResult.loops.empty()) { - llvm::errs() << "LoopNest shell :\n"; - tilingResult.loops.front().dump(); - llvm::errs() << "\n"; - } - }); - - // 4. Generate the tiled implementation within the inner most loop. - if (!tilingResult.loops.empty()) - rewriter.setInsertionPoint( - tilingResult.loops.back().getBody()->getTerminator()); - SmallVector tiledImplementation = - op.getTiledImplementation(rewriter, offsets, sizes); - if (tiledImplementation.size() != 1) { - return rewriter.notifyMatchFailure( - op, "expected tiled implementation to return a single op"); + LLVM_DEBUG({ + if (!tilingResult.loops.empty()) { + llvm::dbgs() << "LoopNest shell :\n"; + tilingResult.loops.front().dump(); + llvm::dbgs() << "\n"; } - tilingResult.tiledOp = tiledImplementation[0]; - - LLVM_DEBUG({ - if (!tilingResult.loops.empty()) { - llvm::errs() << "After tiled implementation :\n"; - tilingResult.loops.front().dump(); - llvm::errs() << "\n"; - } - }); + }); + + // 4. Generate the tiled implementation within the inner most loop. + if (!tilingResult.loops.empty()) + rewriter.setInsertionPoint( + tilingResult.loops.back().getBody()->getTerminator()); + SmallVector tiledImplementation = + op.getTiledImplementation(rewriter, offsets, sizes); + if (tiledImplementation.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected tiled implementation to return a single op"); } - + tilingResult.tiledOp = tiledImplementation[0]; if (op->getNumResults() == 0) { - rewriter.eraseOp(op); + // nothing more to do. return tilingResult; } - // 5. If the original operations has results, modify the loop nest to yield - // the replacement values. + // If loops are empty, the tiled op is used as the replacement for the untiled + // op. if (tilingResult.loops.empty()) { - // 5a. If there were no loops, the tiled implementation results are the - // replacements. - rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); + tilingResult.replacements = llvm::to_vector( + llvm::map_range(tiledImplementation[0]->getResults(), + [](OpResult result) -> Value { return result; })); return tilingResult; } - // 6. Yield the results of the tiled operation from the loop nest as - // replacements for the original untiled ops. - if (tilingResult.tiledOp->getNumResults() != op->getNumResults()) { - return rewriter.notifyMatchFailure( - tilingResult.tiledOp, - "expected tiled op to have as many results as the untiled operation"); + // 5. Yield all the results of the tiled operation. The surrounding loop + // nest is modified to insert a destructive update pattern to yield + // from the loop nest values to replace the untiled op with. + unsigned numResults = op->getNumResults(); + SmallVector> resultOffsetsList(numResults), + resultSizesList(numResults); + for (auto result : llvm::enumerate(op->getResults())) { + if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, + sizes, + resultOffsetsList[result.index()], + resultSizesList[result.index()]))) { + return rewriter.notifyMatchFailure( + op, "failed to get slice of result produced"); + } } - // `scf.for` with tensor semantics requires the loop nest to yield the - // replacement values using destructive updates. Use the `TilingInterface` - // to get the position of the result tiles and use that to generate the - // destructive update pattern, i.e., - // - // ```mlir - // scf.for %iv0 = ... { - // %0 = tiled_op - // } - // ``` - // - // is transformed to - // - // ```mlir - // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { - // %0 = tiled_op - // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] - // scf.yield %1 - // } - // ``` - NewYieldValueFn yieldValueFn = - [&](OpBuilder &b, Location loc, - ArrayRef newBBArgs) -> SmallVector { - SmallVector yieldedValues; - Attribute one = b.getIndexAttr(1); - for (auto resultNum : llvm::seq(0, op->getNumResults())) { - SmallVector resultTileOffsets, resultTileSizes; - if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, - resultTileOffsets, - resultTileSizes))) { - op.emitOpError("unable to get position of result ") - << resultNum << " of the tiled implementation"; - return {}; - } - SmallVector resultTileStrides(resultTileOffsets.size(), - one); - Value yieldedValue = b.create( - op->getLoc(), tilingResult.tiledOp->getResult(resultNum), - newBBArgs[resultNum], resultTileOffsets, resultTileSizes, - resultTileStrides); - yieldedValues.push_back(yieldedValue); - } - return yieldedValues; - }; - SmallVector newLoops = replaceLoopNestWithNewYields( - rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), - yieldValueFn, /*replaceIterOperandsUsesInLoops =*/false); - for (const auto &loop : llvm::enumerate(tilingResult.loops)) { - rewriter.eraseOp(loop.value()); - tilingResult.loops[loop.index()] = newLoops[loop.index()]; + FailureOr> replacementOr = + yieldTiledValues(rewriter, op.getDestinationOperands(rewriter), + tilingResult.tiledOp->getResults(), resultOffsetsList, + resultSizesList, tilingResult.loops); + if (failed(replacementOr)) + return rewriter.notifyMatchFailure(op, "failed to yield replacement"); + if (auto tiledInterfaceOp = dyn_cast(tilingResult.tiledOp)) { + auto innerMostLoop = tilingResult.loops.back(); + updateDestinationOperandsForTiledOp( + rewriter, tiledInterfaceOp.getDestinationOperands(rewriter), + innerMostLoop.getRegionIterArgs()); } - rewriter.replaceOp(op, tilingResult.loops.front().getResults()); + + tilingResult.replacements = replacementOr.value(); + + LLVM_DEBUG({ + if (!tilingResult.loops.empty()) { + llvm::dbgs() << "After tiled implementation :\n"; + tilingResult.loops.front().dump(); + llvm::dbgs() << "\n"; + } + }); return tilingResult; } //===----------------------------------------------------------------------===// -// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. +// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. //===----------------------------------------------------------------------===// -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} - -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} - /// Return the untiled producer whose slice is used in a tiled consumer. The /// method traverses the tile loop nest (`loops`) if needed, and returns the /// `iter_args` of the outer most that is encountered. Traversing the iter_args @@ -430,28 +429,41 @@ return {source->get().dyn_cast(), destinationIterArg}; } +/// Implementation of tile consumer and fuse producer greedily. FailureOr -scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { +mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( + RewriterBase &rewriter, TilingInterface consumer, + scf::SCFTileAndFuseOptions options) { // This transformation is only valid for ops that return values (i.e. not // valid to use with operations that have memref operands). - if (!op->getNumResults()) { + if (!consumer->getNumResults()) { return rewriter.notifyMatchFailure( - op, "invalid pattern for op with no results"); + consumer, "invalid pattern for op with no results"); } // 1. First tile the consumer. - SCFTileAndFuseResult tileAndFuseResult; + scf::SCFTileAndFuseResult tileAndFuseResult; + llvm::SmallDenseMap yieldedValueToResultNumber; { - FailureOr tilingResult = - tilingPattern.returningMatchAndRewrite(op, rewriter); - if (failed(tilingResult)) { - return failure(); - } - tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); + FailureOr tilingResult = + tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); + tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); tileAndFuseResult.loops = std::move(tilingResult->loops); + for (auto result : llvm::enumerate( + llvm::zip(consumer->getResults(), tilingResult->replacements))) { + tileAndFuseResult.replacements[std::get<0>(result.value())] = + std::get<1>(result.value()); + yieldedValueToResultNumber[tilingResult->tiledOp->getResult( + result.index())] = result.index(); + } } + // If there are no loops generated, fusion is immaterial. + if (tileAndFuseResult.loops.empty()) + return tileAndFuseResult; + // 2. Typically, the operands of the tiled operation are slices of the // operands of the untiled operation. These are expressed in IR using // `tensor.extract_slice` operations with source being the operands of the @@ -495,7 +507,7 @@ // values produced by operations that implement the `TilingInterface`. // Add these operations to the worklist. Operation *fusedProducer = fusedProducerValue->getDefiningOp(); - tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); + tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer); addCandidateSlices(fusedProducer, candidates); // 2e. If the slice is for a destination operand, for example, @@ -577,20 +589,19 @@ } //===----------------------------------------------------------------------===// -// LowerToLoopsUsingSCFForOp +// lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// FailureOr> -scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { - SmallVector domain = op.getIterationDomain(rewriter); - +mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, + TilingInterface op) { // TODO: Handle cases where the op has results if needed. if (op->getNumResults() > 0) { return rewriter.notifyMatchFailure( op, "unable to lower to loops operations with return values"); } + SmallVector domain = op.getIterationDomain(rewriter); SmallVector ivs; SmallVector loops; Location loc = op.getLoc(); @@ -610,6 +621,5 @@ if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { return failure(); } - rewriter.eraseOp(op); return loops; } diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -36,38 +36,46 @@ /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while /// using a `filter` to avoid recursive application. -struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { - TestTileUsingSCFForOpWithFilter(MLIRContext *context, - scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : scf::TileUsingSCFForOp(context, std::move(options), benefit), - filter(std::move(filter)) {} +struct TestTileUsingSCFForOp + : public OpInterfaceRewritePattern { + TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} /// Construct a generic pattern applied to `opName`. - TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context, - scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : scf::TileUsingSCFForOp(context, std::move(options), benefit), - filter(std::move(filter)) {} + TestTileUsingSCFForOp(StringRef opName, MLIRContext *context, + scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - auto tilingResult = returningMatchAndRewrite(op, rewriter); - if (failed(tilingResult)) { - return failure(); + FailureOr tilingResult = + scf::tileUsingSCFForOp(rewriter, op, options); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(op, "failed to tile operation"); + + if (op->getNumResults()) { + rewriter.replaceOp(op, tilingResult->replacements); + } else { + rewriter.eraseOp(op); } + filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); return success(); } private: + scf::SCFTilingOptions options; linalg::LinalgTransformationFilter filter; }; @@ -75,45 +83,74 @@ /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive /// application. -struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter - : public scf::TileConsumerAndFuseProducersUsingSCFForOp { - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - MLIRContext *context, scf::SCFTilingOptions options, +struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp + : public OpInterfaceRewritePattern { + TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( + MLIRContext *context, scf::SCFTileAndFuseOptions options, linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter(), PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp( - context, std::move(options), benefit), - filter(std::move(filter)) {} + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} /// Construct a generic pattern applied to `opName`. - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, + TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( + StringRef opName, MLIRContext *context, + scf::SCFTileAndFuseOptions options, linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter(), PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp( - context, std::move(options), benefit), - filter(std::move(filter)) {} + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter); + FailureOr tileAndFuseResult = + scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, + options); if (failed(tileAndFuseResult)) { return failure(); } + // Replace the tiled op with replacements. + SmallVector replacements(op->getNumResults()); + for (auto result : llvm::enumerate(op->getResults())) { + replacements[result.index()] = + tileAndFuseResult->replacements.lookup(result.value()); + } + rewriter.replaceOp(op, replacements); + filter.replaceLinalgTransformationFilter( rewriter, tileAndFuseResult->tiledAndFusedOps.front()); return success(); } private: + scf::SCFTileAndFuseOptions options; linalg::LinalgTransformationFilter filter; }; +/// Pattern to lower operations that implement the `TilingInterface` to +/// loops/scalar IR using `scf.for`. +struct LowerToLoopsUsingSCFForOp + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + FailureOr> loops = + scf::lowerToLoopsUsingSCFForOp(rewriter, op); + if (failed(loops)) + return rewriter.notifyMatchFailure(op, "failed to lower to loops"); + rewriter.eraseOp(op); + return loops; + } +}; + /// Test pass for testing the use of `TilingInterface`. struct TestTilingInterfacePass : public PassWrapper> { @@ -158,72 +195,78 @@ }; } // namespace -template -static void -addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns, - StringRef filterName, ArrayRef tileSizes, - ArrayRef interchange = {}) { +static void addPatternForTiling(MLIRContext *context, + RewritePatternSet &patterns, + StringRef filterName, + ArrayRef tileSizes, + ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); linalg::LinalgTransformationFilter filter( StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); - patterns.add(context, tilingOptions, filter); + patterns.add(context, tilingOptions, filter); +} + +static void addPatternForTileAndFuse(MLIRContext *context, + RewritePatternSet &patterns, + StringRef filterName, + ArrayRef tileSizes, + ArrayRef interchange = {}) { + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( + interchange); + linalg::LinalgTransformationFilter filter( + StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); + patterns.add( + context, tileAndFuseOptions, filter); } void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { if (testTiling) { // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTiling( - context, patterns, "simple_gemm", {10, 20}); + addPatternForTiling(context, patterns, "simple_gemm", {10, 20}); // 2. Tiling M, N and K of `linalg.matmul` on buffers. - addPatternForTiling( - context, patterns, "simple_gemm_memref", {10, 20, 30}); + addPatternForTiling(context, patterns, "simple_gemm_memref", {10, 20, 30}); // 3. Tiling 3D parallel generic op which implements a transpose - addPatternForTiling( - context, patterns, "parallel_generic_transpose", {10, 0, 20}); + addPatternForTiling(context, patterns, "parallel_generic_transpose", + {10, 0, 20}); // 4. Tiling 2D conv op. - addPatternForTiling( - context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30}); + addPatternForTiling(context, patterns, "simple_conv", + {0, 0, 0, 0, 10, 20, 30}); // 5. Tiling a simple op with `linalg.index` inside. - addPatternForTiling( - context, patterns, "indexed_semantics", {10, 20}); + addPatternForTiling(context, patterns, "indexed_semantics", {10, 20}); // 6. Tiling + interchange of an operation - addPatternForTiling( - context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0}); + addPatternForTiling(context, patterns, "gemm_interchange", {10, 20, 30}, + {1, 2, 0}); // 7. Tiling for 2D pad tensor operations. - addPatternForTiling( - context, patterns, "pad_2dtiling", {2, 3}); + addPatternForTiling(context, patterns, "pad_2dtiling", {2, 3}); // 8. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling( - context, patterns, "pad_inner_tiling", {0, 3}); + addPatternForTiling(context, patterns, "pad_inner_tiling", {0, 3}); // 9. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling( - context, patterns, "pad_outer_tiling", {2, 3}); + addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3}); return; } if (testTileConsumerAndFuseProducer) { - // 1. Tile and fuse of gemm with bias-add operation. - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "fusion", {10, 20}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_fusion", {10}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_plus_gemm_fusion", {10, 20}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_sequence_fusion", {10}); + // 1. Tile and fuse of gemm with fill producer and bias-add consumer. + addPatternForTileAndFuse(context, patterns, "fusion", {10, 20}); + // 2. Tile and fuse sequence of GEMMs, by fusing only along M. + addPatternForTileAndFuse(context, patterns, "gemm_fusion", {10}); + // 3. Tile and fuse gemm with consumer + interchange of tiled loops. + addPatternForTileAndFuse(context, patterns, "gemm_interchange_fusion", + {10, 20}, {1, 0}); + // 4. Tile and fuse matmul + transpose(matmul). Will introduce redundant + // computations. + addPatternForTileAndFuse(context, patterns, "gemm_plus_gemm_fusion", + {10, 20}); + // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M + // dimension. + addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}); return; } if (testLoweringToScalar) { - patterns.add(context); + patterns.add(context); } }