diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -77,6 +77,9 @@ /// work on primitive types, if possible. std::unique_ptr createLinalgDetensorizePass(); +/// Create a pass to tile a LinalgOp and fuse its producers. +std::unique_ptr> createLinalgTileAndFuseTensorOpsPass(); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -243,4 +243,18 @@ }]; } +def LinalgTileAndFuseTensorOps + : FunctionPass<"linalg-tile-and-fuse-tensor-ops"> { + let summary = "Tile a LinalgOp and fuse its producers."; + let constructor = "mlir::createLinalgTileAndFuseTensorOpsPass()"; + let options = [ + ListOption<"tileSizes", "tile-sizes", "int64_t", "Tile sizes", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ListOption<"tileInterchange", "tile-interchange", "int64_t", + "Tile loop interchange", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ]; + let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; +} + #endif // MLIR_DIALECT_LINALG_PASSES diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -172,6 +172,64 @@ OpResult producerOpResult, OpOperand &consumerOpOperand); +//===----------------------------------------------------------------------===// +// Fusion on tensor utilities +//===----------------------------------------------------------------------===// + +/// A struct to manage the tile loop nest specific information. +class TileLoopNest { +public: + TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {} + + /// Tile the root operation using the given `tileSizes` and `tileInterchange`. + LogicalResult tileRootOp(OpBuilder &b, ArrayRef tileSizes, + ArrayRef tileInterchange); + + /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the + /// fused producer of fails if fusion is not possible. + // TODO: add replace uses callback to support passes and patterns. + FailureOr fuseProducer(OpBuilder &b, OpOperand *rootOpOperand); + + /// Returns the tiled root operation. + LinalgOp getRootOp() { return rootOp; } + +private: + /// Returns true if the tile loop nest has no tile loops. + bool isEmpty(); + + /// Returns true if the tile loop nest invariants are satisfied: + /// - The number of tile loop operations and dimensions match. + /// - The innermost tile loop is the parent of `tiledOp`. + /// - The tile loops are directly nested. + // TODO: relax to support additional control flow, e.g., IfOp. + bool isValid(); + + /// Searches the block arguments tied to a block argument `bbArg` of the + /// innermost tile loop. Returns the block argument from outermost to + /// innermost or an empty vector if none are found. + SmallVector getTiedBBArgs(BlockArgument bbArg); + + /// Returns the iteration argument of the outermost tile loop mapped to a + /// block argument `bbArg` of the innermost tile loop. + OpOperand *getTiedIterArg(BlockArgument bbArg); + + /// Returns true if `bbArg` has other used than `sliceOp` and its + /// dependencies. Only if there are no other uses, the producer output + /// iteration argument may reused to pass the producer result after fusion. + bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp); + + LinalgOp rootOp; + SmallVector loopOps; + SmallVector loopDims; +}; + +/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the +/// `tileSizes` and `tileInterchange` parameters to control the tiling. +FailureOr +tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, + ArrayRef tileSizes, + ArrayRef tileInterchange); + //===----------------------------------------------------------------------===// // Distribution utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp Fusion.cpp + FusionOnTensors.cpp Generalization.cpp Hoisting.cpp InlineScalarOperands.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -0,0 +1,481 @@ +//===- Fusion.cpp - Implementation of linalg Fusion -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements linalg fusion on tensors +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/Support/LLVM.h" + +using namespace mlir; +using namespace linalg; + +//===----------------------------------------------------------------------===// +// StructuredOp specific helpers. +//===----------------------------------------------------------------------===// + +/// Relate the producer to the consumer loop iterations that access the same +/// producer result element: +/// consumerToProducerLoops = +/// inverse(producerIndexingMap).compose(consumerIndexingMap). +/// Return `consumerToProducerLoops` or none if the inversion fails. +static Optional +getConsumerToProducerLoopsMap(AffineMap producerIndexingMap, + AffineMap consumerIndexingMap) { + assert(consumerIndexingMap.getNumResults() == + producerIndexingMap.getNumResults() && + "expect the number of indexing map results to match"); + // Ensure the producer indexing map is a projected permutation. + if (!producerIndexingMap.isProjectedPermutation()) + return None; + AffineMap inverseIndexingMap = + inverseAndBroadcastProjectedPermuation(producerIndexingMap); + return inverseIndexingMap.compose(consumerIndexingMap); +} + +/// Returns the producer result slice dimensions tiled by the tile loop nest or +/// an empty vector if `getConsumerToProducerLoopsMap` returns none. +// TODO: replace by Fourier-Motzkin and/or compute starting from consumer. +SmallVector getTiledSliceDims(OpResult producerResult, + OpOperand *consumerOperand, + ArrayRef tiledLoopDims) { + LinalgOp consumerOp = consumerOperand->getOwner(); + LinalgOp producerOp = producerResult.getOwner(); + OpOperand *opOperand = + producerOp.getOutputOperand(producerResult.getResultNumber()); + + // Compute the `consumerToProducerLoopsMap` and exit if the computation fails. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap(opOperand); + Optional consumerToProducerLoopsMap = + getConsumerToProducerLoopsMap( + producerIndexingMap, consumerOp.getTiedIndexingMap(consumerOperand)); + if (!consumerToProducerLoopsMap.hasValue()) + return {}; + + // Compute the set of tiled producer loops. + DenseSet tiledProducerLoops; + for (auto en : enumerate(consumerToProducerLoopsMap->getResults())) { + for (int64_t dim : tiledLoopDims) { + if (en.value().isFunctionOfDim(dim)) + tiledProducerLoops.insert(en.index()); + } + } + + // Compute the slice dimensions for the tiled producer loops. + SmallVector tiledSliceDims; + for (auto en : enumerate(producerIndexingMap.getResults())) { + auto dimExpr = en.value().dyn_cast(); + if (dimExpr && tiledProducerLoops.count(dimExpr.getPosition()) != 0) + tiledSliceDims.push_back(en.index()); + } + return tiledSliceDims; +} + +/// Returns the producer fused in place of `sliceOp`. Tile the producer operands +/// along the `tiledSliceDims` and clone the producer. Consider the case of +/// fusion of an output tensor: +/// ``` +/// %1 = producer ins(...) outs(%0) +/// %2 = consumer ins(...) outs(%1) +/// ``` +/// When consumer is tiled, %1 appears in the loop iter_args: +/// ``` +/// %1 = producer ins(...) outs(%0) +/// %2 = scf.for ... iter_args(%1) .. (%bbarg) { +/// %t1 = tensor.extract_slice %bbarg[..] +/// %t2 = consumer ins(...) outs(%t1) +/// %r = tensor.insert_slice %t2, %bbarg[...] +/// } +/// ``` +/// Fusing %1 into the loop requires updating iter_args(%1) to iter_args(%0): +/// ``` +/// %2 = scf.for ... iter_args(%0) .. (%bbarg) { +/// %t0 = tensor.extract_slice %bbarg[..] +/// %t1 = producer ins(...) outs(%t0) +/// %t2 = consumer ins(...) outs(%t1) +/// %r = tensor.insert_slice %t2, %bbarg[...] +/// } +/// ``` +/// This transformation is only valid if %bbarg is exclusively used by the +/// output ExtractSliceOp / InsertSliceOp pair, which is checked by the +/// `fuseProducer` method. +/// TODO: instead of check and failure, insert new iter_args each time a +/// producer is fused into a consumer and fold away unused iter_args. +static LinalgOp getTiledProducer(OpBuilder &b, OpResult producerResult, + tensor::ExtractSliceOp sliceOp, + ArrayRef tiledSliceDims, + OpOperand *iterArg) { + // Clone the producer after `sliceOp` since the slice may be reused to pass in + // the producer result. + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(sliceOp); + + // Get the producer. + LinalgOp producerOp = producerResult.getOwner(); + Location loc = producerOp.getLoc(); + + // Obtain the `producerOp` loop bounds and the `sliceOp` ranges. + SmallVector producerLoopBounds; + transform(producerOp.createLoopRanges(b, loc), + std::back_inserter(producerLoopBounds), + [](Range range) { return range.size; }); + SmallVector sliceOpRanges = sliceOp.getOrCreateRanges(b, loc); + + // Get the producer result indexing map. + AffineMap producerIndexingMap = producerOp.getTiedIndexingMap( + producerOp.getOutputOperand(producerResult.getResultNumber())); + + // Tile the producer operands given the `sliceOp` ranges. Iterate the + // `tiledSliceDims` and store the tile offset and size for the tiled slice + // dimension. Assumes the mapping from slice dimensions to producer loops is a + // permutation. + auto zero = b.create(loc, 0); + SmallVector tileIvs(producerOp.getNumLoops(), nullptr); + SmallVector tileSizes(producerOp.getNumLoops(), zero); + SmallVector allIvs(producerOp.getNumLoops(), nullptr); + for (int64_t tiledSliceDim : tiledSliceDims) { + AffineExpr result = producerIndexingMap.getResults()[tiledSliceDim]; + assert(result.isa() && + "expect producer indexing map is a projected permutation"); + int64_t tiledProducerLoop = result.cast().getPosition(); + tileIvs[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].offset; + tileSizes[tiledProducerLoop] = sliceOpRanges[tiledSliceDim].size; + allIvs[tiledProducerLoop] = tileIvs[tiledProducerLoop]; + } + erase_value(tileIvs, nullptr); + SmallVector tiledOperands = producerOp.getInputAndOutputOperands(); + tiledOperands = makeTiledShapes(b, loc, producerOp, tiledOperands, tileIvs, + tileSizes, producerLoopBounds); + + // Output fusion has to update the iteration arguments of the tile loop nest. + // In particular, the iteration argument of the outermost tile loop needs to + // be set to the producer output instead of the producer result and `clonedOp` + // shall use the existing `sliceOp` result instead of the tiled producer + // output operand. + if (iterArg) { + OpOperand *outputOperand = + producerOp.getOutputOperand(producerResult.getResultNumber()); + iterArg->set(outputOperand->get()); + tiledOperands[outputOperand->getOperandNumber()] = sliceOp.getResult(); + } + + // Clone the producer using the tiled producer operands. + TypeRange resultTypes = ValueRange(tiledOperands) + .take_back(producerOp.getNumOutputs()) + .getTypes(); + LinalgOp clonedOp = producerOp.clone(b, loc, resultTypes, tiledOperands); + + return clonedOp; +} + +//===----------------------------------------------------------------------===// +// TileLoopNest specific helpers. +//===----------------------------------------------------------------------===// + +bool TileLoopNest::isEmpty() { return loopOps.empty(); } + +bool TileLoopNest::isValid() { + // Check if the number of `tileLoopOps` and `tileLoopDims` match. + if (loopOps.size() != loopDims.size()) + return false; + + // Check if the innermost tile loop is the parent of `tiledOp`. + if (rootOp->getParentOp() != loopOps.back()) + return false; + + // Check if the tile loops are directly nested. + return std::adjacent_find(loopOps.begin(), loopOps.end(), + [](Operation *op1, Operation *op2) { + return op1 != op2->getParentOp(); + }) == loopOps.end(); +} + +SmallVector TileLoopNest::getTiedBBArgs(BlockArgument bbArg) { + assert(bbArg && "expect the block argument to be non-zero"); + SmallVector bbArgs; + + // Search all tile loop block arguments from inner to outer. + for (auto tileLoop : reverse(loopOps)) { + if (bbArg.getOwner()->getParentOp() != tileLoop) + return {}; + bbArgs.push_back(bbArg); + OpOperand *iterArg = &tileLoop.getOpOperandForRegionIterArg(bbArg); + bbArg = iterArg->get().dyn_cast(); + } + + // Reverse the block arguments to order them from outer to inner. + return {bbArgs.rbegin(), bbArgs.rend()}; +} + +OpOperand *TileLoopNest::getTiedIterArg(BlockArgument bbArg) { + // Search all block arguments and return the matching iteration argument. + SmallVector bbArgs = getTiedBBArgs(bbArg); + if (bbArgs.size() != loopOps.size()) + return nullptr; + return &loopOps.front().getOpOperandForRegionIterArg(bbArgs.front()); +} + +bool TileLoopNest::hasOtherUses(BlockArgument bbArg, + tensor::ExtractSliceOp sliceOp) { + // Check the innermost block argument is either used by the ExtractSliceOp + // `sliceOp`, the matching InsertSliceOp, or by a DimOp. Handle other uses + // conservatively. + for (Operation *op : bbArg.getUsers()) { + if (!isa(op)) + return false; + if (auto extractSliceOp = dyn_cast(op)) { + if (extractSliceOp != sliceOp) + return false; + } + if (auto insertSliceOp = dyn_cast(op)) { + SetVector backwardSlice; + getBackwardSlice(insertSliceOp.source(), &backwardSlice, + [](Operation *op) { + return isa(op); + }); + if (backwardSlice.empty() || backwardSlice.front() != sliceOp) + return false; + } + } + + // Check the block arguments, except for the innermost one, have one use. + SmallVector bbArgs = getTiedBBArgs(bbArg); + return !all_of(bbArgs, [&](BlockArgument bbArg) { + return bbArg.hasOneUse() || bbArg == bbArgs.back(); + }); +} + +LogicalResult TileLoopNest::tileRootOp(OpBuilder &b, + ArrayRef tileSizes, + ArrayRef tileInterchange) { + // Exit if all tile sizes are zero. + if (tileSizes.size() == static_cast(count(tileSizes, 0))) + return success(); + + // Tile the root operation. + LinalgTilingOptions tilingOptions; + tilingOptions = tilingOptions + .setInterchange(SmallVector( + tileInterchange.begin(), tileInterchange.end())) + .setTileSizes(tileSizes) + .setLoopType(LinalgTilingLoopType::Loops); + Optional tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions); + + // Replace all uses of the root operation. + if (!tiledRootOp.hasValue()) + return failure(); + rootOp->replaceAllUsesWith(tiledRootOp->tensorResults); + + // Update the root operation and append the loops and tile loop dimensions. + rootOp = tiledRootOp->op; + loopOps.append(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); + for (auto en : enumerate(tileSizes)) { + // Copy only the tiled loop dimensions with non-zero tile size. + if (en.value() == 0) + continue; + loopDims.push_back(tileInterchange[en.index()]); + } + assert(isValid() && "expect tile loop nest to be valid after tiling"); + + return success(); +} + +FailureOr TileLoopNest::fuseProducer(OpBuilder &b, + OpOperand *rootOpOperand) { + // Check the tile loop nest is non-empty and satisfies all invariants. + if (isEmpty() || !isValid()) + return failure(); + + // Check `rootOpOperand` is defined by an ExtractSliceOp. + auto sliceOp = rootOpOperand->get().getDefiningOp(); + if (!sliceOp) + return failure(); + + // Check `tileLoopNest` tiles `sliceOp` and `rootOpOperand`. + if (sliceOp->getParentOp() != rootOp->getParentOp() || + rootOpOperand->getOwner() != rootOp) + return failure(); + + // Check if the producer is a LinalgOp possibly passed by iteration argument. + OpOperand *iterArg = nullptr; + auto producerResult = sliceOp.source().dyn_cast(); + if (auto bbArg = sliceOp.source().dyn_cast()) { + iterArg = getTiedIterArg(bbArg); + // Check the iteration argument may be used to pass in the producer output. + if (!iterArg || hasOtherUses(bbArg, sliceOp)) + return failure(); + producerResult = iterArg->get().dyn_cast(); + } + if (!producerResult || !isa(producerResult.getOwner())) + return failure(); + + // TODO: support producers that have index semantics. + if (cast(producerResult.getOwner()).hasIndexSemantics()) + return failure(); + + // Compute the slice dimensions tiled by `tileLoopNest`. + SmallVector tiledSliceDims = + getTiledSliceDims(producerResult, rootOpOperand, loopDims); + if (tiledSliceDims.empty()) + return failure(); + + // Tile the producer operands and clone the producer in place of `sliceOp`. + LinalgOp clonedOp = + getTiledProducer(b, producerResult, sliceOp, tiledSliceDims, iterArg); + + // Cast the `clonedOp` result to gap type mismatches before canonicalization. + Type consumerOperandType = rootOpOperand->get().getType(); + Value newResult = clonedOp->getResult(producerResult.getResultNumber()); + if (newResult.getType() != consumerOperandType) { + OpBuilder::InsertionGuard guard(b); + b.setInsertionPointAfter(clonedOp); + newResult = b.create(producerResult.getLoc(), + consumerOperandType, newResult); + } + + // Replace the `sliceOp` uses except for the `clonedOp` output uses. + sliceOp.getResult().replaceAllUsesExcept(newResult, clonedOp); + return clonedOp; +} + +//===----------------------------------------------------------------------===// +// Tile and fuse entry-points. +//===----------------------------------------------------------------------===// + +FailureOr +mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, + ArrayRef tileSizes, + ArrayRef tileInterchange) { + assert(tileSizes.size() == tileInterchange.size() && + "expect the number of tile sizes and interchange dims to match"); + + // Create an empty tile loop nest. + TileLoopNest tileLoopNest(consumerOp); + + // Search the number of outer parallel loops to separate them from possible + // inner reduction dimensions. + SmallVector iterTypes = + llvm::to_vector<6>(consumerOp.iterator_types().getAsRange()); + applyPermutationToVector( + iterTypes, + SmallVector(tileInterchange.begin(), tileInterchange.end())); + auto *it = find_if(iterTypes, [&](StringAttr iterType) { + return !isParallelIterator(iterType); + }); + int64_t split = std::distance(iterTypes.begin(), it); + + // Tile the outer parallel loops and fuse the output operands. + SmallVector outerTileSizes; + outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); + outerTileSizes.append(tileSizes.size() - split, 0); + if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange))) + return failure(); + for (OpOperand *opOperand : tileLoopNest.getRootOp().getOutputOperands()) + (void)tileLoopNest.fuseProducer(b, opOperand); + + // Tile the remaining loops and fuse the input operands. + SmallVector innerTileSizes; + innerTileSizes.append(split, 0); + innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); + if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) + return failure(); + SmallVector inputOperands = + tileLoopNest.getRootOp().getInputOperands(); + for (OpOperand *opOperand : tileLoopNest.getRootOp().getInputOperands()) + (void)tileLoopNest.fuseProducer(b, opOperand); + + return tileLoopNest; +} + +namespace { +struct LinalgTileAndFuseTensorOps + : public LinalgTileAndFuseTensorOpsBase { + + void notifyFailure(StringRef message) { + llvm::errs() << " - LinalgTileAndFuseTensorOps: " << message << "\n"; + signalPassFailure(); + } + + void runOnFunction() override { + FuncOp funcOp = getFunction(); + OpBuilder b(funcOp.getContext()); + + // Heuristic to find a goor operation to tile and start fusion. Walk all + // operations and select the one with the maximal backward slice of fusion + // candidates. + LinalgOp rootOp = nullptr; + int64_t numFusionCandidates = -1; + funcOp.walk([&](LinalgOp linalgOp) { + SetVector backwardSlice; + getBackwardSlice(linalgOp, &backwardSlice); + int64_t backwardSliceSize = count_if( + backwardSlice, [](Operation *op) { return isa(op); }); + if (backwardSliceSize > numFusionCandidates) { + rootOp = linalgOp; + numFusionCandidates = backwardSliceSize; + } + }); + if (!rootOp) + return notifyFailure("expect to find a root operation"); + + // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. + if (tileSizes.size() < rootOp.getNumLoops()) + return notifyFailure("expect #tile sizes >= #loops"); + + // Check `tileInterchange` contains no entries or as many as `tileSizes`. + if (!tileInterchange.empty() && + tileInterchange.size() != tileSizes.size()) { + return notifyFailure( + "expect the number of tile sizes and interchange dims to match"); + } + + // Copy the `tileSizes` and `tileInterchange` prefixes needed to tile + // `rootOp` or use the identity interchange if `tileInterchange` is empty. + SmallVector rootTileSizes( + tileSizes.begin(), tileSizes.begin() + rootOp.getNumLoops()); + SmallVector rootInterchange = + tileInterchange.empty() + ? llvm::to_vector<6>(llvm::seq(0, tileSizes.size())) + : SmallVector(tileInterchange.begin(), + tileInterchange.begin() + + rootOp.getNumLoops()); + + // As a tiling can only tile a loop dimension once, `rootInterchange` has to + // be a permutation of the `rootOp` loop dimensions. + SmallVector rootInterchangeExprs; + transform(rootInterchange, std::back_inserter(rootInterchangeExprs), + [&](int64_t dim) { return b.getAffineDimExpr(dim); }); + AffineMap rootInterchangeMap = AffineMap::get( + rootOp.getNumLoops(), 0, rootInterchangeExprs, funcOp.getContext()); + if (!rootInterchangeMap.isPermutation()) + return notifyFailure( + "expect the tile interchange permutes the root loops"); + + // Tile `rootOp` and fuse its producers. + if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, + rootInterchange))) + return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly"); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgTileAndFuseTensorOpsPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -0,0 +1,190 @@ +// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> + +// CHECK: fuse_input +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +builtin.func @fuse_input(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + + // Tile both input operand dimensions. + // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]]) + // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK-SAME: %[[UB1]], %[[UB2]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: %{{.*}} = linalg.matmul ins(%[[T1]] + %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)> + +// CHECK: fuse_output +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_output(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // Update the iteration argument of the outermost tile loop. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // CHECK: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]]) + + // Tile the both output operand dimensions. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // CHECK-SAME: %[[IV1]], %[[IV0]] + // CHECK-SAME: %[[TS1]], %[[TS0]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // CHECK: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]] + %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +// CHECK: fuse_reduction +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32> +builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: tensor<12x7x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg3 : tensor<12x7x25xf32>) outs(%arg1 : tensor<12x25xf32>) { + ^bb0(%arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<12x25xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]]) + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // CHECK: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]]) + + // Tile only the parallel dimensions but not the reduction dimension. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV2]], 0, %[[IV0]] + // CHECK-SAME: %[[UB2]], 7, %[[UB0]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CHECK-SAME: %[[IV2]], %[[IV0]] + // CHECK-SAME: %[[UB2]], %[[UB0]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]] + %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +#map0 = affine_map<(d0, d1) -> (d1, d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> + +// CHECK: fuse_transposed +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32> +builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>, + %arg3: tensor<12x24xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%arg3 : tensor<12x24xf32>) outs(%arg0 : tensor<24x12xf32>) { + ^bb0(%arg4: f32, %arg5: f32): // no predecessors + %2 = addf %arg4, %arg5 : f32 + linalg.yield %2 : f32 + } -> tensor<24x12xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + + // Swap the input operand slice offsets due to the transposed indexing map. + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV2]], %[[IV1]] + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // CHECK: %{{.*}} = linalg.matmul ins(%[[T2]] + %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %1 : tensor<24x25xf32> +} + +// ----- + +// CHECK: fuse_input_and_output +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>, + %arg1: tensor<12x25xf32>, + %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { + %c0 = constant 0 : index + %c12 = constant 12 : index + %c25 = constant 25 : index + %c24 = constant 24 : index + %c4 = constant 4 : index + %cst = constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> + %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> + + // Fuse both producers to the appropriate tile loops. + // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // CHECK-SAME: %[[IV1]], %[[IV0]] + // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV2]] + // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]] + %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32> + return %2 : tensor<24x25xf32> +}