diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ #define MLIR_DIALECT_LINALG_ANALYSIS_DEPENDENCEANALYSIS_H_ +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/OpDefinition.h" @@ -67,7 +68,7 @@ // Builds a linalg dependence graph for the ops of type LinalgOp under `f`. static LinalgDependenceGraph buildDependenceGraph(Aliases &aliases, FuncOp f); - LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); + LinalgDependenceGraph(Aliases &aliases, ArrayRef ops); /// Returns the X such that op -> X is a dependence of type dt. dependence_range getDependencesFrom(Operation *src, DependenceType dt) const; @@ -168,7 +169,7 @@ ArrayRef types) const; Aliases &aliases; - SmallVector linalgOps; + SmallVector linalgOps; DenseMap linalgOpPositions; }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -86,21 +86,21 @@ LinalgDependenceGraph LinalgDependenceGraph::buildDependenceGraph(Aliases &aliases, FuncOp f) { - SmallVector linalgOps; + SmallVector linalgOps; f.walk([&](LinalgOp op) { linalgOps.push_back(op); }); return LinalgDependenceGraph(aliases, linalgOps); } LinalgDependenceGraph::LinalgDependenceGraph(Aliases &aliases, - ArrayRef ops) + ArrayRef ops) : aliases(aliases), linalgOps(ops.begin(), ops.end()) { for (auto en : llvm::enumerate(linalgOps)) { - assert(isa(en.value()) && "Expected value for LinalgOp"); - linalgOpPositions.insert(std::make_pair(en.value(), en.index())); + linalgOpPositions.insert( + std::make_pair(en.value().getOperation(), en.index())); } for (unsigned i = 0, e = ops.size(); i < e; ++i) { for (unsigned j = i + 1; j < e; ++j) { - addDependencesBetween(cast(ops[i]), cast(ops[j])); + addDependencesBetween(ops[i], ops[j]); } } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include #include #define DEBUG_TYPE "linalg-fusion" @@ -190,48 +191,15 @@ llvm_unreachable("Expect to be able to extract a shape defining loop range"); } -/// Fuses the producer of `producerIdx` into the loop immediately enclosing -/// `consumer`. This is achieved by "recomputing" the `producer` at the time it -/// is needed just before the `consumer. -/// -/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are -/// 2 cases: -/// 1. Buffer case: `producerIdx` is the index of the buffer in -/// `producer.getOutputBuffers()`. -/// 2. Tensor case: `producerIdx` is the index of the tensor in -/// `producer.getResults()`. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, - LinalgOp consumer, unsigned consumerIdx) { - Operation *shapeProducingOp = - consumer.getShapedOperand(consumerIdx).getDefiningOp(); - assert((isa(shapeProducingOp) || - isa(shapeProducingOp)) && - "SubviewOp or SubTensorOp expected"); - - // loopToOperandRangesMaps are permutations-only by construction: - // we can always identify a data dimension with a (at least one) loop - // dimension. - // TODO: extend this with range inference. - AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); - LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx - << ", producer map: " << producerMap << "\n"); +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, + const DenseMap &fusedLoopsAndRanges) { unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); - - // Iterate over dimensions identified by the producer map for `producerIdx`. - // This defines a subset of the loop ranges that we need to complete later. - auto loc = consumer.getLoc(); - for (auto en : llvm::enumerate(producerMap.getResults())) { - unsigned posInProducerLoop = en.value().cast().getPosition(); - loopRanges[posInProducerLoop] = - isa(shapeProducingOp) - ? cast(shapeProducingOp) - .getOrCreateRanges(b, loc)[en.index()] - : cast(shapeProducingOp) - .getOrCreateRanges(b, loc)[en.index()]; + for (auto fusedLoops : fusedLoopsAndRanges) { + loopRanges[fusedLoops.first] = fusedLoops.second; } // Iterate over all dimensions. For the dimensions not identified by the @@ -250,7 +218,45 @@ } } - return cloneWithLoopRanges(b, loc, producer, loopRanges); + return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges); +} + +static Range getRangeFromOperandShape(OpBuilder &b, Location loc, + Value shapedOperand, unsigned dim) { + Operation *shapeProducingOp = shapedOperand.getDefiningOp(); + if (auto subViewOp = dyn_cast(shapeProducingOp)) { + return subViewOp.getOrCreateRanges(b, loc)[dim]; + } + if (auto subTensorOp = dyn_cast(shapeProducingOp)) { + return subTensorOp.getOrCreateRanges(b, loc)[dim]; + } + llvm_unreachable("SubviewOp or SubTensorOp expected"); +} + +/// Fuses the producer of `producerIdx` into the loop immediately enclosing +/// `consumer`. This is achieved by "recomputing" the `producer` at the time it +/// is needed just before the `consumer. +/// +/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are +/// 2 cases: +/// 1. Buffer case: `producerIdx` is the index of the buffer in +/// `producer.getOutputBuffers()`. +/// 2. Tensor case: `producerIdx` is the index of the tensor in +/// `producer.getResults()`. +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, + LinalgOp consumer, unsigned consumerIdx) { + AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); + LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx + << ", producer map: " << producerMap << "\n"); + DenseMap fusedLoopsAndRanges; + Location loc = consumer.getLoc(); + Value shapedOperand = consumer.getShapedOperand(consumerIdx); + for (auto en : llvm::enumerate(producerMap.getResults())) { + unsigned posInProducerLoop = en.value().cast().getPosition(); + fusedLoopsAndRanges[posInProducerLoop] = + getRangeFromOperandShape(b, loc, shapedOperand, en.index()); + } + return fuse(b, producer, fusedLoopsAndRanges); } // Encode structural fusion safety preconditions. @@ -600,8 +606,9 @@ /// submap with only parallel loops = affine_map<(i, j) -> (j)> /// Fused dimensions : j static std::set -collectTileAndFuseLoops(LinalgOp op, +collectTileAndFuseLoops(ArrayRef ops, const FusableOpDependencesTy &fusableDependences) { + assert(!ops.empty()); auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { return linalgOp.iterator_types() .getValue() @@ -612,79 +619,73 @@ .size(); }; - LLVM_DEBUG({ - llvm::dbgs() << "Op : "; - op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n"; - }); - - size_t numOuterParallelLoops = getNumOuterParallelLoops(op); - for (auto dependence : fusableDependences) { - linalg::LinalgOp producer = cast(dependence.first); + size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); + for (auto op : ops.drop_back()) { numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer)); + std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); } std::set fusableLoops; auto range = llvm::seq(0, numOuterParallelLoops); fusableLoops.insert(range.begin(), range.end()); - for (auto dependence : fusableDependences) { - LLVM_DEBUG({ - llvm::dbgs() << "\t fusable :"; - for (unsigned i : fusableLoops) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - linalg::LinalgOp producer = cast(dependence.first); - - assert(!dependence.second.empty() && - "unexpected producer but not dependences"); - AffineMap producerIndexingMap = producer.getIndexingMap( - dependence.second.front().dependentOpView.operandNum); - AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), producerIndexingMap); - if (!prunedProducerIndexingMap.isPermutation()) - return {}; - - AffineMap consumerIndexingMap = - op.getIndexingMap(dependence.second.front().indexingOpView.operandNum); - if (consumerIndexingMap.getNumResults() != - prunedProducerIndexingMap.getNumResults()) - return {}; - - LLVM_DEBUG({ - llvm::dbgs() << "\t producerMap : "; - producerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << " pruned : "; - prunedProducerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - AffineMap invProducerIndexMap = - inversePermutation(prunedProducerIndexingMap); - if (!invProducerIndexMap) - return {}; - - AffineMap consumerLoopToProducerLoop = - invProducerIndexMap.compose(consumerIndexingMap); - std::set candidates; - for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { - unsigned position = expr.cast().getPosition(); - if (fusableLoops.count(position)) - candidates.insert(position); + + for (auto op : reverse(ops)) { + for (auto dependence : fusableDependences.lookup(op)) { + LLVM_DEBUG({ + llvm::dbgs() << "\t fusable :"; + for (unsigned i : fusableLoops) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + + AffineMap producerIndexingMap = + op.getIndexingMap(dependence.dependentOpView.operandNum); + AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( + op.iterator_types().getValue(), producerIndexingMap); + if (!prunedProducerIndexingMap.isPermutation()) + return {}; + + LinalgOp consumer = cast(dependence.indexingOpView.op); + AffineMap consumerIndexingMap = + consumer.getIndexingMap(dependence.indexingOpView.operandNum); + if (consumerIndexingMap.getNumResults() != + prunedProducerIndexingMap.getNumResults()) + return {}; + + LLVM_DEBUG({ + llvm::dbgs() << "\t producerMap : "; + producerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << " pruned : "; + prunedProducerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + llvm::dbgs() << "\t consumerMap : "; + consumerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + AffineMap invProducerIndexMap = + inversePermutation(prunedProducerIndexingMap); + if (!invProducerIndexMap) + return {}; + + AffineMap consumerLoopToProducerLoop = + invProducerIndexMap.compose(consumerIndexingMap); + std::set candidates; + for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { + unsigned position = expr.cast().getPosition(); + if (fusableLoops.count(position)) + candidates.insert(position); + } + LLVM_DEBUG({ + llvm::dbgs() << "\t candidates :"; + for (unsigned i : candidates) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + if (candidates.empty()) + return {}; + std::swap(candidates, fusableLoops); } - LLVM_DEBUG({ - llvm::dbgs() << "\t candidates :"; - for (unsigned i : candidates) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - if (candidates.empty()) - return {}; - std::swap(candidates, fusableLoops); } return fusableLoops; @@ -692,60 +693,69 @@ /// Find all dependences that are to be fusable. static FusableOpDependencesTy -findAllFusableDependences(LinalgOp op, - const LinalgDependenceGraph &dependenceGraph, - const LinalgFusionOptions &fusionOptions) { +findAllFusableDependences(ArrayRef ops, + const LinalgDependenceGraph &dependenceGraph) { FusableOpDependencesTy fusableDependences; // TODO: Currently fusion would not be legal if the fusable dependence is to // the same producer but different indexing map in the consumer. Fix this, but // in the meanwhile disallow such a fusion. DenseMap fusedProducerIndexingMap; - for (auto operandIndex : fusionOptions.indicesToFuse) { - auto fusableDependence = - findFusableProducer(op, operandIndex, dependenceGraph); - if (!fusableDependence) - return FusableOpDependencesTy{}; - LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); - // Do not fuse dependences that are to operations not in the same basic - // block. This avoid moving fused operations across loops that might - // themselves carry dependency making the fusion illegal. - if (producerOp.getOperation()->getBlock() != - op.getOperation()->getBlock()) { - op.emitRemark("unhandled fusion of ops in different basic blocks"); - return FusableOpDependencesTy{}; - } - // Make sure that the indexing map of the view used for fusion in the - // producer is a projected permutation. - unsigned producerIdx = fusableDependence->dependentOpView.operandNum; - AffineMap producerMap = producerOp.getIndexingMap(producerIdx); - if (!producerMap.isProjectedPermutation()) { - op.emitRemark("unhandled non permutation indexing map for fused view in " - "producer for operand at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } + for (LinalgOp op : reverse(ops)) { + for (unsigned operandNum : + llvm::seq(0, op.getNumInputsAndOutputBuffers())) { + Optional + fusableDependence = + findFusableProducer(op, operandNum, dependenceGraph); + if (!fusableDependence) + continue; + LinalgOp producerOp = + cast(fusableDependence->dependentOpView.op); + // Do not fuse dependences that are to operations not in the same basic + // block. This avoid moving fused operations across loops that might + // themselves carry dependency making the fusion illegal. + if (producerOp.getOperation()->getBlock() != + op.getOperation()->getBlock()) { + op.emitRemark("unhandled fusion of ops in different basic blocks"); + return FusableOpDependencesTy{}; + } + // Make sure that the indexing map of the view used for fusion in the + // producer is a projected permutation. + unsigned producerIdx = fusableDependence->dependentOpView.operandNum; + AffineMap producerMap = producerOp.getIndexingMap(producerIdx); + if (!producerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled non permutation indexing map for fused view in " + "producer for operand at index ") + << operandNum; + return FusableOpDependencesTy{}; + } - unsigned consumerIdx = fusableDependence->indexingOpView.operandNum; - AffineMap consumerMap = op.getIndexingMap(consumerIdx); - if (!consumerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled case where indexing map for fused view in the consumer is " - "not a projected permutation while fusing at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } + unsigned consumerIdx = fusableDependence->indexingOpView.operandNum; + AffineMap consumerMap = op.getIndexingMap(consumerIdx); + if (!consumerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled case where indexing map for fused view in the consumer " + "is " + "not a projected permuration while fusing at index ") + << operandNum; + return FusableOpDependencesTy{}; + } - // Check if the producer is already a fusion candidate. Cannot fuse this - // dependence if it has a different indexing map when used in the consumer. - if (fusedProducerIndexingMap.count(producerOp.getOperation()) && - fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { - op.emitRemark("unhandled fusion to the same producer but with different " - "indexing maps"); - return FusableOpDependencesTy{}; - } - fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + // Check if the producer is already a fusion candidate. Cannot fuse this + // dependence if it has a different indexing map when used in the + // consumer. + if (fusedProducerIndexingMap.count(producerOp.getOperation()) && + fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { + op.emitRemark( + "unhandled fusion to the same producer but with different " + "indexing maps"); + return FusableOpDependencesTy{}; + } + fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; - fusableDependences[producerOp.getOperation()].push_back(*fusableDependence); + fusableDependences[producerOp.getOperation()].push_back( + *fusableDependence); + } } return fusableDependences; } @@ -756,46 +766,110 @@ return false; } +/// Tile the fused loops in the root operation, by setting the tile sizes for +/// all other loops to zero (those will be tiled later). +static Optional tileRootOperation( + OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, + const LinalgTilingOptions &options, const std::set &fusedLoops) { + SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); + auto zero = std_constant_index(0); + for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) { + if (!fusedLoops.count(i)) + tileSizes[i] = zero; + } + LinalgTilingOptions tileFusedLoopsOptions = options; + tileFusedLoopsOptions.setTileSizes(tileSizes); + return tileLinalgOp(builder, op, tileFusedLoopsOptions); +} + +/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected +/// to be a tiled operation such that it is valid to fuse all operations in +/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of +/// `tiledOp`. +static SmallVector +fuseOperations(OpBuilder &builder, LinalgOp tiledOp, + ArrayRef fusionCandidates, + const std::set &fusedLoops) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(tiledOp); + DenseMap fusedLoopsAndRanges; + for (unsigned loop : fusedLoops) { + ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop); + fusedLoopsAndRanges[loop] = getRangeFromOperandShape( + builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); + } + SmallVector fusedOps(fusionCandidates.size()); + for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { + LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); + fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; + builder.setInsertionPoint(fusedOp); + } + return fusedOps; +} + +/// Post fusion, tile all the unfused loops of the fused operations. +static Optional tileUnfusedLoops( + OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, + const LinalgTilingOptions &options, const std::set &fusedLoops) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(op); + SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); + auto zero = std_constant_index(0); + for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) { + if (fusedLoops.count(i)) + tileSizes[i] = zero; + } + tileSizes.resize(op.getNumLoops()); + if (llvm::all_of(tileSizes, isZero)) + return TiledLinalgOp{op, {}}; + LinalgTilingOptions tileUnfusedLoopsOptions = options; + tileUnfusedLoopsOptions.setTileSizes(tileSizes); + tileUnfusedLoopsOptions.distribution = llvm::None; + return tileLinalgOp(builder, op, tileUnfusedLoopsOptions); +} + template static Optional -tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, +tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - // Some of the tiling options might not be supportable with tile and fuse. - // TODO: Support interchange with tile + fuse. - if (!tilingOptions.interchangeVector.empty()) { - op.emitError("unable to handle tile and fuse with interchange"); + ArrayRef tilingOptions) { + if (ops.empty()) return llvm::None; + LinalgOp rootOp = ops.back(); + if (ops.size() != tilingOptions.size()) { + rootOp.emitError( + "need tiling options for all the operations that are to be fused"); + return llvm::None; + } + for (auto op : enumerate(ops)) { + // TODO: Nothing in the fusion of sequence of ops is specific to + // buffers. This check can be removed after it is tested on tensors. + LinalgOp linalgOp = op.value(); + if (!linalgOp.hasBufferSemantics()) { + linalgOp.emitError("tile and fuse only tested for buffer operation"); + return llvm::None; + } + // TODO: Support interchange with tile + fuse. This might actually help do + // better fusion. + if (!tilingOptions[op.index()].interchangeVector.empty()) { + linalgOp.emitError("unable to handle tile and fuse with interchange"); + return llvm::None; + } } - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); - ScopedContext scope(rewriter, op.getLoc()); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(rootOp); + ScopedContext scope(builder, rootOp.getLoc()); // Find all the producers. FusableOpDependencesTy fusableDependences = - findAllFusableDependences(op, dependenceGraph, fusionOptions); + findAllFusableDependences(ops, dependenceGraph); if (fusableDependences.empty()) return llvm::None; - // Enforce the convention that "tiling by zero" skips tiling a particular - // dimension. This convention is significantly simpler to handle instead of - // adjusting affine maps to account for missing dimensions. - auto nLoops = op.getNumLoops(); - SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(rewriter, op); - if (tileSizeVector.size() < nLoops) { - auto zero = std_constant_index(0); - tileSizeVector.append(nLoops - tileSizeVector.size(), zero); - } - - TiledAndFusedLinalgOps ret; - // Find the loops that can be tiled and fused. std::set tileFuseLoops = - collectTileAndFuseLoops(op, fusableDependences); + collectTileAndFuseLoops(ops, fusableDependences); // If there are no fusable dependences or there are no tile+fusable loops, // just return. @@ -803,88 +877,112 @@ return llvm::None; } - // Get the tile sizes for the first and second tiling steps. For the first - // step the tile size are set to zero for the loops that arent - // fused. Similarly for the second step, the tile sizes are set to zero for - // the loops that are fused. For example, if for the following input - // - // ``` - // linalg.add ins(%a, %b) outs(%c) - // linalg.matmul ins(%d, %c) outs(%e) - // ``` - // - // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` - // respectively, and since only `j` can be tiled and fused. The tile sizes - // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable - // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile - // the tiled matmul generated by the first tiling step. - SmallVector tileAndFuseSizes, tileSizes; - for (auto tileSize : enumerate(tileSizeVector)) { - auto zero = std_constant_index(0); - if (tileFuseLoops.count(tileSize.index())) { - tileAndFuseSizes.push_back(tileSize.value()); - tileSizes.push_back(zero); - } else { - tileSizes.push_back(tileSize.value()); - tileAndFuseSizes.push_back(zero); + // Tile the fused loops in the last operation in the list. + SmallVector tileSizeVector = + tilingOptions.back().tileSizeComputationFunction(builder, rootOp); + TiledAndFusedLinalgOps ret; + Optional tiledRootOp = tileRootOperation( + builder, rootOp, tileSizeVector, tilingOptions.back(), tileFuseLoops); + if (!tiledRootOp) { + rootOp.emitError("failed to tile the fused loops"); + return llvm::None; + } + ret.op = tiledRootOp->op; + ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); + + // Fuse the other operations into the fused inter-tile loops produced above. + ret.fusedProducers = + fuseOperations(builder, ret.op, ops.drop_back(), tileFuseLoops); + + // Tile the unfused loops for the fused ops. + ret.unfusedLoops.resize(ops.size()); + for (auto op : enumerate(ret.fusedProducers)) { + const LinalgTilingOptions &option = tilingOptions[op.index()]; + if (!option.tileSizeComputationFunction) + continue; + SmallVector producerTileSizeVector = + option.tileSizeComputationFunction(builder, op.value()); + Optional tiledOp = tileUnfusedLoops( + builder, op.value(), producerTileSizeVector, option, tileFuseLoops); + if (!tiledOp) { + op.value().emitError("unable to tile unfused loops"); + return llvm::None; } + if (tiledOp->op != op.value()) { + builder.eraseOp(op.value()); + ret.fusedProducers[op.index()] = tiledOp->op; + } + ret.unfusedLoops[op.index()].assign(tiledOp->loops.begin(), + tiledOp->loops.end()); + builder.setInsertionPoint(tiledOp->op); } - // Tile for the loops that can be fused. - LinalgTilingOptions firstTilingOptions = tilingOptions; - firstTilingOptions.setTileSizes(tileAndFuseSizes); - Optional firstTiledOp = - tileLinalgOp(rewriter, op, firstTilingOptions); - if (!firstTiledOp) + // Tile the unfused loops for the tiled root operation. + Optional rootTileUnfusedLoopsOp = tileUnfusedLoops( + builder, ret.op, tileSizeVector, tilingOptions.back(), tileFuseLoops); + if (!rootTileUnfusedLoopsOp) { + ret.op.emitError("unable to tile unfused loops"); return llvm::None; - ret.op = firstTiledOp->op; - ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); - - rewriter.setInsertionPoint(ret.op); - // Fuse the operands. - for (auto dependence : fusableDependences) { - LinalgOp producerOp = cast(dependence.first); - unsigned producerIdx = dependence.second.front().dependentOpView.operandNum; - unsigned consumerIdx = dependence.second.front().indexingOpView.operandNum; - LinalgOp fusedOp = fuse(rewriter, producerOp, - producerOp.getOutputIndex(producerIdx).getValue(), - ret.op, consumerIdx); - ret.fusedProducers.push_back(fusedOp); - ret.originalProducers.push_back(producerOp); - } - - if (!llvm::all_of(tileSizes, isZero)) { - // Tile the remaining loops of the root operation. - LinalgTilingOptions secondTilingOptions = tilingOptions; - // The distribution is done only for the tile+fused loops. - secondTilingOptions.distribution = llvm::None; - secondTilingOptions.setTileSizes(tileSizes); - Optional secondTiledOp = - tileLinalgOp(rewriter, ret.op, secondTilingOptions); - if (!secondTiledOp) - return llvm::None; - ret.unfusedLoops.assign(secondTiledOp->loops.begin(), - secondTiledOp->loops.end()); - rewriter.eraseOp(ret.op); - ret.op = secondTiledOp->op; } - + if (rootTileUnfusedLoopsOp->op != ret.op) { + builder.eraseOp(ret.op); + ret.op = rootTileUnfusedLoopsOp->op; + } + ret.unfusedLoops.back().assign(rootTileUnfusedLoopsOp->loops.begin(), + rootTileUnfusedLoopsOp->loops.end()); return ret; } -Optional -mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions) { - switch (tilingOptions.loopType) { +Optional mlir::linalg::tileAndFuseLinalgOps( + OpBuilder &builder, ArrayRef ops, + const LinalgDependenceGraph &dependenceGraph, + ArrayRef tilingOptions) { + switch (tilingOptions.back().loopType) { case LinalgTilingLoopType::Loops: - return tileAndFuseLinalgOpsImpl(rewriter, op, dependenceGraph, - tilingOptions, fusionOptions); + return tileAndFuseLinalgOpsImpl(builder, ops, dependenceGraph, + tilingOptions); case LinalgTilingLoopType::ParallelLoops: return tileAndFuseLinalgOpsImpl( - rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + builder, ops, dependenceGraph, tilingOptions); default:; } return llvm::None; } + +namespace { +/// Pass to test tile and fuse of sequence of operations. Intended only for +/// testing. +struct LinalgTileAndFusePass + : public LinalgTileAndFuseOpsBase { + LinalgTileAndFusePass() = default; + LinalgTileAndFusePass(ArrayRef sizes) { tileSizes = sizes; } + + void runOnFunction() override { + FuncOp funcOp = getOperation(); + auto &blocks = funcOp.getBody().getBlocks(); + if (!llvm::hasSingleElement(blocks)) { + return; + } + SmallVector linalgOps = + llvm::to_vector<2>(blocks.front().getOps()); + Aliases aliases; + LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + SmallVector tilingOptions( + linalgOps.size(), + LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( + LinalgTilingLoopType::ParallelLoops)); + OpBuilder builder(funcOp.getContext()); + Optional tileAndFuseOps = tileAndFuseLinalgOps( + builder, linalgOps, dependenceGraph, tilingOptions); + if (!tileAndFuseOps) + return signalPassFailure(); + for (auto op : linalgOps) + builder.eraseOp(op); + } +}; +} // namespace + +std::unique_ptr> +mlir::createLinalgTileAndFusePass(ArrayRef tileSizes) { + return std::make_unique(tileSizes); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -165,17 +165,42 @@ if (!linalgOp.hasBufferSemantics()) return failure(); + llvm::SetVector producers; + producers.insert(linalgOp); + for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) { + if (!fusionOptions.indicesToFuse.count( + dependence.indexingOpView.operandNum)) + continue; + if (isa(dependence.dependentOpView.op)) { + producers.insert(dependence.dependentOpView.op); + } + } + + SmallVector fusionOps; + for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; + ++it) { + auto producerLinalgOp = dyn_cast(&(*it)); + if (producerLinalgOp && producers.count(producerLinalgOp)) + fusionOps.push_back(producerLinalgOp); + } + + SmallVector tilingOptionsVec(fusionOps.size(), + LinalgTilingOptions()); + fusionOps.push_back(linalgOp); + tilingOptionsVec.push_back(tilingOptions); + Optional tiledAndFusedOps = tileAndFuseLinalgOps( - rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + rewriter, fusionOps, dependenceGraph, tilingOptionsVec); if (!tiledAndFusedOps) return failure(); marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); for (auto fusedOp : tiledAndFusedOps->fusedProducers) { fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); } - for (auto origProducerOp : tiledAndFusedOps->originalProducers) + for (auto origProducerOp : ArrayRef(fusionOps).drop_back()) { originalOpMarker.replaceLinalgMarker(rewriter, origProducerOp.getOperation()); + } rewriter.updateRootInPlace( op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); return success(); diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -124,7 +124,7 @@ DenseSet eraseSet; // Save original Linalg ops, we only want to make a pass over those. - SmallVector linalgOps; + SmallVector linalgOps; f.walk([&](LinalgOp op) { // TODO: support multi-results. if (op.getOperation()->getNumResults() <= 1) @@ -133,8 +133,7 @@ // Tile and Fuse for tensors inputs (TODO: all tensor operands). bool changed = false; - for (auto *op : llvm::reverse(linalgOps)) { - LinalgOp linalgOp = cast(op); + for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { if (en.value().getType().isa()) { // TODO: LinalgDependenceGraph should be able to update itself. @@ -142,7 +141,7 @@ // removed. linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); - if (auto info = fuseProducerOfBuffer(b, op, en.index(), graph)) { + if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) { auto *originalOp = info->originalProducer.getOperation(); eraseSet.insert(originalOp); auto *originalOpInLinalgOpsVector = @@ -155,7 +154,7 @@ // Tile and Fuse tensor input (TODO: init_tensors too). if (en.index() >= linalgOp.getNumInputs()) continue; - if (auto info = fuseProducerOfTensor(b, op, en.index())) { + if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) { auto *originalOp = info->originalProducer.getOperation(); auto *originalOpInLinalgOpsVector = std::find(linalgOps.begin(), linalgOps.end(), originalOp);