diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -37,14 +37,6 @@ SmallVector tensorResults; }; -struct TiledAndFusedLinalgOps { - LinalgOp op; - SmallVector fusedProducers; - SmallVector originalProducers; - SmallVector fusedLoops; - SmallVector unfusedLoops; -}; - /// Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns, @@ -73,14 +65,11 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options); -/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in -/// three steps -/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile -/// + fuse loops). -/// - Tile just these loops of the consumer (root operation) and fuse with -/// the producer. -/// - Tile again the tiled consumer operation produced above to do rest of -/// the tiling specified by the `tilingOptions`. +/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This +/// proceeds as follows: +/// - Find outer parallel loops in these ops that can be fused. +/// - Tile fusable outer parallel loops of the last operation in the sequence. +/// - Fuse the remaining operations with the tiled operation /// /// For example, consider the sequence of matmul below /// @@ -107,36 +96,39 @@ /// : memref<256x32xf32> to memref<16x32xf32, #map0> /// %3 = subview %arg1[0, 0] [32, 32] [1, 1] /// : memref<32x32xf32> to memref<32x32xf32, #map1> +/// %4 = subview %arg3[0, 0] [32, 32] [1, 1] +/// : memref<32x32xf32> to memref<32x32xf32, #map1> /// linalg.matmul /// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) /// outs(%0 : memref<16x32xf32, #map0>) -/// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) { -/// scf.for %arg7 = %c0 to %c32 step %c4 { -/// %4 = subview %0[0, %arg7] [16, 4] [1, 1] -/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0> -/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1] -/// : memref<32x32xf32> to memref<4x8xf32, #map0> -/// %6 = subview %1[0, %arg6] [16, 8] [1, 1] -/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0> -/// linalg.matmul -/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) -/// outs(%6 : memref<16x8xf32, #map0>) -/// } -/// scf.yield -/// } -/// scf.yield +/// linalg.matmul +/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) +/// outs(%1 : memref<16x8xf32, #map0>) /// } /// -/// The following tiling options are handled differently in tile+fuse (compared -/// to tile only) +/// `tilingOptions` are used to tile the corresponding operation in `ops` (the +/// size of the former should be same as size of the latter. Based on how +/// tile+fuse is implemented, the fused loops are generated based on the last +/// operation in the sequence. For example, the tile sizes for the fused loops +/// is obtained from `tilingOptions.back()`. The following tiling options are +/// handled differently in tile+fuse (compared to tile only) /// - Interchange of the tiling loops is not supported right now. -/// - Distribution is only done for the tile+fuse loops. The tiled loops -/// generated by the second tiling is not distributed. +/// - Only the fused loops are distributed. +struct TiledAndFusedLinalgOps { + /// Operation obtained by tiling the last operation in sequence of `ops` + /// passed to `tileAndFuseLinalgOps`. + LinalgOp op; + /// The dimension of the loops that are fused. + std::set fusedLoopDims; + /// The fused operations. + SmallVector fusedProducers; + /// The fused loop generated. + SmallVector fusedLoops; +}; Optional -tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, +tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions); + const LinalgTilingOptions &tilingOptions); /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. /// This is an in-place transformation controlled by `interchangeVector`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -178,6 +178,9 @@ Value shape = en.value(); SmallVector shapeRanges(map.getNumResults(), nullptr); for (auto en2 : llvm::enumerate(map.getResults())) { + auto dimExpr = en2.value().dyn_cast(); + if (!dimExpr) + continue; if (loopDepth == en2.value().cast().getPosition()) { LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " << loopDepth << "\n"); @@ -190,49 +193,18 @@ llvm_unreachable("Expect to be able to extract a shape defining loop range"); } -/// Fuses the producer of `producerIdx` into the loop immediately enclosing -/// `consumer`. This is achieved by "recomputing" the `producer` at the time it -/// is needed just before the `consumer. -/// -/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are -/// 2 cases: -/// 1. Buffer case: `producerIdx` is the index of the buffer in -/// `producer.getOutputBuffers()`. -/// 2. Tensor case: `producerIdx` is the index of the tensor in -/// `producer.getResults()`. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, - LinalgOp consumer, unsigned consumerIdx) { - Operation *shapeProducingOp = - consumer.getShapedOperand(consumerIdx).getDefiningOp(); - assert((isa(shapeProducingOp) || - isa(shapeProducingOp)) && - "SubviewOp or SubTensorOp expected"); - - // loopToOperandRangesMaps are permutations-only by construction: - // we can always identify a data dimension with a (at least one) loop - // dimension. - // TODO: extend this with range inference. - AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); - LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx - << ", producer map: " << producerMap << "\n"); +/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges` +/// provides the loop range information for the fused loops. The rest are +/// obtained from the producer itself, since they are not tiled + fused. +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, + const DenseMap &fusedLoopsAndRanges) { unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); - - // Iterate over dimensions identified by the producer map for `producerIdx`. - // This defines a subset of the loop ranges that we need to complete later. - auto loc = consumer.getLoc(); - for (auto en : llvm::enumerate(producerMap.getResults())) { - unsigned posInProducerLoop = en.value().cast().getPosition(); - loopRanges[posInProducerLoop] = - isa(shapeProducingOp) - ? cast(shapeProducingOp) - .getOrCreateRanges(b, loc)[en.index()] - : cast(shapeProducingOp) - .getOrCreateRanges(b, loc)[en.index()]; - } + for (auto fusedLoops : fusedLoopsAndRanges) + loopRanges[fusedLoops.first] = fusedLoops.second; // Iterate over all dimensions. For the dimensions not identified by the // producer map for `producerIdx`, we need to explicitly compute the shape @@ -250,7 +222,46 @@ } } - return cloneWithLoopRanges(b, loc, producer, loopRanges); + return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges); +} + +/// Helper function to get the loop range for a dimension `dim` based on the +/// `shapedOperand`. It is expected to be defined by a subview op or a subtensor +/// op. +static Range getRangeFromOperandShape(OpBuilder &b, Location loc, + Value shapedOperand, unsigned dim) { + Operation *shapeProducingOp = shapedOperand.getDefiningOp(); + if (auto subViewOp = dyn_cast(shapeProducingOp)) + return subViewOp.getOrCreateRanges(b, loc)[dim]; + if (auto subTensorOp = dyn_cast(shapeProducingOp)) + return subTensorOp.getOrCreateRanges(b, loc)[dim]; + llvm_unreachable("SubviewOp or SubTensorOp expected"); +} + +/// Fuses the producer of `producerIdx` into the loop immediately enclosing +/// `consumer`. This is achieved by "recomputing" the `producer` at the time it +/// is needed just before the `consumer. +/// +/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are +/// 2 cases: +/// 1. Buffer case: `producerIdx` is the index of the buffer in +/// `producer.getOutputBuffers()`. +/// 2. Tensor case: `producerIdx` is the index of the tensor in +/// `producer.getResults()`. +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, + LinalgOp consumer, unsigned consumerIdx) { + AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); + LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx + << ", producer map: " << producerMap << "\n"); + DenseMap fusedLoopsAndRanges; + Location loc = consumer.getLoc(); + Value shapedOperand = consumer.getShapedOperand(consumerIdx); + for (auto en : llvm::enumerate(producerMap.getResults())) { + unsigned posInProducerLoop = en.value().cast().getPosition(); + fusedLoopsAndRanges[posInProducerLoop] = + getRangeFromOperandShape(b, loc, shapedOperand, en.index()); + } + return fuse(b, producer, fusedLoopsAndRanges); } // Encode structural fusion safety preconditions. @@ -525,6 +536,68 @@ Operation *, SmallVector>; +/// Returns the mapping from iterations in the consumer that write to the same +/// location as the iterations in the producer. To do so use +/// - indexing map of the fused view in the consumer : consumerIndexMap +/// - indexing map of the fused view in the producer : producerIndexMap +/// consumerLoopToProducerLoop = +/// inverse(producerIndexMap).compose(consumerIndexMap) +static Optional getConsumerLoopToProducerLoopMap( + LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { + auto producer = cast(dependence.dependentOpView.op); + AffineMap producerIndexingMap = + producer.getIndexingMap(dependence.dependentOpView.operandIndex); + auto consumer = cast(dependence.indexingOpView.op); + AffineMap consumerIndexingMap = + consumer.getIndexingMap(dependence.indexingOpView.operandIndex); + + AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( + producer.iterator_types().getValue(), producerIndexingMap); + if (!prunedProducerIndexingMap.isPermutation()) + return None; + + if (consumerIndexingMap.getNumResults() != + prunedProducerIndexingMap.getNumResults()) + return None; + + LLVM_DEBUG({ + llvm::dbgs() << "\t producerMap : "; + producerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << " pruned : "; + prunedProducerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + llvm::dbgs() << "\t consumerMap : "; + consumerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); + if (!invProducerIndexMap) + return None; + + return invProducerIndexMap.compose(consumerIndexingMap); +} + +/// Given a projected permutation `map`, returns true if the map changes the +/// order in which the fused loop dimension appear. +static bool doesTransposeAccess(AffineMap map, + const std::set &fusableLoops) { + Optional lastFusableLoop; + for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { + return expr.cast().getPosition(); + })) { + if (!fusableLoops.count(pos)) + continue; + if (!lastFusableLoop) { + lastFusableLoop = pos; + continue; + } + if (pos <= lastFusableLoop.getValue()) + return true; + } + return false; +} + /// Returns the positions of the loop in `op` that can be tiled based on the /// operations that are to be fused with it. For example, in a /// @@ -538,13 +611,7 @@ /// 2. Of the parallel loops only some can be fused. Only those loops can be /// fused such where the fusable loops iteration space only touches one tile /// of the fused operation. This is because the producer (which is writing -/// the fused subview) has update semantics. To compute this, -/// a. Find the mapping from iterations in the consumer that write to the -/// same location as the iterations in the producer. To do so use -/// - indexing map of the fused view in the consumer : consumerIndexMap -/// - indexing map of the fused view in the producer : producerIndexMap -/// consumerLoopToProducerLoop = -/// inverse(producerIndexMap).compose(consumerIndexMap) +/// the fused subview) has update semantics. /// /// Since an inverse computation is needed, we need to consider the projection /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops @@ -582,8 +649,9 @@ /// submap with only parallel loops = affine_map<(i, j) -> (j)> /// Fused dimensions : j static std::set -collectTileAndFuseLoops(LinalgOp op, +collectTileAndFuseLoops(ArrayRef ops, const FusableOpDependencesTy &fusableDependences) { + assert(!ops.empty()); auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { return linalgOp.iterator_types() .getValue() @@ -594,88 +662,57 @@ .size(); }; - LLVM_DEBUG({ - llvm::dbgs() << "Op : "; - op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n"; - }); - - size_t numOuterParallelLoops = getNumOuterParallelLoops(op); - for (auto dependence : fusableDependences) { - linalg::LinalgOp producer = cast(dependence.first); + size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); + for (auto op : ops.drop_back()) { numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer)); + std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); } std::set fusableLoops; auto range = llvm::seq(0, numOuterParallelLoops); fusableLoops.insert(range.begin(), range.end()); - for (auto dependence : fusableDependences) { - LLVM_DEBUG({ - llvm::dbgs() << "\t fusable :"; - for (unsigned i : fusableLoops) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - linalg::LinalgOp producer = cast(dependence.first); - - assert(!dependence.second.empty() && - "unexpected producer but not dependences"); - AffineMap producerIndexingMap = producer.getIndexingMap( - dependence.second.front().dependentOpView.operandIndex); - AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), producerIndexingMap); - if (!prunedProducerIndexingMap.isPermutation()) - return {}; - - AffineMap consumerIndexingMap = op.getIndexingMap( - dependence.second.front().indexingOpView.operandIndex); - if (consumerIndexingMap.getNumResults() != - prunedProducerIndexingMap.getNumResults()) - return {}; - - LLVM_DEBUG({ - llvm::dbgs() << "\t producerMap : "; - producerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << " pruned : "; - prunedProducerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - AffineMap invProducerIndexMap = - inversePermutation(prunedProducerIndexingMap); - if (!invProducerIndexMap) - return {}; - - AffineMap consumerLoopToProducerLoop = - invProducerIndexMap.compose(consumerIndexingMap); - - LLVM_DEBUG({ - llvm::dbgs() << "\t consumerLoopToProducerLoop : "; - consumerLoopToProducerLoop.print(llvm::dbgs()); - }); - - std::set candidates; - for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { - AffineDimExpr dimExpr = expr.dyn_cast(); - if (!dimExpr) - continue; - unsigned position = dimExpr.getPosition(); - if (fusableLoops.count(position)) - candidates.insert(position); + + for (auto op : reverse(ops)) { + for (auto dependence : fusableDependences.lookup(op)) { + LLVM_DEBUG({ + llvm::dbgs() << "\t fusable :"; + for (unsigned i : fusableLoops) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + + Optional consumerLoopToProducerLoop = + getConsumerLoopToProducerLoopMap(dependence); + if (!consumerLoopToProducerLoop) { + op.emitRemark("failed to get map from consumer loop to producer loop"); + return {}; + } + // todo: This condition is only an implementation limitation. When fusing + // the operation, if the accesses in the producer/consumer are transposes + // of each other, the loop bounds for the tiled producer can be + // manipulated accordingly. This requires some additional bookkeeping in + // the implementation of tile+fuse that is defered to later. + if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { + op.emitRemark("unhandled fusion when fusion requires permutation"); + return {}; + } + + std::set candidates; + for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { + unsigned position = expr.cast().getPosition(); + if (fusableLoops.count(position)) + candidates.insert(position); + } + LLVM_DEBUG({ + llvm::dbgs() << "\t candidates :"; + for (unsigned i : candidates) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + if (candidates.empty()) + return {}; + std::swap(candidates, fusableLoops); } - LLVM_DEBUG({ - llvm::dbgs() << "\t candidates :"; - for (unsigned i : candidates) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - if (candidates.empty()) - return {}; - std::swap(candidates, fusableLoops); } return fusableLoops; @@ -683,60 +720,69 @@ /// Find all dependences that are to be fusable. static FusableOpDependencesTy -findAllFusableDependences(LinalgOp op, - const LinalgDependenceGraph &dependenceGraph, - const LinalgFusionOptions &fusionOptions) { +findAllFusableDependences(ArrayRef ops, + const LinalgDependenceGraph &dependenceGraph) { FusableOpDependencesTy fusableDependences; // TODO: Currently fusion would not be legal if the fusable dependence is to // the same producer but different indexing map in the consumer. Fix this, but // in the meanwhile disallow such a fusion. DenseMap fusedProducerIndexingMap; - for (auto operandIndex : fusionOptions.indicesToFuse) { - auto fusableDependence = - findFusableProducer(op, operandIndex, dependenceGraph); - if (!fusableDependence) - return FusableOpDependencesTy{}; - LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); - // Do not fuse dependences that are to operations not in the same basic - // block. This avoid moving fused operations across loops that might - // themselves carry dependency making the fusion illegal. - if (producerOp.getOperation()->getBlock() != - op.getOperation()->getBlock()) { - op.emitRemark("unhandled fusion of ops in different basic blocks"); - return FusableOpDependencesTy{}; - } - // Make sure that the indexing map of the view used for fusion in the - // producer is a projected permutation. - unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; - AffineMap producerMap = producerOp.getIndexingMap(producerIdx); - if (!producerMap.isProjectedPermutation()) { - op.emitRemark("unhandled non permutation indexing map for fused view in " - "producer for operand at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } + for (LinalgOp op : reverse(ops)) { + for (unsigned operandIndex : + llvm::seq(0, op.getNumInputsAndOutputBuffers())) { + Optional + fusableDependence = + findFusableProducer(op, operandIndex, dependenceGraph); + if (!fusableDependence) + continue; + LinalgOp producerOp = + cast(fusableDependence->dependentOpView.op); + // Do not fuse dependences that are to operations not in the same basic + // block. This avoid moving fused operations across loops that might + // themselves carry dependency making the fusion illegal. + if (producerOp.getOperation()->getBlock() != + op.getOperation()->getBlock()) { + op.emitRemark("unhandled fusion of ops in different basic blocks"); + return FusableOpDependencesTy{}; + } + // Make sure that the indexing map of the view used for fusion in the + // producer is a projected permutation. + unsigned producerIdx = fusableDependence->dependentOpView.operandIndex; + AffineMap producerMap = producerOp.getIndexingMap(producerIdx); + if (!producerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled non permutation indexing map for fused view in " + "producer for operand at index ") + << operandIndex; + return FusableOpDependencesTy{}; + } - unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; - AffineMap consumerMap = op.getIndexingMap(consumerIdx); - if (!consumerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled case where indexing map for fused view in the consumer is " - "not a projected permutation while fusing at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } + unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex; + AffineMap consumerMap = op.getIndexingMap(consumerIdx); + if (!consumerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled case where indexing map for fused view in the consumer " + "is " + "not a projected permuration while fusing at index ") + << operandIndex; + return FusableOpDependencesTy{}; + } - // Check if the producer is already a fusion candidate. Cannot fuse this - // dependence if it has a different indexing map when used in the consumer. - if (fusedProducerIndexingMap.count(producerOp.getOperation()) && - fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { - op.emitRemark("unhandled fusion to the same producer but with different " - "indexing maps"); - return FusableOpDependencesTy{}; - } - fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + // Check if the producer is already a fusion candidate. Cannot fuse this + // dependence if it has a different indexing map when used in the + // consumer. + if (fusedProducerIndexingMap.count(producerOp.getOperation()) && + fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { + op.emitRemark( + "unhandled fusion to the same producer but with different " + "indexing maps"); + return FusableOpDependencesTy{}; + } + fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; - fusableDependences[producerOp.getOperation()].push_back(*fusableDependence); + fusableDependences[producerOp.getOperation()].push_back( + *fusableDependence); + } } return fusableDependences; } @@ -747,136 +793,121 @@ return false; } +/// Tile the fused loops in the root operation, by setting the tile sizes for +/// all other loops to zero (those will be tiled later). +static Optional tileRootOperation( + OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, + const LinalgTilingOptions &options, const std::set &fusedLoops) { + SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); + auto zero = std_constant_index(0); + for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) { + if (!fusedLoops.count(i)) + tileSizes[i] = zero; + } + LinalgTilingOptions tileFusedLoopsOptions = options; + tileFusedLoopsOptions.setTileSizes(tileSizes); + return tileLinalgOp(builder, op, tileFusedLoopsOptions); +} + +/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected +/// to be a tiled operation such that it is valid to fuse all operations in +/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of +/// `tiledOp`. +static SmallVector +fuseOperations(OpBuilder &builder, LinalgOp tiledOp, + ArrayRef fusionCandidates, + const FusableOpDependencesTy &fusableDependences, + const std::set &fusedLoops) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(tiledOp); + DenseMap fusedLoopsAndRanges; + for (unsigned loop : fusedLoops) { + ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop); + fusedLoopsAndRanges[loop] = getRangeFromOperandShape( + builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); + } + SmallVector fusedOps(fusionCandidates.size()); + for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { + LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); + fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; + builder.setInsertionPoint(fusedOp); + } + return fusedOps; +} + template static Optional -tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, +tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - // Some of the tiling options might not be supportable with tile and fuse. - // TODO: Support interchange with tile + fuse. + const LinalgTilingOptions &tilingOptions) { + if (ops.empty()) + return llvm::None; + LinalgOp rootOp = ops.back(); + for (auto op : enumerate(ops)) { + // TODO: Nothing in the fusion of sequence of ops is specific to + // buffers. This check can be removed after it is tested on tensors. + LinalgOp linalgOp = op.value(); + if (!linalgOp.hasBufferSemantics()) { + linalgOp.emitError("tile and fuse only tested for buffer operation"); + return llvm::None; + } + } + // TODO: Support interchange with tile + fuse. This might actually help do + // better fusion. if (!tilingOptions.interchangeVector.empty()) { - op.emitError("unable to handle tile and fuse with interchange"); + rootOp.emitError("unable to handle tile and fuse with interchange"); return llvm::None; } - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); - ScopedContext scope(rewriter, op.getLoc()); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(rootOp); + ScopedContext scope(builder, rootOp.getLoc()); // Find all the producers. FusableOpDependencesTy fusableDependences = - findAllFusableDependences(op, dependenceGraph, fusionOptions); + findAllFusableDependences(ops, dependenceGraph); if (fusableDependences.empty()) return llvm::None; - // Enforce the convention that "tiling by zero" skips tiling a particular - // dimension. This convention is significantly simpler to handle instead of - // adjusting affine maps to account for missing dimensions. - auto nLoops = op.getNumLoops(); - SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(rewriter, op); - if (tileSizeVector.size() < nLoops) { - auto zero = std_constant_index(0); - tileSizeVector.append(nLoops - tileSizeVector.size(), zero); - } - TiledAndFusedLinalgOps ret; - // Find the loops that can be tiled and fused. - std::set tileFuseLoops = - collectTileAndFuseLoops(op, fusableDependences); + ret.fusedLoopDims = collectTileAndFuseLoops(ops, fusableDependences); // If there are no fusable dependences or there are no tile+fusable loops, // just return. - if (tileFuseLoops.empty()) { + if (ret.fusedLoopDims.empty()) { return llvm::None; } - // Get the tile sizes for the first and second tiling steps. For the first - // step the tile size are set to zero for the loops that arent - // fused. Similarly for the second step, the tile sizes are set to zero for - // the loops that are fused. For example, if for the following input - // - // ``` - // linalg.add ins(%a, %b) outs(%c) - // linalg.matmul ins(%d, %c) outs(%e) - // ``` - // - // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` - // respectively, and since only `j` can be tiled and fused. The tile sizes - // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable - // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile - // the tiled matmul generated by the first tiling step. - SmallVector tileAndFuseSizes, tileSizes; - for (auto tileSize : enumerate(tileSizeVector)) { - auto zero = std_constant_index(0); - if (tileFuseLoops.count(tileSize.index())) { - tileAndFuseSizes.push_back(tileSize.value()); - tileSizes.push_back(zero); - } else { - tileSizes.push_back(tileSize.value()); - tileAndFuseSizes.push_back(zero); - } - } - - // Tile for the loops that can be fused. - LinalgTilingOptions firstTilingOptions = tilingOptions; - firstTilingOptions.setTileSizes(tileAndFuseSizes); - Optional firstTiledOp = - tileLinalgOp(rewriter, op, firstTilingOptions); - if (!firstTiledOp) + // Tile the fused loops in the last operation in the list. + SmallVector tileSizeVector = + tilingOptions.tileSizeComputationFunction(builder, rootOp); + Optional tiledRootOp = tileRootOperation( + builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); + if (!tiledRootOp) { + rootOp.emitError("failed to tile the fused loops"); return llvm::None; - ret.op = firstTiledOp->op; - ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); - - rewriter.setInsertionPoint(ret.op); - // Fuse the operands. - for (auto dependence : fusableDependences) { - LinalgOp producerOp = cast(dependence.first); - unsigned producerIdx = - dependence.second.front().dependentOpView.operandIndex; - unsigned consumerIdx = - dependence.second.front().indexingOpView.operandIndex; - LinalgOp fusedOp = fuse(rewriter, producerOp, - producerOp.getOutputIndex(producerIdx).getValue(), - ret.op, consumerIdx); - ret.fusedProducers.push_back(fusedOp); - ret.originalProducers.push_back(producerOp); - } - - if (!llvm::all_of(tileSizes, isZero)) { - // Tile the remaining loops of the root operation. - LinalgTilingOptions secondTilingOptions = tilingOptions; - // The distribution is done only for the tile+fused loops. - secondTilingOptions.distribution = llvm::None; - secondTilingOptions.setTileSizes(tileSizes); - Optional secondTiledOp = - tileLinalgOp(rewriter, ret.op, secondTilingOptions); - if (!secondTiledOp) - return llvm::None; - ret.unfusedLoops.assign(secondTiledOp->loops.begin(), - secondTiledOp->loops.end()); - rewriter.eraseOp(ret.op); - ret.op = secondTiledOp->op; } + ret.op = tiledRootOp->op; + ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); + // Fuse the other operations into the fused inter-tile loops produced above. + ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(), + fusableDependences, ret.fusedLoopDims); return ret; } Optional -mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, +mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions) { + const LinalgTilingOptions &tilingOptions) { switch (tilingOptions.loopType) { case LinalgTilingLoopType::Loops: - return tileAndFuseLinalgOpsImpl(rewriter, op, dependenceGraph, - tilingOptions, fusionOptions); + return tileAndFuseLinalgOpsImpl(builder, ops, dependenceGraph, + tilingOptions); case LinalgTilingLoopType::ParallelLoops: return tileAndFuseLinalgOpsImpl( - rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + builder, ops, dependenceGraph, tilingOptions); default:; } return llvm::None; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -165,17 +165,71 @@ if (!linalgOp.hasBufferSemantics()) return failure(); + llvm::SetVector producers; + producers.insert(linalgOp); + for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) { + if (!fusionOptions.indicesToFuse.count( + dependence.indexingOpView.operandIndex)) + continue; + if (isa(dependence.dependentOpView.op)) + producers.insert(dependence.dependentOpView.op); + } + + SmallVector fusionOps; + for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; + ++it) { + auto producerLinalgOp = dyn_cast(&(*it)); + if (producerLinalgOp && producers.count(producerLinalgOp)) + fusionOps.push_back(producerLinalgOp); + } + fusionOps.push_back(linalgOp); + + // Tile and fuse only tiles the fused loops. So the unfused loops of the + // operation are to be tiled. Get the tile sizes passed + SmallVector tileSizes = + tilingOptions.tileSizeComputationFunction(rewriter, op); + LinalgTilingOptions instanceTilingOptions = tilingOptions; + instanceTilingOptions.setTileSizes(tileSizes); Optional tiledAndFusedOps = tileAndFuseLinalgOps( - rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + rewriter, fusionOps, dependenceGraph, instanceTilingOptions); if (!tiledAndFusedOps) return failure(); + + // Tile the unfused loops; + SmallVector unfusedLoopTileSizes; + for (auto tileSize : enumerate(tileSizes)) { + if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) + unfusedLoopTileSizes.push_back( + rewriter.create(op->getLoc(), 0)); + else + unfusedLoopTileSizes.push_back(tileSize.value()); + } + // Tile the loop only if there is a non-zero tile size. + if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) + unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); + if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { + if (auto cst = val.getDefiningOp()) + return cst.getValue() != 0; + return true; + })) { + LinalgTilingOptions unfusedTilingOptions = tilingOptions; + unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); + Optional unfusedTiledOp = + tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); + if (!unfusedTiledOp) + return failure(); + rewriter.eraseOp(tiledAndFusedOps->op); + tiledAndFusedOps->op = unfusedTiledOp->op; + } + marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); for (auto fusedOp : tiledAndFusedOps->fusedProducers) { fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); } - for (auto origProducerOp : tiledAndFusedOps->originalProducers) + for (auto origProducerOp : ArrayRef(fusionOps).drop_back()) { originalOpMarker.replaceLinalgMarker(rewriter, origProducerOp.getOperation()); + } rewriter.updateRootInPlace( op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); return success(); diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -47,7 +47,9 @@ // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]] // CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] -// CHECK: linalg.fill(%[[SV3]], %[[CST]]) +// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] +// CHECK: linalg.fill(%[[SV3_2]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" // CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { // CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] @@ -109,9 +111,12 @@ // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]] // CHECK-SAME: [%[[M]], %[[TILE_N_2]]] +// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]] // CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]] -// CHECK-SAME: [%[[K]], %[[TILE_N]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) +// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] +// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]] +// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) // CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer" // CHECK-NOT: linalg.fill // CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]] @@ -186,11 +191,16 @@ // CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] +// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N]]] +// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) +// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] +// CHECK: %[[SV3_2:.+]] = subview %[[ARG1]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" -// CHECK: linalg.fill(%[[SV2]], %[[CST]]) +// CHECK: linalg.fill(%[[SV2_2]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = @@ -261,15 +271,18 @@ // CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] +// CHECK: %[[K2_2:.+]] = dim %[[ARG1]], %[[C1]] // CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M]], %[[K1]]] -// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]] +// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2_2]]] +// CHECK: %[[SV1_2:.+]] = subview %[[ARG2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K2_2]]] // CHECK: linalg.matmul // CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" // CHECK-SAME: ins(%[[SV3]], %[[SV4]] // CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV1]] : memref) +// CHECK-SAME: outs(%[[SV1_2]] : memref) // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = // CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -0,0 +1,133 @@ +// RUN: mlir-opt -pass-pipeline="func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),canonicalize,cse" -split-input-file %s | FileCheck %s + +module { + func @three_op_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3 : memref) { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = dim %arg0, %c0 : memref + %d1 = dim %arg1, %c1 : memref + %0 = alloc(%d0, %d1) : memref + linalg.fill(%0, %cst) : memref, f32 + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%0 : memref) + linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg2 : memref, memref) + outs(%arg3 : memref) { + ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : + %5 = addf %arg4, %arg5 : f32 + linalg.yield %5 : f32 + } + return + } +} + +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func @three_op_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref +// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { +// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]] +// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]] +// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}}) +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_TEMP]] : memref) +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_ARG3]] : memref) +// CHECK: scf.yield +// CHECK: } + +// ----- + +module { + func @sequence_of_matmul(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref, + %arg4: memref) { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %m = dim %arg0, %c0 : memref + %n1 = dim %arg1, %c1 : memref + %n2 = dim %arg2, %c1 : memref + %n3 = dim %arg3, %c1 : memref + %0 = alloc(%m, %n1) : memref + %1 = alloc(%m, %n2) : memref + linalg.fill(%0, %cst) : memref, f32 + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%0 : memref) + linalg.fill(%1, %cst) : memref, f32 + linalg.matmul ins(%0, %arg2 : memref, memref) + outs(%1 : memref) + linalg.fill(%arg4, %cst) : memref, f32 + linalg.matmul ins(%1, %arg3 : memref, memref) + outs(%arg4 : memref) + return + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK: func @sequence_of_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]] +// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]]) +// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]]) +// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) +// CHECK-SAME: step (%[[C16]]) { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N2]]] +// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] +// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]] +// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] +// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N3]]] +// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N1]]] +// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]] +// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]] +// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] +// CHECK: linalg.fill(%[[SV_ALLOC1]], %{{.+}}) +// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[SV_ARG1]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_ALLOC1]] : memref) +// CHECK: linalg.fill(%[[SV_ALLOC2]], %{{.+}}) +// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[SV_ARG2]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_ALLOC2]] : memref) +// CHECK: linalg.fill(%[[SV_ARG4_2]], %{{.+}}) +// CHECK: linalg.matmul ins(%[[SV_ALLOC2]], %[[ARG3]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_ARG4]] : memref) +// CHECK: scf.yield +// CHECK: } + diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -197,6 +197,44 @@ } } }; + +/// Pass to test tile and fuse of sequence of operations. Intended only for +/// testing. +struct TestLinalgTileAndFuseSequencePass + : public PassWrapper { + TestLinalgTileAndFuseSequencePass() = default; + TestLinalgTileAndFuseSequencePass( + const TestLinalgTileAndFuseSequencePass &pass){}; + + ListOption tileSizes{ + *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + FuncOp funcOp = getOperation(); + auto &blocks = funcOp.getBody().getBlocks(); + if (!llvm::hasSingleElement(blocks)) { + return; + } + SmallVector linalgOps = + llvm::to_vector<2>(blocks.front().getOps()); + Aliases aliases; + LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + OpBuilder builder(funcOp.getContext()); + Optional tileAndFuseOps = tileAndFuseLinalgOps( + builder, linalgOps, dependenceGraph, + LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( + LinalgTilingLoopType::ParallelLoops)); + if (!tileAndFuseOps) + return signalPassFailure(); + for (auto op : linalgOps) + op.erase(); + } +}; } // namespace namespace mlir { @@ -211,5 +249,12 @@ "test-linalg-greedy-fusion", "Test Linalg fusion by applying a greedy test transformation."); } +void registerTestLinalgTileAndFuseSequencePass() { + PassRegistration + testTileAndFuseSequencePass( + "test-linalg-tile-and-fuse", + "Test Linalg tiling and fusion of a sequence of Linalg operations."); +} + } // namespace test } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -74,6 +74,7 @@ void registerTestLinalgFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); +void registerTestLinalgTileAndFuseSequencePass(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -140,6 +141,7 @@ test::registerTestLinalgFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); + test::registerTestLinalgTileAndFuseSequencePass(); test::registerTestLinalgTransforms(); test::registerTestLivenessPass(); test::registerTestLoopFusion();