diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -77,6 +77,9 @@ /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); +/// Create a pass to tile a LinalgOp and fuse its producers. +std::unique_ptr> createLinalgTileAndFuseTensorOpsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -243,4 +243,18 @@ }]; } +def LinalgTileAndFuseTensorOps + : FunctionPass<"linalg-tile-and-fuse-tensor-ops"> { + let summary = "Tile a LinalgOp and fuse its producers."; + let constructor = "mlir::createLinalgTileAndFuseTensorOpsPass()"; + let options = [ + ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ListOption<"tileInterchange", "tile-interchange", "int64_t", + "Tile loop interchange", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; + let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -172,6 +172,34 @@ OpResult producerOpResult, OpOperand &consumerOpOperand); +//===----------------------------------------------------------------------===// +// Fusion on tensor utilities +//===----------------------------------------------------------------------===// + +/// A struct to manage the tile loop nest specific information. +struct TileLoopNest { + /// Tiled root operation. + LinalgOp rootOp; + /// Tile loop operations from outermost to innermost. + SmallVector loopOps; + /// Tiled root operation loop dimensions from outermost to innermost. + SmallVector loopDims; +}; + +/// Fuses the producer of `consumerOpOperand` in place of `sliceOp` if possible. +/// Assumes `tileLoopNest` tiles the consumer and the producer is a LinalgOp. +// TODO: add replace uses callback to support passes and patterns. +FailureOr fuseProducer(OpBuilder &b, tensor::ExtractSliceOp sliceOp, + OpOperand *consumerOpOperand, + TileLoopNest &tileLoopNest); + +/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the +/// `tileSizes` and `tileInterchange` parameters to control the tiling. +FailureOr +tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, + ArrayRef tileSizes, + ArrayRef tileInterchange); + //===----------------------------------------------------------------------===// // Distribution utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp Fusion.cpp + FusionOnTensors.cpp Generalization.cpp Hoisting.cpp InlineScalarOperands.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -0,0 +1,521 @@ +//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements linalg fusion on tensors +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace linalg; + +//===----------------------------------------------------------------------===// +// TileLoopNest specific helpers. +//===----------------------------------------------------------------------===// + +/// Returns true if the tile loop nest has no tile loops. +static bool isEmptyTileLoopNest(TileLoopNest &tileLoopNest) { + return tileLoopNest.loopOps.empty(); +} + +/// Returns true if the tile loop nest invariants are satisfied: +/// - The number of tile loop operations and dimensions match. +/// - The innermost tile loop is the parent of `tiledOp`. +/// - The tile loops are directly nested. +// TODO: relax to support additional control flow, e.g., IfOp. +static bool isValidTileLoopNest(TileLoopNest &tileLoopNest) { + // Check if the number of `tileLoopOps` and `tileLoopDims` match. + if (tileLoopNest.loopOps.size() != tileLoopNest.loopDims.size()) + return false; + + // Check if the innermost tile loop is the parent of `tiledOp`. + if (tileLoopNest.rootOp->getParentOp() != tileLoopNest.loopOps.back()) + return false; + + // Check if the tile loops are directly nested. + return std::adjacent_find(tileLoopNest.loopOps.begin(), + tileLoopNest.loopOps.end(), + [](Operation *op1, Operation *op2) { + return op1 != op2->getParentOp(); + }) == tileLoopNest.loopOps.end(); +} + +/// Updates the tile loop nest after tiling given the newly tiled op `tiledOp` +/// and its tile loops `loopOps` as well as the tiling options needed to keep +/// track of the tile loop dimensions. +static void tileTileLoopNest(TileLoopNest &tileLoopNest, LinalgOp tiledOp, + ArrayRef loopOps, + ArrayRef tileSizes, + ArrayRef tileInterchange) { + tileLoopNest.rootOp = tiledOp; + + // Check all tile loops are ForOps and append them to the tile loop nest. + assert(all_of(loopOps, [](Operation *op) { return isa(op); }) && + "expect tile loop of type scf::ForOp"); + tileLoopNest.loopOps.append(loopOps.begin(), loopOps.end()); + + // Search the tiled loop dimensions and add them to `tiledLoopDims`. + for (auto en : enumerate(tileSizes)) { + if (en.value() != 0) + tileLoopNest.loopDims.push_back(tileInterchange[en.index()]); + } + assert(isValidTileLoopNest(tileLoopNest) && + "expect tile loop nest is valid after updating it"); +} + +/// Searches the tile loop nest block arguments tied to a block argument `bbArg` +/// of the innermost tile loop. Returns the block argument from outermost to +/// innermost or an empty vector if none are found. +static SmallVector +getTileLoopNestBBArgs(TileLoopNest &tileLoopNest, BlockArgument bbArg) { + // Search all tile loop block arguments from inner to outer. + SmallVector bbArgs; + for (auto tileLoop : reverse(tileLoopNest.loopOps)) { + assert(bbArg && bbArg.getOwner()->getParentOp() == tileLoop && + "expect a tile loop block argument"); + bbArgs.push_back(bbArg); + OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); + bbArg = iterArg->get().dyn_cast(); + } + + // Reverse the block arguments to order them from outer to inner. + return {bbArgs.rbegin(), bbArgs.rend()}; +} + +/// Returns the iteration argument of the outermost tile loop mapped to a block +/// argument `bbArg` of the innermost tile loop. +static OpOperand *getTileLoopNestIterArg(TileLoopNest &tileLoopNest, + BlockArgument bbArg) { + SmallVector bbArgs = + getTileLoopNestBBArgs(tileLoopNest, bbArg); + assert(!bbArgs.empty() && bbArgs.size() == tileLoopNest.loopOps.size() && + "expect to find a block argument for every tile loop"); + return &tileLoopNest.loopOps.front().getOpOperandForRegionIterArg( + bbArgs.front()); +} + +/// Returns true if the block argument `bbArg` has other used than `sliceOp` and +/// its dependencies. Only if there are no other uses, the producer output +/// iteration argument may reused to pass the producer result after fusion. +static bool hasTileLoopNestBBArgOtherUses(TileLoopNest &tileLoopNest, + BlockArgument bbArg, + tensor::ExtractSliceOp sliceOp) { + // Check the innermost block argument is either used by the ExtractSliceOp + // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses + // conservatively. + for (Operation *op : bbArg.getUsers()) { + if (!isa(op)) + return false; + if (auto extractSliceOp = dyn_cast(op)) { + if (extractSliceOp != sliceOp) + return false; + } + if (auto insertSliceOp = dyn_cast(op)) { + SetVector backwardSlice; + getBackwardSlice(insertSliceOp.source(), &backwardSlice, + [](Operation *op) { + return isa(op); + }); + if (backwardSlice.empty() || backwardSlice.front() != sliceOp) + return false; + } + } + + // Check the block arguments, except for the innermost one, have one use. + SmallVector bbArgs = + getTileLoopNestBBArgs(tileLoopNest, bbArg); + return !all_of(bbArgs, [&](BlockArgument bbArg) { + return bbArg.hasOneUse() || bbArg == bbArgs.back(); + }); +} + +//===----------------------------------------------------------------------===// +// StructuredOp specific helpers. +//===----------------------------------------------------------------------===// + +/// Relate the producer to the consumer loop iterations that access the same +/// producer result element: +/// consumerToProducerLoops = +/// inverse(producerIndexingMap).compose(consumerIndexingMap). +/// Return `consumerToProducerLoops` or none if the inversion fails. +static Optional +getConsumerToProducerLoopsMap(AffineMap producerIndexingMap, + AffineMap consumerIndexingMap) { + assert(consumerIndexingMap.getNumResults() == + producerIndexingMap.getNumResults() && + "expect the number of indexing map results to match"); + // Ensure the producer indexing map is a projected permutation. + if (!producerIndexingMap.isProjectedPermutation()) + return None; + AffineMap inverseIndexingMap = + inverseAndBroadcastProjectedPermuation(producerIndexingMap); + return inverseIndexingMap.compose(consumerIndexingMap); +} + +/// Returns the producer result slice dimensions tiled by the tile loop nest or +/// an empty vector if `getConsumerToProducerLoopsMap` returns none. +// TODO: replace by Fourier-Motzkin and/or compute starting from consumer. +SmallVector getTiledSliceDims(OpResult producerResult, + OpOperand *consumerOperand, + TileLoopNest &tileLoopNest) { + LinalgOp consumerOp = consumerOperand->getOwner(); + LinalgOp producerOp = producerResult.getOwner(); + OpOperand *opOperand = + producerOp.getOutputOperand(producerResult.getResultNumber()); + + // Compute the `consumerToProducerLoopsMap` and exit if the computation fails. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand); + Optional consumerToProducerLoopsMap = + getConsumerToProducerLoopsMap( + producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand)); + if (!consumerToProducerLoopsMap.hasValue()) + return {}; + + // Compute the set of tiled producer loops. + DenseSet tiledProducerLoops; + for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) { + for (int64_t dim : tileLoopNest.loopDims) { + if (en.value().isFunctionOfDim(dim)) + tiledProducerLoops.insert(en.index()); + } + } + + // Compute the slice dimensions for the tiled producer loops. + SmallVector tiledSliceDims; + for (auto en : enumerate(producerIndexingMap.getResults())) { + auto dimExpr = en.value().dyn_cast(); + if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0) + tiledSliceDims.push_back(en.index()); + } + return tiledSliceDims; +} + +/// Returns the producer fused in place of `sliceOp`. Tile the producer operands +/// along the `tiledSliceDims` and clone the producer. Consider the case of +/// fusion of an output tensor: +/// ``` +/// %1 = producer ins(...) outs(%0) +/// %2 = consumer ins(...) outs(%1) +/// ``` +/// When consumer is tiled, %1 appears in the loop iter_args: +/// ``` +/// %1 = producer ins(...) outs(%0) +/// %2 = scf.for ... iter_args(%1) .. (%bbarg) { +/// %t1 = tensor.extract_slice %bbarg[..] +/// %t2 = consumer ins(...) outs(%t1) +/// %r = tensor.insert_slice %t2, %bbarg[...] +/// } +/// ``` +/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0): +/// ``` +/// %2 = scf.for ... iter_args(%0) .. (%bbarg) { +/// %t0 = tensor.extract_slice %bbarg[..] +/// %t1 = producer ins(...) outs(%t0) +/// %t2 = consumer ins(...) outs(%t1) +/// %r = tensor.insert_slice %t2, %bbarg[...] +/// } +/// ``` +/// This transformation is only valid if %bbarg is exclusively used by the +/// output ExtractSliceOp / InsertSliceOp pair. The `fuseProducer` method +/// checks this precondition using the `hasOtherUses` method of tile loop nest +/// and fails if it is not satisfied. +/// TODO: instead of check and failure, insert new iter_args each time a +/// producer is fused into a consumer and fold away unused iter_args. +static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, + tensor::ExtractSliceOp sliceOp, + ArrayRef tiledSliceDims, + TileLoopNest &tileLoopNest) { + // Clone the producer after `sliceOp` since the slice may be reused to pass in + // the producer result. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(sliceOp); + + // Get the producer. + LinalgOp producerOp = producerResult.getOwner(); + Location loc = producerOp.getLoc(); + + // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. + SmallVector producerLoopBounds; + transform(producerOp.createLoopRanges(b, loc), + std::back_inserter(producerLoopBounds), + [](Range range) { return range.size; }); + SmallVector sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); + + // Get the producer result indexing map. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( + producerOp.getOutputOperand(producerResult.getResultNumber())); + + // Tile the producer operands given the `sliceOp` ranges. Iterate the + // `tiledSliceDims` and store the tile offset and size for the tiled slice + // dimension. Assumes the mapping from slice dimensions to producer loops is a + // permutation. + auto zero = b.create(loc, 0); + SmallVector tileIvs(producerOp.getNumLoops(), nullptr); + SmallVector tileSizes(producerOp.getNumLoops(), zero); + SmallVector allIvs(producerOp.getNumLoops(), nullptr); + for (int64_t tiledSliceDim : tiledSliceDims) { + AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim]; + assert(result.isa() && + "expect producer indexing map is a projected permutation"); + int64_t tiledProducerLoop = result.cast().getPosition(); + tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; + tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; + allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; + } + erase_value(tileIvs, nullptr); + SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); + tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, + tileSizes, producerLoopBounds); + + // Update the iteration argument of the outermost tile loop in case of output + // fusion. Set the iteration argument to the producer output and use the + // `sliceOp` result instead of the tiled output operand. + BlockArgument bbArg = sliceOp.source().dyn_cast(); + if (bbArg) { + OpOperand *iterArg = getTileLoopNestIterArg(tileLoopNest, bbArg); + OpOperand *outputOperand = + producerOp.getOutputOperand(producerResult.getResultNumber()); + iterArg->set(outputOperand->get()); + tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); + } + + // Clone the producer using the tiled producer operands. + TypeRange resultTypes = ValueRange(tiledOperands) + .take_back(producerOp.getNumOutputs()) + .getTypes(); + LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); + + return clonedOp; +} + +//===----------------------------------------------------------------------===// +// Tile and fuse entry-points. +//===----------------------------------------------------------------------===// + +FailureOr mlir::linalg::fuseProducer(OpBuilder &b, + tensor::ExtractSliceOp sliceOp, + OpOperand *consumerOpOperand, + TileLoopNest &tileLoopNest) { + // Check `tileLoopNest` is non-empty and satisfies all invariants. + if (isEmptyTileLoopNest(tileLoopNest) || !isValidTileLoopNest(tileLoopNest)) + return failure(); + + // Check `tileLoopNest` tiles `sliceOp` and `consumerOpOperand`. + if (sliceOp->getParentOp() != tileLoopNest.rootOp->getParentOp() || + consumerOpOperand->getOwner() != tileLoopNest.rootOp) + return failure(); + + // Check `consumerOpOperand` is defined by `sliceOp`. + if (consumerOpOperand->get() != sliceOp.getResult()) + return failure(); + + // Search the `producerResult` and check it is a LinalgOp. + auto producerResult = sliceOp.source().dyn_cast(); + if (auto bbArg = sliceOp.source().dyn_cast()) { + // Check the block argument is a tile loop iteration argument. + if (bbArg.getOwner()->getParentOp() != tileLoopNest.loopOps.back()) + return failure(); + producerResult = + getTileLoopNestIterArg(tileLoopNest, bbArg)->get().dyn_cast(); + // Check the block argument may be used to pass in the producer output. + if (hasTileLoopNestBBArgOtherUses(tileLoopNest, bbArg, sliceOp)) + return failure(); + } + if (!producerResult || !isa(producerResult.getOwner())) + return failure(); + + // TODO: support producers that have index semantics. + if (cast(producerResult.getOwner()).hasIndexSemantics()) + return failure(); + + // Compute the slice dimensions tiled by `tileLoopNest`. + SmallVector tiledSliceDims = + getTiledSliceDims(producerResult, consumerOpOperand, tileLoopNest); + if (tiledSliceDims.empty()) + return failure(); + + // Tile the producer operands and clone the producer in place of `sliceOp`. + LinalgOp clonedOp = getTiledProducer(b, producerResult, sliceOp, + tiledSliceDims, tileLoopNest); + + // Cast the `clonedOp` result to gap type mismatches before canonicalization. + Type consumerOperandType = consumerOpOperand->get().getType(); + Value newResult = clonedOp->getResult(producerResult.getResultNumber()); + if (newResult.getType() != consumerOperandType) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(clonedOp); + newResult = b.create(producerResult.getLoc(), + consumerOperandType, newResult); + } + + // Replace the `sliceOp` uses except for the `clonedOp` output uses. + sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); + return clonedOp; +} + +FailureOr +mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, + ArrayRef tileSizes, + ArrayRef tileInterchange) { + assert(tileSizes.size() == tileInterchange.size() && + "expect the number of tile sizes and interchange dims to match"); + + // Create an empty tile loop nest. + TileLoopNest tileLoopNest = {consumerOp}; + + // Helper function to further tile the tile loop nest. + auto tileConsumer = [&](ArrayRef tileSizes) { + // Tile the consumer if any of the tile sizes is non-zero. + if (tileSizes.size() == static_cast(count(tileSizes, 0))) + return success(); + LinalgTilingOptions tilingOptions; + tilingOptions = tilingOptions.setInterchange(tileInterchange) + .setTileSizes(tileSizes) + .setLoopType(LinalgTilingLoopType::Loops); + LinalgOp tiledOp = tileLoopNest.rootOp; + Optional tiledConsumer = + tileLinalgOp(b, tiledOp, tilingOptions); + if (!tiledConsumer.hasValue()) + return failure(); + // Update the tile loop nest. + tileTileLoopNest(tileLoopNest, tiledConsumer.getValue().op, + tiledConsumer.getValue().loops, tileSizes, + tileInterchange); + tiledOp->replaceAllUsesWith(tiledConsumer.getValue().tensorResults); + return success(); + }; + + // Helper function to fuse the `opOperand` producers + auto fuseProducers = [&](ArrayRef opOperands) { + // Try to fuse the consumer operands if the tile loop nest is non-empty. + if (isEmptyTileLoopNest(tileLoopNest)) + return; + for (OpOperand *opOperand : opOperands) { + tensor::ExtractSliceOp sliceOp = + opOperand->get().getDefiningOp(); + assert(sliceOp && "expect to find a slice op after tiling"); + (void)fuseProducer(b, sliceOp, opOperand, tileLoopNest); + } + }; + + // Search the number of outer parallel loops to separate them from possible + // inner reduction dimensions. + SmallVector iterTypes = + llvm::to_vector<6>(consumerOp.iterator_types().getAsRange()); + applyPermutationToVector(iterTypes, tileInterchange); + auto *it = find_if(iterTypes, [&](StringAttr iterType) { + return !isParallelIterator(iterType); + }); + int64_t split = std::distance(iterTypes.begin(), it); + + // Tile the outer parallel loops and fuse the output operands. + SmallVector outerTileSizes; + outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); + outerTileSizes.append(tileSizes.size() - split, 0); + if (failed(tileConsumer(outerTileSizes))) + return failure(); + SmallVector outputOperands = + tileLoopNest.rootOp.getOutputOperands(); + fuseProducers(outputOperands); + + // Tile the remaining loops and fuse the input operands. + SmallVector innerTileSizes; + innerTileSizes.append(split, 0); + innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); + if (failed(tileConsumer(innerTileSizes))) + return failure(); + SmallVector inputOperands = + tileLoopNest.rootOp.getInputOperands(); + fuseProducers(inputOperands); + return tileLoopNest; +} + +namespace { +struct LinalgTileAndFuseTensorOps + : public LinalgTileAndFuseTensorOpsBase { + + void notifyFailure(StringRef message) { + llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n"; + signalPassFailure(); + } + + void runOnFunction() override { + FuncOp funcOp = getFunction(); + OpBuilder b(funcOp.getContext()); + + // Check the `tileInterchange` vector is either empty or its size matches + // the `tileSizes` vector size. + if (!tileInterchange.empty() && + tileInterchange.size() != tileSizes.size()) { + return notifyFailure( + "expect the number of tile sizes and interchange dims to match"); + } + + // Heuristic to search a good root operation to tile and start fusion. Walk + // all operations and select the candidate with the maximal backward slice. + LinalgOp rootOp = nullptr; + int64_t numFusionCandidates = -1; + funcOp.walk([&](LinalgOp linalgOp) { + SetVector backwardSlice; + getBackwardSlice(linalgOp, &backwardSlice); + int64_t backwardSliceSize = count_if( + backwardSlice, [](Operation *op) { return isa(op); }); + if (backwardSliceSize > numFusionCandidates) { + rootOp = linalgOp; + numFusionCandidates = backwardSliceSize; + } + }); + if (!rootOp) + return notifyFailure("expect to find a root operation"); + + // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. + if (tileSizes.size() < rootOp.getNumLoops()) + return notifyFailure("expect #tile sizes >= #loops"); + + // Copy the `tileSizes` and `tileInterchange` vectors and limit their length + // to the number of `rootOp` loop dimensions. + SmallVector rootTileSizes(tileSizes.begin(), tileSizes.end()); + SmallVector rootInterchange(tileInterchange.begin(), + tileInterchange.end()); + rootTileSizes.resize(rootOp.getNumLoops()); + rootInterchange.resize(rootOp.getNumLoops()); + + // Check the `tileInterchange` is a permutation. + SmallVector rootInterchangeExprs; + transform(rootInterchange, std::back_inserter(rootInterchangeExprs), + [&](int64_t dim) { return b.getAffineDimExpr(dim); }); + AffineMap rootInterchangeMap = AffineMap::get( + rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext()); + if (!rootInterchangeMap.isPermutation()) + return notifyFailure("expect tile interchange is a permutation"); + + // Tile `rootOp` and fuse its producers. + if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, + rootInterchange))) + return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly"); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgTileAndFuseTensorOpsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -0,0 +1,190 @@ +// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> + +// CHECK: fuse_input +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +builtin.func @fuse_input(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + + // Tile both input operand dimensions. + // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]]) + // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK-SAME: %[[UB1]], %[[UB2]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: %{{.*}} = linalg.matmul ins(%[[T1]] + %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_output +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_output(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // Update the iteration argument of the outermost tile loop. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // CHECK: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]]) + + // Tile the both output operand dimensions. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // CHECK-SAME: %[[IV1]], %[[IV0]] + // CHECK-SAME: %[[TS1]], %[[TS0]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // CHECK: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]] + %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK: fuse_reduction +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32> +builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: tensor<12x7x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg3 : tensor<12x7x25xf32>) outs(%arg1 : tensor<12x25xf32>) { + ^bb0(%arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<12x25xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]]) + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // CHECK: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]]) + + // Tile only the parallel dimensions but not the reduction dimension. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV2]], 0, %[[IV0]] + // CHECK-SAME: %[[UB2]], 7, %[[UB0]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV2]], %[[IV0]] + // CHECK-SAME: %[[UB2]], %[[UB0]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]] + %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: fuse_transposed +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32> +builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: tensor<12x24xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg3 : tensor<12x24xf32>) outs(%arg0 : tensor<24x12xf32>) { + ^bb0(%arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + + // Swap the input operand slice offsets due to the transposed indexing map. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV2]], %[[IV1]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // CHECK: %{{.*}} = linalg.matmul ins(%[[T2]] + %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK: fuse_input_and_output +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // Fuse both producers to the appropriate tile loops. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // CHECK-SAME: %[[IV1]], %[[IV0]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]] + %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %2 : tensor<24x25xf32> +}