diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -77,6 +77,9 @@ /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); +/// Create a pass to tile a LinalgOp and fuse its producers. +std::unique_ptr> createLinalgTileAndFuseTensorOpsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -240,4 +240,18 @@ }]; } +def LinalgTileAndFuseTensorOps + : FunctionPass<"linalg-tile-and-fuse-tensor-ops"> { + let summary = "Tile a LinalgOp and fuse its producers."; + let constructor = "mlir::createLinalgTileAndFuseTensorOpsPass()"; + let options = [ + ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ListOption<"tileInterchange", "tile-interchange", "int64_t", + "Tile loop interchange", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; + let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -167,6 +167,72 @@ OpResult producerOpResult, OpOperand &consumerOpOperand); +//===----------------------------------------------------------------------===// +// Fusion on tensor utilities +//===----------------------------------------------------------------------===// + +/// A struct to manage the tile loop nest specific information. +struct TileLoopNest { + TileLoopNest(LinalgOp tiledOp) : tiledOp(tiledOp){}; + + /// Updates the tile loop nest after tiling given the newly tiled op `tiledOp` + /// and its tile loops `tileLoopOps` as well as the tiling options needed to + /// keep track of the tile loop dimensions. + void addTileLoops(LinalgOp tiledOp, ArrayRef tileLoopOps, + ArrayRef tileSizes, + ArrayRef tileInterchange); + + /// Returns the iteration argument of the outermost tile loop mapped to a + /// block argument `bbArg` of the innermost tile loop. + OpOperand *getTiedIterArg(BlockArgument bbArg); + + /// Returns true if the block argument `bbArg` has other used than `sliceOp` + /// and its dependencies. Only if there are no other uses, the producer output + /// iteration argument may reused to pass the producer result after fusion. + bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp); + + /// Returns true if the tile loop nest invariants are satisfied: + /// - The number of tile loop operations and dimensions match. + /// - The innermost tile loop is the parent of `tiledOp`. + /// - The tile loops are directly nested. + // TODO: relax to support additional control flow, e.g., IfOp. + bool isValid(); + + /// Returns true if the tile loop nest is empty. + bool isEmpty() { return tileLoopOps.empty(); } + + /// Returns `tiledOp`. + LinalgOp getTiledOp() { return tiledOp; }; + + /// Returns the `tiledOp` loop dimensions tiled by the `tileLoopOps` from + /// outer to inner. + SmallVector getTileLoopDims() { return tileLoopDims; } + +private: + /// Returns all tile loop block arguments tied to the block argument `bbArg` + /// from outer to inner. + SmallVector getTiedBBArgs(BlockArgument); + + LinalgOp tiledOp; + SmallVector tileLoopOps; + SmallVector tileLoopDims; +}; + +/// Fuses the producer of `producerResult` in place of `sliceOp` if possible. +/// Expects `tileLoopNest` tiles the consumer. +// TODO: add replace uses callback to support passes and patterns. +FailureOr fuseProducer(OpBuilder &b, OpResult producerResult, + tensor::ExtractSliceOp sliceOp, + OpOperand *consumerOpOperand, + TileLoopNest &tileLoopNest); + +/// Tiles `consumerOp` and fuses its dependencies if possible. Uses `tileSizes` +/// and `tileInterchange` parameters to control the tiling. +FailureOr +tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, + ArrayRef tileSizes, + ArrayRef tileInterchange); + //===----------------------------------------------------------------------===// // Distribution utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp Fusion.cpp + FusionOnTensors.cpp Generalization.cpp Hoisting.cpp InlineScalarOperands.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -0,0 +1,487 @@ +//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements linalg fusion on tensors +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; +using namespace linalg; + +//===----------------------------------------------------------------------===// +// TileLoopNest implementation. +//===----------------------------------------------------------------------===// + +void TileLoopNest::addTileLoops(LinalgOp tiledOp, + ArrayRef tileLoopOps, + ArrayRef tileSizes, + ArrayRef tileInterchange) { + this->tiledOp = tiledOp; + + // Convert the tile loops to ForOps and append them to `tileLoopOps`. + transform(tileLoopOps, std::back_inserter(this->tileLoopOps), + [](Operation *op) { + auto forOp = dyn_cast(op); + assert(forOp && "expect tile loop of type scf.for"); + return forOp; + }); + + // Search the tiled loop dimensions and add them to `tiledLoopDims`. + for (auto en : enumerate(tileSizes)) { + if (en.value() != 0) + this->tileLoopDims.push_back(tileInterchange[en.index()]); + } + assert(isValid() && "expect tile loop nest is valid after updating it"); +} + +bool TileLoopNest::isValid() { + // Check the number of `tileLoopOps` and `tileLoopDims` match. + if (tileLoopOps.size() != tileLoopDims.size()) + return false; + + // Check the innermost tile loop is the parent of `tiledOp`. + if (tiledOp->getParentOp() != tileLoopOps.back()) + return false; + + // Check the tile loops are directly nested. + return std::adjacent_find(tileLoopOps.begin(), tileLoopOps.end(), + [](scf::ForOp op1, scf::ForOp op2) { + return op1 != op2->getParentOp(); + }) == tileLoopOps.end(); +} + +SmallVector TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { + SmallVector bbArgs; + + // Search all tile loop block arguments from inner to outer. + for (auto tileLoop : reverse(tileLoopOps)) { + assert(bbArg && bbArg.getOwner()->getParentOp() == tileLoop && + "expect a tile loop block argument"); + bbArgs.push_back(bbArg); + OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); + bbArg = iterArg->get().dyn_cast(); + } + + // Reverse the block arguments to order them from outer to inner. + return {bbArgs.rbegin(), bbArgs.rend()}; +} + +OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { + SmallVector bbArgs = getTiedBBArgs(bbArg); + assert(!bbArgs.empty() && bbArgs.size() == tileLoopOps.size() && + "expect to find a block argument for every tile loop"); + return &tileLoopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); +} + +bool TileLoopNest::hasOtherUses(BlockArgument bbArg, + tensor::ExtractSliceOp sliceOp) { + // Check the innermost block argument is either used by the ExtractSliceOp + // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses + // conservatively. + for (Operation *op : bbArg.getUsers()) { + if (!isa(op)) + return false; + if (auto extractSliceOp = dyn_cast(op)) { + if (extractSliceOp != sliceOp) + return false; + } + if (auto insertSliceOp = dyn_cast(op)) { + SetVector backwardSlice; + getBackwardSlice(insertSliceOp.source(), &backwardSlice, + [](Operation *op) { + return isa(op); + }); + if (backwardSlice.empty() || backwardSlice.front() != sliceOp) + return false; + } + } + + // Check the block arguments, except for the innermost one, have one use. + SmallVector bbArgs = getTiedBBArgs(bbArg); + return !all_of(bbArgs, [&](BlockArgument bbArg) { + return bbArg.hasOneUse() || bbArg == bbArgs.back(); + }); +} + +//===----------------------------------------------------------------------===// +// StructuredOp specific helpers. +//===----------------------------------------------------------------------===// + +/// Relate the producer to the consumer loop iterations that access the same +/// producer result element: +/// consumerToProducerLoops = +/// inverse(producerIndexingMap).compose(consumerIndexingMap). +/// Return `consumerToProducerLoops` or none if the inversion fails. +static Optional +getConsumerToProducerLoopsMap(AffineMap producerIndexingMap, + AffineMap consumerIndexingMap) { + assert(consumerIndexingMap.getNumResults() == + producerIndexingMap.getNumResults() && + "expect the number of indexing map results to match"); + // Ensure the producer indexing map is a projected permutation. + if (!producerIndexingMap.isProjectedPermutation()) + return None; + AffineMap inverseIndexingMap = + inverseAndBroadcastProjectedPermuation(producerIndexingMap); + return inverseIndexingMap.compose(consumerIndexingMap); +} + +/// Returns the producer result slice dimensions tiled by the tile loop nest or +/// an empty vector if `getConsumerToProducerLoopsMap` returns none. +// TODO: replace by Fourier-Motzkin and/or compute starting from consumer. +SmallVector getTiledSliceDims(OpResult producerResult, + OpOperand *consumerOperand, + TileLoopNest &tileLoopNest) { + LinalgOp consumerOp = consumerOperand->getOwner(); + LinalgOp producerOp = producerResult.getOwner(); + OpOperand *opOperand = + producerOp.getOutputOperand(producerResult.getResultNumber()); + + // Compute the `consumerToProducerLoopsMap` and exit if the computation fails. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand); + Optional consumerToProducerLoopsMap = + getConsumerToProducerLoopsMap( + producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand)); + if (!consumerToProducerLoopsMap.hasValue()) + return {}; + + // Compute the set of tiled producer loops. + DenseSet tiledProducerLoops; + for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) { + for (int64_t dim : tileLoopNest.getTileLoopDims()) { + if (en.value().isFunctionOfDim(dim)) + tiledProducerLoops.insert(en.index()); + } + } + + // Compute the slice dimensions for the tiled producer loops. + SmallVector tiledSliceDims; + for (auto en : enumerate(producerIndexingMap.getResults())) { + auto dimExpr = en.value().dyn_cast(); + if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0) + tiledSliceDims.push_back(en.index()); + } + return tiledSliceDims; +} + +/// Returns the producer fused in place of `sliceOp`. Tile the producer operands +/// along the `tiledSliceDims` and clone the producer. Consider the case of +/// fusion of an output tensor: +/// ``` +/// %1 = producer ins(...) outs(%0) +/// %2 = consumer ins(...) outs(%1) +/// ``` +/// When consumer is tiled, %1 appears in the loop iter_args: +/// ``` +/// %1 = producer ins(...) outs(%0) +/// %2 = scf.for ... iter_args(%1) .. (%bbarg) { +/// %t1 = tensor.extract_slice %bbarg[..] +/// %t2 = consumer ins(...) outs(%t1) +/// %r = tensor.insert_slice %t2, %bbarg[...] +/// } +/// ``` +/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0): +/// ``` +/// %2 = scf.for ... iter_args(%0) .. (%bbarg) { +/// %t0 = tensor.extract_slice %bbarg[..] +/// %t1 = producer ins(...) outs(%t0) +/// %t2 = consumer ins(...) outs(%t1) +/// %r = tensor.insert_slice %t2, %bbarg[...] +/// } +/// ``` +/// This transformation is only valid if %bbarg is exclusively used by the +/// output ExtractSliceOp / InsertSliceOp pair. The `fuseProducer` method +/// checks this precondition using the `hasOtherUses` method of tile loop nest +/// and fails if it is not satisfied. +/// TODO: instead of check and failure, insert new iter_args each time a +/// producer is fused into a consumer and fold away unused iter_args. +static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, + tensor::ExtractSliceOp sliceOp, + ArrayRef tiledSliceDims, + TileLoopNest &tileLoopNest) { + // Get the producer. + LinalgOp producerOp = producerResult.getOwner(); + Location loc = producerOp.getLoc(); + + // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. + SmallVector producerLoopBounds; + transform(producerOp.createLoopRanges(b, loc), + std::back_inserter(producerLoopBounds), + [](Range range) { return range.size; }); + SmallVector sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); + + // Get the producer result indexing map. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( + producerOp.getOutputOperand(producerResult.getResultNumber())); + + // Tile the producer operands. + auto zero = b.create(loc, 0); + SmallVector tileIvs(producerOp.getNumLoops(), nullptr); + SmallVector tileSizes(producerOp.getNumLoops(), zero); + SmallVector allIvs(producerOp.getNumLoops(), nullptr); + for (int64_t tiledSliceDim : tiledSliceDims) { + AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim]; + assert(result.isa() && + "expect producer indexing map is a projected permutation"); + int64_t tiledProducerLoop = result.cast().getPosition(); + tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; + tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; + allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; + } + erase_value(tileIvs, nullptr); + SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); + tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, + tileSizes, producerLoopBounds); + + // Update the iteration argument of the outermost tile loop in case of output + // fusion. Set the iteration argument to the producer output and use the + // `sliceOp` result instead of the tiled output operand. + BlockArgument bbArg = sliceOp.source().dyn_cast(); + if (bbArg) { + OpOperand *iterArg = tileLoopNest.getTiedIterArg(bbArg); + OpOperand *outputOperand = + producerOp.getOutputOperand(producerResult.getResultNumber()); + iterArg->set(outputOperand->get()); + tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); + } + + // Clone the producer using the tiled producer operands. + TypeRange resultTypes = ValueRange(tiledOperands) + .take_back(producerOp.getNumOutputs()) + .getTypes(); + LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); + + // TODO: support producers that have index semantics. + assert(!producerOp.hasIndexSemantics() && + "expect producer to not have index semantics"); + return clonedOp; +} + +//===----------------------------------------------------------------------===// +// Tile and fuse entry-points. +//===----------------------------------------------------------------------===// + +FailureOr mlir::linalg::fuseProducer(OpBuilder &b, + OpResult producerResult, + tensor::ExtractSliceOp sliceOp, + OpOperand *consumerOpOperand, + TileLoopNest &tileLoopNest) { + // Check `producerResult` is non-zero. + if (!producerResult) + return failure(); + + // Check `tileLoopNest` is non-empty and satisfies all invariants. + if (tileLoopNest.isEmpty() || !tileLoopNest.isValid()) + return failure(); + + // Check `tileLoopNest` tiles `sliceOp` and `consumerOpOperand`. + if (sliceOp->getParentOp() != tileLoopNest.getTiledOp()->getParentOp() || + consumerOpOperand->getOwner() != tileLoopNest.getTiledOp()) + return failure(); + + // Check the use-def chain from `consumerOpOperand` to `producerResult`. If + // the producer result is passed in via iteration arguments make sure the + // iteration argument has no other uses. + if (consumerOpOperand->get() != sliceOp.getResult()) + return failure(); + if (BlockArgument bbArg = sliceOp.source().dyn_cast()) { + if (tileLoopNest.hasOtherUses(bbArg, sliceOp) || + tileLoopNest.getTiedIterArg(bbArg)->get() != producerResult) + return failure(); + } else if (sliceOp.source() != producerResult) { + return failure(); + } + + // Compute the slice dimensions tiled by `tileLoopNest`. + SmallVector tiledSliceDims = + getTiledSliceDims(producerResult, consumerOpOperand, tileLoopNest); + if (tiledSliceDims.empty()) + return failure(); + + // Insert `clonedOp` after `sliceOp` since the former may access the slice. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(sliceOp); + + // Tile the producer operands and clone the producer in place of `sliceOp`. + LinalgOp clonedOp = getTiledProducer(b, producerResult, sliceOp, + tiledSliceDims, tileLoopNest); + + // Cast the `clonedOp` result to gap type mismatches before canonicalization. + Type consumerOperandType = consumerOpOperand->get().getType(); + Value newResult = clonedOp->getResult(producerResult.getResultNumber()); + if (newResult.getType() != consumerOperandType) + newResult = b.create(producerResult.getLoc(), + consumerOperandType, newResult); + + // Replace the `sliceOp` uses except for the `clonedOp` output uses. + sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); + return clonedOp; +} + +FailureOr +mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, + ArrayRef tileSizes, + ArrayRef tileInterchange) { + assert(tileSizes.size() == tileInterchange.size() && + "expect the number of tile sizes and interchange dims to match"); + + // Create an empty tile loop nest. + TileLoopNest tileLoopNest(consumerOp); + + // Helper function to further tile the tile loop nest. + auto tileConsumer = [&](ArrayRef tileSizes) { + // Tile the consumer except all tile sizes are zero. + if (tileSizes.size() != static_cast(count(tileSizes, 0))) { + LinalgTilingOptions tilingOptions; + tilingOptions = tilingOptions + .setInterchange(SmallVector{ + tileInterchange.begin(), tileInterchange.end()}) + .setTileSizes(tileSizes) + .setLoopType(LinalgTilingLoopType::Loops); + LinalgOp tiledOp = tileLoopNest.getTiledOp(); + Optional tiledConsumer = + tileLinalgOp(b, tiledOp, tilingOptions); + if (!tiledConsumer.hasValue()) + return failure(); + // Update the tile loop nest. + tileLoopNest.addTileLoops(tiledConsumer.getValue().op, + tiledConsumer.getValue().loops, tileSizes, + tileInterchange); + tiledOp->replaceAllUsesWith(tiledConsumer.getValue().tensorResults); + } + return success(); + }; + + // Helper function to fuse the `opOperand` producers + auto fuseProducers = [&](ArrayRef opOperands) { + // Try to fuse the consumer operands if the tile loop nest is not empty. + if (!tileLoopNest.isEmpty()) { + for (OpOperand *opOperand : opOperands) { + OpResult producerResult = opOperand->get().dyn_cast(); + OpOperand *consumerOpOperand = &tileLoopNest.getTiledOp()->getOpOperand( + opOperand->getOperandNumber()); + tensor::ExtractSliceOp sliceOp = + consumerOpOperand->get().getDefiningOp(); + (void)fuseProducer(b, producerResult, sliceOp, consumerOpOperand, + tileLoopNest); + } + } + }; + + // Search the number of outer parallel loops to separate them from possible + // inner reduction dimensions. + auto *it = find_if(tileInterchange, [&](int64_t dim) { + assert(dim >= 0 && dim < consumerOp.getNumLoops() && + "expect interchange dims are >=0 and <#loops"); + Attribute iteratorType = consumerOp.iterator_types().getValue()[dim]; + return !isParallelIterator(iteratorType); + }); + int64_t split = std::distance(tileInterchange.begin(), it); + + // Tile the outer parallel loops and fuse the output operands. + SmallVector outerTileSizes; + outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); + outerTileSizes.append(tileSizes.size() - split, 0); + if (failed(tileConsumer(outerTileSizes))) + return failure(); + SmallVector outputOperands = consumerOp.getOutputOperands(); + fuseProducers(outputOperands); + + // Tile the remaining loops and fuse the input operands. + SmallVector innerTileSizes; + innerTileSizes.append(split, 0); + innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); + if (failed(tileConsumer(innerTileSizes))) + return failure(); + SmallVector inputOperands = consumerOp.getInputOperands(); + fuseProducers(inputOperands); + return tileLoopNest; +} + +namespace { +struct LinalgTileAndFuseTensorOps + : public LinalgTileAndFuseTensorOpsBase { + + void notifyFailure(StringRef message) { + llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n"; + signalPassFailure(); + } + + void runOnFunction() override { + FuncOp funcOp = getFunction(); + + // Check the size of `tileInterchange` matches `tileSizes` or is zero. + if (tileInterchange.size() != tileSizes.size() && !tileInterchange.empty()) + return notifyFailure( + "expect the number of tile sizes and interchange dims to match"); + + // Search the root operation with the largest backward slice. + LinalgOp rootOp = nullptr; + int64_t numFusionCandidates = -1; + funcOp.walk([&](LinalgOp linalgOp) { + SetVector backwardSlice; + getBackwardSlice(linalgOp, &backwardSlice); + int64_t backwardSliceSize = count_if( + backwardSlice, [](Operation *op) { return isa(op); }); + if (backwardSliceSize > numFusionCandidates) { + rootOp = linalgOp; + numFusionCandidates = backwardSliceSize; + } + }); + if (!rootOp) + return notifyFailure("expect to find a root operation"); + + // Copy the available tile sizes and set the remaining ones to zero. + SmallVector rootTileSizes(tileSizes.begin(), tileSizes.end()); + rootTileSizes.resize(rootOp.getNumLoops(), 0); + + // Copy the interchange entries and complete them to a valid permutation. + SmallVector rootInterchange; + SmallVector dimCounts(rootOp.getNumLoops(), 0); + for (int64_t dim : tileInterchange) { + if (dim < 0 || dim >= rootOp.getNumLoops()) + return notifyFailure("expect interchange dims are >=0 and <#loops"); + if (dimCounts[dim] != 0) + return notifyFailure("expect interchange dims to appear only once"); + dimCounts[dim]++; + rootInterchange.push_back(dim); + if (rootInterchange.size() >= rootOp.getNumLoops()) + break; + } + for (auto en : enumerate(dimCounts)) { + if (en.value() == 0) + rootInterchange.push_back(en.index()); + } + + // Tile the root operation and fuse its producers. + OpBuilder b(funcOp.getContext()); + if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, + rootInterchange))) + return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly"); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgTileAndFuseTensorOpsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -0,0 +1,190 @@ +// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> + +// CHECK: fuse_input +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +builtin.func @fuse_input(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + + // Tile both input operand dimensions. + // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]]) + // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK-SAME: %[[UB1]], %[[UB2]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: %{{.*}} = linalg.matmul ins(%[[T1]] + %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_output +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_output(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // Update the iteration argument of the outermost tile loop. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // CHECK: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]]) + + // Tile the both output operand dimensions. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // CHECK-SAME: %[[IV1]], %[[IV0]] + // CHECK-SAME: %[[TS1]], %[[TS0]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // CHECK: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]] + %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK: fuse_reduction +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32> +builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: tensor<12x7x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg3 : tensor<12x7x25xf32>) outs(%arg1 : tensor<12x25xf32>) { + ^bb0(%arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<12x25xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]]) + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // CHECK: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]]) + + // Tile only the parallel dimensions but not the reduction dimension. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV2]], 0, %[[IV0]] + // CHECK-SAME: %[[UB2]], 7, %[[UB0]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV2]], %[[IV0]] + // CHECK-SAME: %[[UB2]], %[[UB0]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]] + %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: fuse_transposed +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32> +builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: tensor<12x24xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg3 : tensor<12x24xf32>) outs(%arg0 : tensor<24x12xf32>) { + ^bb0(%arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + + // Swap the input operand slice offsets due to the transposed indexing map. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV2]], %[[IV1]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // CHECK: %{{.*}} = linalg.matmul ins(%[[T2]] + %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK: fuse_input_and_output +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // Fuse both producers to the appropriate tile loops. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // CHECK-SAME: %[[IV1]], %[[IV0]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]] + %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %2 : tensor<24x25xf32> +}