diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -60,38 +60,48 @@ } }; +/// Transformation information returned after tiling. struct SCFTilingResult { + /// The tiled operation generated. Operation *tiledOp; + /// The `scf.for` operations that iterate over the tiles. SmallVector loops; + /// Values to use as replacements for the untiled op. Is the same size as the + /// number of results of the untiled op. + SmallVector replacements; }; -/// Pattern to tile an op that implements the `TilingInterface` using +/// Method to tile an op that implements the `TilingInterface` using /// `scf.for` for iterating over the tiles. -struct TileUsingSCFForOp : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all TilingInterface ops. - TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// Construct a generic pattern applied to `opName`. - TileUsingSCFForOp(StringRef opName, MLIRContext *context, - SCFTilingOptions options, PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); +FailureOr tileConsumerUsingSCFForOp(RewriterBase &rewriter, + TilingInterface op, + SCFTilingOptions options); + +/// Options used to control tile + fuse. +struct SCFTileAndFuseOptions { + /// The tiling options used to control the tiling of the consumer. + SCFTilingOptions tilingOptions; + SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) { + tilingOptions = options; + return *this; } +}; -private: - /// Options to control tiling; - SCFTilingOptions options; +/// Transformation information returned after tile and fuse. +struct SCFTileAndFuseResult { + /// List of untiled operations that were fused with the tiled consumer. + llvm::SetVector fusedProducers; + /// List of tiled and fused operations generated. The first one in this list + /// is gauranteed to be the tiled operations generated during tiling of the + /// generated operation. + llvm::SetVector tiledAndFusedOps; + /// The `scf.for` operations that iterate over the tiles. + SmallVector loops; + /// The replacement values to use for the tiled and fused operations. + llvm::DenseMap replacements; }; -/// Pattern to tile and fuse a sequence of operations, by tiling the consumer +/// Method to tile and fuse a sequence of operations, by tiling the consumer /// and fusing its producers. Note that this assumes that it is valid to /// tile+fuse the producer into the innermost tiled loop. Its up to the caller /// to ensure that the tile sizes provided make this fusion valid. @@ -99,64 +109,32 @@ /// For example, for the following sequence /// /// ```mlir -/// %0 = linalg.fill ... -/// %1 = linalg.matmul ... outs(%0 : ...) ... +/// %0 = +/// %1 = linalg.fill ... outs(%0 : ... ) +/// %2 = linalg.matmul ... outs(%1 : ...) ... /// ``` /// /// it is legal to fuse the fill with the matmul only if the matmul is tiled /// along the parallel dimensions and not the reduction dimension, i.e. the tile -/// size for the reduction dimension should be 0. -struct SCFTileAndFuseResult { - SmallVector tiledAndFusedOps; - SmallVector loops; -}; -struct TileConsumerAndFuseProducersUsingSCFForOp - : public OpInterfaceRewritePattern { - - /// Construct a generic pattern applied to all TilingInterface ops. - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// Construct a generic pattern applied to `opName`. - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - SCFTilingOptions options, - PatternBenefit benefit = 1); - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } - -private: - /// This pattern uses the tiling pattern. Instead of using inheritance, use - /// the patterns as private object that is instantiated at the same time as - /// this pattern. - TileUsingSCFForOp tilingPattern; -}; - -/// Pattern to lower operations that implement the `TilingInterface` to -/// loops/scalar IR using `scf.for`. -struct LowerToLoopsUsingSCFForOp - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - FailureOr> - returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const; - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - return returningMatchAndRewrite(op, rewriter); - } -}; +/// size for the reduction dimension should be 0. The resulting fused +/// transformation is +/// +/// ```mlir +/// %1 = scf.for ... iter_args(%arg0 = %0) +/// %2 = tensor.extract_slice %arg0 +/// %3 = linalg.fill .. outs(%2 : ... ) +/// %4 = linalg.matmul .. outs(%3 : ...) +/// } +/// ``` +FailureOr +tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter, + TilingInterface consumer, + SCFTileAndFuseOptions options); + +/// Method to lower an `op` that implements the `TilingInterface` to +/// loops/scalars. +FailureOr> +lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op); } // namespace scf } // namespace mlir diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -44,13 +44,15 @@ /// - `loop` isnt erased, but is left in a "no-op" state where the body of the /// loop just yields the basic block arguments that correspond to the /// initialization values of a loop. The loop is dead after this method. -/// - All uses of the `newIterOperands` within the generated new loop +/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the +/// `newIterOperands` within the generated new loop /// are replaced with the corresponding `BlockArgument` in the loop body. using NewYieldValueFn = std::function( OpBuilder &b, Location loc, ArrayRef newBBArgs)>; scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn); + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop = true); /// Update a perfectly nested loop nest to yield new values from the innermost /// loop and propagating it up through the loop nest. This function @@ -64,12 +66,14 @@ /// the body of the loop just yields the basic block arguments that correspond /// to the initialization values of a loop. The original loops are dead after /// this method. -/// - All uses of the `newIterOperands` within the generated new loop +/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the +/// `newIterOperands` within the generated new loop /// are replaced with the corresponding `BlockArgument` in the loop body. SmallVector replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef loopNest, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValueFn); + const NewYieldValueFn &newYieldValueFn, + bool replaceIterOperandsUsesInLoop = true); /// Outline a region with a single block into a new FuncOp. /// Assumes the FuncOp result types is the type of the yielded operands of the diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -240,7 +240,7 @@ using PointerUnion::PointerUnion; public: - void dump() { llvm::errs() << *this << "\n"; } + void dump() const { llvm::errs() << *this << "\n"; } }; /// Allow printing to a stream. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -87,7 +87,7 @@ } //===----------------------------------------------------------------------===// -// TileUsingSCFForOp pattern implementation. +// tileConsumerUsingSCFForOp implementation. //===----------------------------------------------------------------------===// // Check if `stride` evenly divides the trip count `size - offset`. @@ -167,22 +167,106 @@ return loops; } -scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} +/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, +/// construct the destructive update pattern that inserts the yielded +/// value into a destination tensor provided by `initValue` at offset +/// `tileOffsets` and size `tileSizes`. For example, +/// +/// ```mlir +/// scf.for %iv0 = ... { +/// %0 = tiled_op +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { +/// %0 = tiled_op +/// %1 = tensor.insert_slice %0 into %arg[..] [..] [..] +/// scf.yield %1 +/// } +/// ``` +/// +/// To reduce overhead, batch process these by accepting a list of +/// `yieldedValue`, `initValue`, `tileOffset` and `tileSizes`. +static FailureOr> +yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, + ValueRange yieldedValues, + ArrayRef> tileOffsetsList, + ArrayRef> tileSizesList, + MutableArrayRef loops) { + NewYieldValueFn yieldValueFn = + [&](OpBuilder &b, Location loc, + ArrayRef newBBArgs) -> SmallVector { + SmallVector inserts; + for (auto yieldedValue : llvm::enumerate(yieldedValues)) { + ArrayRef tileOffsets = + tileOffsetsList[yieldedValue.index()]; + ArrayRef tileSizes = tileSizesList[yieldedValue.index()]; + SmallVector tileStrides(tileOffsets.size(), + b.getIndexAttr(1)); + Value insert = b.create( + loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], + tileOffsets, tileSizes, tileStrides); + inserts.push_back(insert); + } + return inserts; + }; + + SmallVector newLoops = + replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, + /*replaceIterOperandsUsesInLoop =*/false); + for (const auto &loop : llvm::enumerate(loops)) { + rewriter.eraseOp(loop.value()); + loops[loop.index()] = newLoops[loop.index()]; + } + return llvm::to_vector(llvm::map_range( + loops.front().getResults().take_back(yieldedValues.size()), + [](OpResult r) -> Value { return r; })); +} -scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName, - MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)) {} +/// If the tiled operation is destination passing style, update the +/// slice of the destination used (which refers to the untiled destination) +/// to use the corresponding region argument of the innermost loop. +/// +/// ```mlir +/// %0 = +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %0 +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +/// +/// is transformed to +/// +/// ```mlir +/// scf.for %iv0 = ... iter_args(%arg = %0) { +/// %1 = tensor.extract_slice %arg +/// %2 = tiled_op +/// %3 = tensor.insert_slice %2 into %arg +/// scf.yield %3 +/// } +/// ``` +static void +updateDestinationOperandsForTiledOp(OpBuilder &builder, + ValueRange tiledOpDestinationValues, + ValueRange bbArgsList) { + for (auto destValue : llvm::enumerate(tiledOpDestinationValues)) { + auto sliceOp = destValue.value().getDefiningOp(); + if (!sliceOp) + continue; + sliceOp.setOperand(0, bbArgsList[destValue.index()]); + } +} +/// Implementation of tiling transformation of `op` that implements the +/// `TilingInterface` using `scf.for` to iterate over the tiles. FailureOr -scf::TileUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { +mlir::scf::tileConsumerUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, + scf::SCFTilingOptions options) { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(op); @@ -244,292 +328,262 @@ offsets = applyPermutationToVector(offsets, inversePermutation); sizes = applyPermutationToVector(sizes, inversePermutation); } + } - LLVM_DEBUG({ - if (!tilingResult.loops.empty()) { - llvm::errs() << "LoopNest shell :\n"; - tilingResult.loops.front().dump(); - llvm::errs() << "\n"; - } - }); - - // 4. Generate the tiled implementation within the inner most loop. - if (!tilingResult.loops.empty()) - rewriter.setInsertionPoint( - tilingResult.loops.back().getBody()->getTerminator()); - SmallVector tiledImplementation = - op.getTiledImplementation(rewriter, offsets, sizes); - if (tiledImplementation.size() != 1) { - return rewriter.notifyMatchFailure( - op, "expected tiled implementation to return a single op"); + LLVM_DEBUG({ + if (!tilingResult.loops.empty()) { + llvm::errs() << "LoopNest shell :\n"; + tilingResult.loops.front().dump(); + llvm::errs() << "\n"; } - tilingResult.tiledOp = tiledImplementation[0]; - - LLVM_DEBUG({ - if (!tilingResult.loops.empty()) { - llvm::errs() << "After tiled implementation :\n"; - tilingResult.loops.front().dump(); - llvm::errs() << "\n"; - } - }); + }); + + // 4. Generate the tiled implementation within the inner most loop. + if (!tilingResult.loops.empty()) + rewriter.setInsertionPoint( + tilingResult.loops.back().getBody()->getTerminator()); + SmallVector tiledImplementation = + op.getTiledImplementation(rewriter, offsets, sizes); + if (tiledImplementation.size() != 1) { + return rewriter.notifyMatchFailure( + op, "expected tiled implementation to return a single op"); } - + tilingResult.tiledOp = tiledImplementation[0]; if (op->getNumResults() == 0) { - rewriter.eraseOp(op); + // nothing more to do. return tilingResult; } - // 5. If the original operations has results, modify the loop nest to yield - // the replacement values. - SmallVector replacements; + // If loops are empty, the tiled op is used as the replacement for the untiled + // op. if (tilingResult.loops.empty()) { - // 5a. If there were no loops, the tiled implementation results are the - // replacements. - rewriter.replaceOp(op, tilingResult.tiledOp->getResults()); + tilingResult.replacements = llvm::to_vector( + llvm::map_range(tiledImplementation[0]->getResults(), + [](OpResult result) -> Value { return result; })); return tilingResult; } - // 5b. `scf.for` with tensor semantics requires the loop nest to yield the - // replacement values using destructive updates. Use the `TilingInterface` - // to get the position of the result tiles and use that to generate the - // destructive update pattern, i.e., - // - // ```mlir - // scf.for %iv0 = ... { - // %0 = tiled_op - // } - // ``` - // - // is transformed to - // - // ```mlir - // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. { - // %0 = tiled_op - // %1 = tensor.insert_slice %0 into %arg[..] [..] [..] - // scf.yield %1 - // } - // ``` - NewYieldValueFn yieldValueFn = - [&](OpBuilder &b, Location loc, - ArrayRef newBBArgs) -> SmallVector { - SmallVector yieldedValues; - Attribute one = b.getIndexAttr(1); - for (auto resultNum : llvm::seq(0, op->getNumResults())) { - SmallVector resultTileOffsets, resultTileSizes; - if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes, - resultTileOffsets, - resultTileSizes))) { - op.emitOpError("unable to get position of result ") - << resultNum << " of the tiled implementation"; - return {}; - } - SmallVector resultTileStrides(resultTileOffsets.size(), - one); - Value yieldedValue = b.create( - op->getLoc(), tilingResult.tiledOp->getResult(resultNum), - newBBArgs[resultNum], resultTileOffsets, resultTileSizes, - resultTileStrides); - yieldedValues.push_back(yieldedValue); + // 5. Yield all the results of the tiled operation. The surrounding loop + // nest is modified to insert a destructive update pattern to yield + // from the loop nest values to replace the untiled op with. + unsigned numResults = op->getNumResults(); + SmallVector> resultOffsetsList(numResults), + resultSizesList(numResults); + for (auto result : llvm::enumerate(op->getResults())) { + if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, + sizes, + resultOffsetsList[result.index()], + resultSizesList[result.index()]))) { + return rewriter.notifyMatchFailure( + op, "failed to get slice of result produced"); } - return yieldedValues; - }; - SmallVector newLoops = replaceLoopNestWithNewYields( - rewriter, tilingResult.loops, op.getDestinationOperands(rewriter), - yieldValueFn); - for (const auto &loop : llvm::enumerate(tilingResult.loops)) { - rewriter.eraseOp(loop.value()); - tilingResult.loops[loop.index()] = newLoops[loop.index()]; } - rewriter.replaceOp(op, tilingResult.loops.front().getResults()); + + FailureOr> replacementOr = + yieldTiledValues(rewriter, op.getDestinationOperands(rewriter), + tilingResult.tiledOp->getResults(), resultOffsetsList, + resultSizesList, tilingResult.loops); + if (failed(replacementOr)) + return rewriter.notifyMatchFailure(op, "failed to yield replacement"); + + tilingResult.replacements = replacementOr.value(); + if (auto tiledInterfaceOp = dyn_cast(tilingResult.tiledOp)) { + scf::ForOp innerMostLoop = tilingResult.loops.back(); + SmallVector tiledOpDestinationValues = + tiledInterfaceOp.getDestinationOperands(rewriter); + updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationValues, + innerMostLoop.getRegionIterArgs()); + } + + LLVM_DEBUG({ + if (!tilingResult.loops.empty()) { + llvm::errs() << "After tiled implementation :\n"; + tilingResult.loops.front().dump(); + llvm::errs() << "\n"; + } + }); return tilingResult; } //===----------------------------------------------------------------------===// -// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation. +// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. //===----------------------------------------------------------------------===// -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} - -scf::TileConsumerAndFuseProducersUsingSCFForOp:: - TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName, - MLIRContext *context, - scf::SCFTilingOptions options, - PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - tilingPattern(context, std::move(options)) {} - -/// Return the `Value` that is defined by an operation that implements -/// the `TilingInterface`. Looks through `iter_args` of scf.for nest -/// if required. -static Optional getFusableProducer(Value v) { - while (auto blockArg = v.dyn_cast()) { - auto loopOp = dyn_cast(blockArg.getOwner()->getParentOp()); - if (!loopOp) - return llvm::None; - v = loopOp.getOpOperandForRegionIterArg(blockArg).get(); - } - if (!isa_and_nonnull(v.getDefiningOp())) - return llvm::None; - return v.cast(); -} - -// Replace iter args of the outer most loop with region args of the inner most -// one. -static void replaceIterArgs(scf::ForOp outerFor, scf::ForOp innerFor, - PatternRewriter &rewriter) { - assert(outerFor.getNumIterOperands() == innerFor.getNumIterOperands() && - "expect same number of iter args"); - Block *block = &(*innerFor.getRegion().begin()); - for (auto it : - llvm::zip(outerFor.getIterOperands(), innerFor.getRegionIterArgs())) { - Value source = std::get<0>(it); - Value target = std::get<1>(it); - source.replaceUsesWithIf(target, [&](OpOperand &use) { - return use.getOwner()->getBlock() == block; - }); +/// Return the untiled producer whose slice is used in a tiled consumer. The +/// method traverses the tile loop nest (`loops`) if needed, and returns the +/// `iter_args` of the outer most that is encountered. If there was no loop +/// traversal needed, the second value of the returned tuple is empty. +static std::tuple> +getProducerOp(OpOperand *source, ArrayRef loops) { + Optional destinationIterArg; + auto loopIt = loops.rbegin(); + while (auto iterArg = source->get().dyn_cast()) { + scf::ForOp loop = *loopIt; + if (iterArg.getOwner()->getParentOp() != loop) + break; + source = &loop.getOpOperandForRegionIterArg(iterArg); + loopIt++; } + if (loopIt == loops.rend()) + destinationIterArg = source; + return {source->get().dyn_cast(), destinationIterArg}; } +/// Implementation of tile consumer and fuse producer greedily. FailureOr -scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { +mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( + RewriterBase &rewriter, TilingInterface consumer, + scf::SCFTileAndFuseOptions options) { // This transformation is only valid for ops that return values (i.e. not // valid to use with operations that have memref operands). - if (!op->getNumResults()) { + if (!consumer->getNumResults()) { return rewriter.notifyMatchFailure( - op, "invalid pattern for op with no results"); + consumer, "invalid pattern for op with no results"); } // 1. First tile the consumer. - SCFTileAndFuseResult tileAndFuseResult; + scf::SCFTileAndFuseResult tileAndFuseResult; + llvm::SmallDenseMap yieldedValueToResultNumber; { - FailureOr tilingResult = - tilingPattern.returningMatchAndRewrite(op, rewriter); - if (failed(tilingResult)) { - return failure(); - } - tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp); + FailureOr tilingResult = + tileConsumerUsingSCFForOp(rewriter, consumer, options.tilingOptions); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); + tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp); tileAndFuseResult.loops = std::move(tilingResult->loops); + for (auto result : llvm::enumerate( + llvm::zip(consumer->getResults(), tilingResult->replacements))) { + tileAndFuseResult.replacements[std::get<0>(result.value())] = + std::get<1>(result.value()); + yieldedValueToResultNumber[tilingResult->tiledOp->getResult( + result.index())] = result.index(); + } } - // 2. Typically, the operands of the tiled operation are slices of the - // operands of the untiled operation. These are expressed in IR using - // `tensor.extract_slice` operations with source being the operands of the - // untiled operation. Create a worklist of these `tensor.extract_slice` - // operations. If the producers of the source of the `tensor.extract_slice` - // can be tiled such that the tiled value is generated in-place, that - // effectively tiles + fuses the operations. - auto addCandidateSlices = [](Operation *fusedOp, - std::deque &candidates) { - for (Value operand : fusedOp->getOperands()) - if (auto sliceOp = operand.getDefiningOp()) - candidates.push_back(sliceOp); + // If there are no loops generated, fusion is immaterial. + if (tileAndFuseResult.loops.empty()) + return tileAndFuseResult; + + // 2. Use a worklist algorithm to fuse greedily with the producer. The + // worklist contains `tensor.extract_slice` operations. If they + // are produced by an operation that implements the `TilingInterface`, + // replace the `tensor.extract_slice` with a tiled implementation of the + // producer. This in-turn might generate more `tensor.extract_slice`s + // for operands of the producer. Repeat this process till nothing else + // can be fused. + SmallVector candidateSlices; + auto populateCandidateSlices = [&](Operation *tiledOp) { + for (OpOperand &operand : tiledOp->getOpOperands()) { + if (auto sliceOp = + operand.get().getDefiningOp()) { + candidateSlices.push_back(sliceOp); + } + } }; - - std::deque candidates; - addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); - OpBuilder::InsertionGuard g(rewriter); - while (!candidates.empty()) { - // 2a. Traverse the slices in BFS fashion. - tensor::ExtractSliceOp candidateSliceOp = candidates.front(); - candidates.pop_front(); - - // 2b. Get the producer of the source (potentially walking through - // `iter_args` of nested `scf.for`) - Optional fusableProducer = - getFusableProducer(candidateSliceOp.getSource()); - if (!fusableProducer) + populateCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back()); + while (!candidateSlices.empty()) { + tensor::ExtractSliceOp sliceOp = candidateSlices.pop_back_val(); + + OpResult producer; + Optional destinationIterArg; + std::tie(producer, destinationIterArg) = + getProducerOp(&sliceOp->getOpOperand(0), tileAndFuseResult.loops); + if (!producer) continue; - // 2c. Generate the tiled implementation of the producer of the source - rewriter.setInsertionPoint(candidateSliceOp); + // 2a. Based on the slice of the producer needed, generate the tiled + // producer. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(sliceOp); FailureOr fusedProducerValue = - tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, - fusableProducer.value()); + tensor::replaceExtractSliceWithTiledProducer(rewriter, sliceOp, + producer); if (failed(fusedProducerValue)) continue; - rewriter.replaceOp(candidateSliceOp, fusedProducerValue.value()); - - // 2d. The operands of the fused producer might themselved be slices of - // values produced by operations that implement the `TilingInterface`. - // Add these operations to the worklist. - Operation *fusedProducer = fusedProducerValue->getDefiningOp(); - tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer); - addCandidateSlices(fusedProducer, candidates); - - // 2e. If the operation being fused creates a value that is used as `outs` - // in the tiled operation, the result of the unfused operation will be - // used in the `iter_args` of the tiled loop generated. When the - // operation is fused, this use in `iter_args` needs to be modified to - // use the destination of the fused operation. For example, starting - // with - // - // ```mlir - // %0 = linalg.init_tensor ... - // %1 = linalg.fill ... outs(%0:...)... - // %2 = linalg.matmul ... outs(%1:...).... - // ``` - // - // First the `linalg.matmul` gets tiled - // - // ```mlir - // %0 = linalg.init_tensor - // %1 = linalg.fill - // %2 = scf.for .... iter_args(%arg0 = %1)... - // ... - // ... = linalg.matmul ... + + rewriter.replaceOp(sliceOp, fusedProducerValue.value()); + + // 2b. If the slice is for a destination operand, then + // - Update the iter_arg of the outer most loop to use the destination + // of + // the untiled producer. + // - Update the destination of the slice of the tiled producer generated + // to use the same basic block argument as the slice that was used to + // producer the tiled implementation. + // For example // - // ``` + // ```mlir + // %0 = linalg.init + // %1 = linalg.fill .. outs(%0 : ) + // %2 = scf.for .. iter_args(%arg0 = %1) { + // %3 = tensor.extract_slice %arg0 + // .. = linalg.matmul .. outs(%3 : ) + // } + // ``` // - // When the `linalg.fill` gets fused, the `iter_args` needs to be - // modified + // is transformed to // - // ```mlir - // %0 = linalg.init_tensor - // %1 = scf.for ... iter_args(%arg0 = %0)... - // ... - // %2 = linalg.fill ... - // %3 = linalg.matmul ... outs(%2: ...)... - // ``` - TilingInterface unfusedProducerOp = - cast(fusableProducer->getOwner()); - scf::ForOp outerMostTiledLoop = tileAndFuseResult.loops.front(); - SmallVector unfusedProducerOpDestValues = - unfusedProducerOp.getDestinationOperands(rewriter); - for (OpOperand &uses : unfusedProducerOp->getUses()) { - if (uses.getOwner() == outerMostTiledLoop.getOperation()) { - unsigned resultNumber = uses.get().cast().getResultNumber(); - unsigned operandNumber = uses.getOperandNumber(); - outerMostTiledLoop->setOperand( - operandNumber, unfusedProducerOpDestValues[resultNumber]); + // ```mlir + // %0 = linalg.init + // %1 = scf.for .. iter_args(%arg0 = %0) { + // %3 = tensor.extract_slice %arg0 + // %4 = linalg.fill .. outs(%3 : ) + // .. = linalg.matmul .. outs(%4 : ) + // } + // ``` + if (destinationIterArg) { + unsigned iterArgNumber = destinationIterArg.value()->getOperandNumber(); + unsigned resultNumber = producer.getResultNumber(); + if (auto producerOp = dyn_cast(producer.getOwner())) { + scf::ForOp outerMostLoop = tileAndFuseResult.loops.front(); + SmallVector destination = + producerOp.getDestinationOperands(rewriter); + outerMostLoop.setOperand(iterArgNumber, destination[resultNumber]); } + if (auto tiledAndFusedInterfaceOp = + fusedProducerValue.value().getDefiningOp()) { + scf::ForOp innerMostLoop = tileAndFuseResult.loops.back(); + SmallVector destination = + tiledAndFusedInterfaceOp.getDestinationOperands(rewriter); + updateDestinationOperandsForTiledOp( + rewriter, destination[resultNumber], + innerMostLoop + .getRegionIterArgs()[iterArgNumber - + innerMostLoop.getNumControlOperands()]); + } + } + tileAndFuseResult.fusedProducers.insert(producer.getOwner()); + if (auto tiledAndFusedOp = fusedProducerValue.value().getDefiningOp()) { + tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); + populateCandidateSlices(tiledAndFusedOp); } + + LLVM_DEBUG({ + if (!tileAndFuseResult.loops.empty()) { + llvm::errs() << "After fusing producer: \n"; + tileAndFuseResult.loops.front().dump(); + llvm::errs() << "\n"; + } + }); } - replaceIterArgs(tileAndFuseResult.loops.front(), - tileAndFuseResult.loops.back(), rewriter); + return tileAndFuseResult; } //===----------------------------------------------------------------------===// -// LowerToLoopsUsingSCFForOp +// lowerToLoopsUsingSCFForOp implementation. //===----------------------------------------------------------------------===// FailureOr> -scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite( - TilingInterface op, PatternRewriter &rewriter) const { - SmallVector domain = op.getIterationDomain(rewriter); - +mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, + TilingInterface op) { // TODO: Handle cases where the op has results if needed. if (op->getNumResults() > 0) { return rewriter.notifyMatchFailure( op, "unable to lower to loops operations with return values"); } + SmallVector domain = op.getIterationDomain(rewriter); SmallVector ivs; SmallVector loops; Location loc = op.getLoc(); @@ -549,6 +603,5 @@ if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { return failure(); } - rewriter.eraseOp(op); return loops; } diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -40,7 +40,8 @@ scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop, ValueRange newIterOperands, - const NewYieldValueFn &newYieldValuesFn) { + const NewYieldValueFn &newYieldValuesFn, + bool replaceIterOperandsUsesInLoop) { // Create a new loop before the existing one, with the extra operands. OpBuilder::InsertionGuard g(builder); builder.setInsertionPoint(loop); @@ -79,13 +80,15 @@ llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size()))) std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); - // Replace all uses of `newIterOperands` with the corresponding basic block - // arguments. - for (auto it : llvm::zip(newIterOperands, newBBArgs)) { - std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { - Operation *user = use.getOwner(); - return newLoop->isProperAncestor(user); - }); + if (replaceIterOperandsUsesInLoop) { + // Replace all uses of `newIterOperands` with the corresponding basic block + // arguments. + for (auto it : llvm::zip(newIterOperands, newBBArgs)) { + std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) { + Operation *user = use.getOwner(); + return newLoop->isProperAncestor(user); + }); + } } // Replace all uses of the original loop with corresponding values from the @@ -104,7 +107,8 @@ SmallVector mlir::replaceLoopNestWithNewYields( OpBuilder &builder, ArrayRef loopNest, - ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn) { + ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn, + bool replaceIterOperandsUsesInLoop) { if (loopNest.empty()) return {}; SmallVector newLoopNest(loopNest.size()); @@ -121,8 +125,11 @@ newIterOperands.size())); return newYields; }; - newLoopNest[loopDepth] = replaceLoopWithNewYields( - builder, loopNest[loopDepth], newIterOperands, fn); + newLoopNest[loopDepth] = + replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands, + fn, replaceIterOperandsUsesInLoop); + newLoopNest[loopDepth + 1].getInitArgsMutable().assign( + newLoopNest[loopDepth].getRegionIterArgs()); } return newLoopNest; } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -30,7 +30,7 @@ // CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : // CHECK-SAME: outs(%[[FILL_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -68,7 +68,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV0]], %[[IV1]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -80,7 +80,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]], %[[BIAS_TILE]] : // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV0]], %[[IV1]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -130,7 +130,7 @@ // CHECK-SAME: ins(%[[GEMM0_TILE]], %[[RHS1_TILE]] : // CHECK-SAME: outs(%[[FILL1_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GEMM1_TILE]] into %[[ITERARG]][%[[IV]], 0] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -182,7 +182,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]] : // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- @@ -218,7 +218,7 @@ // CHECK-SAME: iter_args(%[[ITERARG1:.+]] = %[[ITERARG0]]) // CHECK-DAG: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] // CHECK-DAG: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] -// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV1]], %[[IV0]]] +// CHECK-DAG: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV0]]] // CHECK: %[[FILL_TILE:.+]] = linalg.fill // CHECK-SAME: outs(%[[INIT_TILE]] : // CHECK: %[[GEMM_TILE:.+]] = linalg.matmul @@ -229,7 +229,7 @@ // CHECK-SAME: ins(%[[GEMM_TILE]] : // CHECK-SAME: outs(%[[INIT_TILE_2]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] -// CHECK scf.yield %[[INSERT]] +// CHECK: scf.yield %[[INSERT]] // ----- diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -36,38 +36,46 @@ /// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using /// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while /// using a `filter` to avoid recursive application. -struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp { - TestTileUsingSCFForOpWithFilter(MLIRContext *context, - scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : scf::TileUsingSCFForOp(context, std::move(options), benefit), - filter(std::move(filter)) {} +struct TestTileUsingSCFForOp + : public OpInterfaceRewritePattern { + TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} /// Construct a generic pattern applied to `opName`. - TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context, - scf::SCFTilingOptions options, - linalg::LinalgTransformationFilter filter = - linalg::LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : scf::TileUsingSCFForOp(context, std::move(options), benefit), - filter(std::move(filter)) {} + TestTileUsingSCFForOp(StringRef opName, MLIRContext *context, + scf::SCFTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - auto tilingResult = returningMatchAndRewrite(op, rewriter); - if (failed(tilingResult)) { - return failure(); + FailureOr tilingResult = + scf::tileConsumerUsingSCFForOp(rewriter, op, options); + if (failed(tilingResult)) + return rewriter.notifyMatchFailure(op, "failed to tile operation"); + + if (op->getNumResults()) { + rewriter.replaceOp(op, tilingResult->replacements); + } else { + rewriter.eraseOp(op); } + filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp); return success(); } private: + scf::SCFTilingOptions options; linalg::LinalgTransformationFilter filter; }; @@ -75,45 +83,74 @@ /// (that tiles and fuses operations using the `TilingInterface` with `scf.for` /// ops for iterating over the tiles) while using a `filter` to avoid recursive /// application. -struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter - : public scf::TileConsumerAndFuseProducersUsingSCFForOp { - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - MLIRContext *context, scf::SCFTilingOptions options, +struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp + : public OpInterfaceRewritePattern { + TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( + MLIRContext *context, scf::SCFTileAndFuseOptions options, linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter(), PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp( - context, std::move(options), benefit), - filter(std::move(filter)) {} + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} /// Construct a generic pattern applied to `opName`. - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter( - StringRef opName, MLIRContext *context, scf::SCFTilingOptions options, + TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( + StringRef opName, MLIRContext *context, + scf::SCFTileAndFuseOptions options, linalg::LinalgTransformationFilter filter = linalg::LinalgTransformationFilter(), PatternBenefit benefit = 1) - : scf::TileConsumerAndFuseProducersUsingSCFForOp( - context, std::move(options), benefit), - filter(std::move(filter)) {} + : OpInterfaceRewritePattern(context, benefit), + options(std::move(options)), filter(std::move(filter)) {} LogicalResult matchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const override { if (failed(filter.checkAndNotify(rewriter, op))) return failure(); - auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter); + FailureOr tileAndFuseResult = + scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, + options); if (failed(tileAndFuseResult)) { return failure(); } + // Replace the tiled op with replacements. + SmallVector replacements(op->getNumResults()); + for (auto result : llvm::enumerate(op->getResults())) { + replacements[result.index()] = + tileAndFuseResult->replacements.lookup(result.value()); + } + rewriter.replaceOp(op, replacements); + filter.replaceLinalgTransformationFilter( rewriter, tileAndFuseResult->tiledAndFusedOps.front()); return success(); } private: + scf::SCFTileAndFuseOptions options; linalg::LinalgTransformationFilter filter; }; +/// Pattern to lower operations that implement the `TilingInterface` to +/// loops/scalar IR using `scf.for`. +struct LowerToLoopsUsingSCFForOp + : public OpInterfaceRewritePattern { + using OpInterfaceRewritePattern::OpInterfaceRewritePattern; + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + LogicalResult matchAndRewrite(TilingInterface op, + PatternRewriter &rewriter) const override { + FailureOr> loops = + scf::lowerToLoopsUsingSCFForOp(rewriter, op); + if (failed(loops)) + return rewriter.notifyMatchFailure(op, "failed to lower to loops"); + rewriter.eraseOp(op); + return loops; + } +}; + /// Test pass for testing the use of `TilingInterface`. struct TestTilingInterfacePass : public PassWrapper> { @@ -158,72 +195,78 @@ }; } // namespace -template -static void -addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns, - StringRef filterName, ArrayRef tileSizes, - ArrayRef interchange = {}) { +static void addPatternForTiling(MLIRContext *context, + RewritePatternSet &patterns, + StringRef filterName, + ArrayRef tileSizes, + ArrayRef interchange = {}) { scf::SCFTilingOptions tilingOptions; tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); linalg::LinalgTransformationFilter filter( StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); - patterns.add(context, tilingOptions, filter); + patterns.add(context, tilingOptions, filter); +} + +static void addPatternForTileAndFuse(MLIRContext *context, + RewritePatternSet &patterns, + StringRef filterName, + ArrayRef tileSizes, + ArrayRef interchange = {}) { + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange( + interchange); + linalg::LinalgTransformationFilter filter( + StringAttr::get(context, filterName), StringAttr::get(context, "tiled")); + patterns.add( + context, tileAndFuseOptions, filter); } void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, RewritePatternSet &patterns) { if (testTiling) { // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTiling( - context, patterns, "simple_gemm", {10, 20}); + addPatternForTiling(context, patterns, "simple_gemm", {10, 20}); // 2. Tiling M, N and K of `linalg.matmul` on buffers. - addPatternForTiling( - context, patterns, "simple_gemm_memref", {10, 20, 30}); + addPatternForTiling(context, patterns, "simple_gemm_memref", {10, 20, 30}); // 3. Tiling 3D parallel generic op which implements a transpose - addPatternForTiling( - context, patterns, "parallel_generic_transpose", {10, 0, 20}); + addPatternForTiling(context, patterns, "parallel_generic_transpose", + {10, 0, 20}); // 4. Tiling 2D conv op. - addPatternForTiling( - context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30}); + addPatternForTiling(context, patterns, "simple_conv", + {0, 0, 0, 0, 10, 20, 30}); // 5. Tiling a simple op with `linalg.index` inside. - addPatternForTiling( - context, patterns, "indexed_semantics", {10, 20}); + addPatternForTiling(context, patterns, "indexed_semantics", {10, 20}); // 6. Tiling + interchange of an operation - addPatternForTiling( - context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0}); + addPatternForTiling(context, patterns, "gemm_interchange", {10, 20, 30}, + {1, 2, 0}); // 7. Tiling for 2D pad tensor operations. - addPatternForTiling( - context, patterns, "pad_2dtiling", {2, 3}); + addPatternForTiling(context, patterns, "pad_2dtiling", {2, 3}); // 8. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling( - context, patterns, "pad_inner_tiling", {0, 3}); + addPatternForTiling(context, patterns, "pad_inner_tiling", {0, 3}); // 9. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling( - context, patterns, "pad_outer_tiling", {2, 3}); + addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3}); return; } if (testTileConsumerAndFuseProducer) { - // 1. Tile and fuse of gemm with bias-add operation. - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "fusion", {10, 20}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_fusion", {10}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_plus_gemm_fusion", {10, 20}); - addPatternForTiling< - TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( - context, patterns, "gemm_sequence_fusion", {10}); + // 1. Tile and fuse of gemm with fill producer and bias-add consumer. + addPatternForTileAndFuse(context, patterns, "fusion", {10, 20}); + // 2. Tile and fuse sequence of GEMMs, by fusing only along M. + addPatternForTileAndFuse(context, patterns, "gemm_fusion", {10}); + // 3. Tile and fuse gemm with consumer + interchange of tiled loops. + addPatternForTileAndFuse(context, patterns, "gemm_interchange_fusion", + {10, 20}, {1, 0}); + // 4. Tile and fuse matmul + transpose(matmul). Will introduce redundant + // computations. + addPatternForTileAndFuse(context, patterns, "gemm_plus_gemm_fusion", + {10, 20}); + // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M + // dimension. + addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}); return; } if (testLoweringToScalar) { - patterns.add(context); + patterns.add(context); } }