diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -39,9 +39,8 @@ struct TiledAndFusedLinalgOps { LinalgOp op; SmallVector fusedProducers; - SmallVector originalProducers; SmallVector fusedLoops; - SmallVector unfusedLoops; + SmallVector, 1> unfusedLoops; }; /// Populates patterns for vectorization of all ConvN-D ops. @@ -72,14 +71,12 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options); -/// Tile and fuse the `op` with its producers. The tile and fuse proceeds in -/// three steps -/// - Find tile loops that are fusable with its producer tile loops (a.k.a. tile -/// + fuse loops). -/// - Tile just these loops of the consumer (root operation) and fuse with -/// the producer. -/// - Tile again the tiled consumer operation produced above to do rest of -/// the tiling specified by the `tilingOptions`. +/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This +/// proceeds as follows: +/// - Find outer parallel loops in these ops that can be fused. +/// - Tile fusable outer parallel loops of the last operation in the sequence. +/// - Fuse the remaining operations with the tiled operation +/// - Tile the unfused loops in each of the fused operations id needed. /// /// For example, consider the sequence of matmul below /// @@ -110,32 +107,35 @@ /// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) /// outs(%0 : memref<16x32xf32, #map0>) /// scf.parallel (%arg6) = (%c0) to (%c32) step (%c8) { -/// scf.for %arg7 = %c0 to %c32 step %c4 { -/// %4 = subview %0[0, %arg7] [16, 4] [1, 1] -/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0> -/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1] -/// : memref<32x32xf32> to memref<4x8xf32, #map0> -/// %6 = subview %1[0, %arg6] [16, 8] [1, 1] -/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0> -/// linalg.matmul -/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) -/// outs(%6 : memref<16x8xf32, #map0>) +/// scf.for %arg7 = %c0 to %c32 step %c4 { +/// %4 = subview %0[0, %arg7] [16, 4] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x4xf32, #map0> +/// %5 = subview %arg3[%arg7, %arg6] [4, 8] [1, 1] +/// : memref<32x32xf32> to memref<4x8xf32, #map0> +/// %6 = subview %1[0, %arg6] [16, 8] [1, 1] +/// : memref<16x32xf32, #map0> to memref<16x8xf32, #map0> +/// linalg.matmul +/// ins(%4, %5 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) +/// outs(%6 : memref<16x8xf32, #map0>) /// } /// scf.yield /// } /// scf.yield /// } /// -/// The following tiling options are handled differently in tile+fuse (compared -/// to tile only) +/// `tilingOptions` are used to tile the corresponding operation in `ops` (the +/// size of the former should be same as size of the latter. Based on how +/// tile+fuse is implemented, the fused loops are generated based on the last +/// operation in the sequence. For example, the tile sizes for the fused loops +/// is obtained from `tilingOptions.back()`. The following tiling options are +/// handled differently in tile+fuse (compared to tile only) /// - Interchange of the tiling loops is not supported right now. -/// - Distribution is only done for the tile+fuse loops. The tiled loops +/// - Distribution is only done for the fused loops. The tiled loops /// generated by the second tiling is not distributed. Optional -tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, +tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions); + ArrayRef tilingOptions); /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. /// This is an in-place transformation controlled by `interchangeVector`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include #include #define DEBUG_TYPE "linalg-fusion" @@ -190,48 +191,18 @@ llvm_unreachable("Expect to be able to extract a shape defining loop range"); } -/// Fuses the producer of `producerIdx` into the loop immediately enclosing -/// `consumer`. This is achieved by "recomputing" the `producer` at the time it -/// is needed just before the `consumer. -/// -/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are -/// 2 cases: -/// 1. Buffer case: `producerIdx` is the index of the buffer in -/// `producer.getOutputBuffers()`. -/// 2. Tensor case: `producerIdx` is the index of the tensor in -/// `producer.getResults()`. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, - LinalgOp consumer, unsigned consumerIdx) { - Operation *shapeProducingOp = - consumer.getShapedOperand(consumerIdx).getDefiningOp(); - assert((isa(shapeProducingOp) || - isa(shapeProducingOp)) && - "SubviewOp or SubTensorOp expected"); - - // loopToOperandRangesMaps are permutations-only by construction: - // we can always identify a data dimension with a (at least one) loop - // dimension. - // TODO: extend this with range inference. - AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); - LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx - << ", producer map: " << producerMap << "\n"); +/// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges` +/// provides the loop range information for the fused loops. The rest are +/// obtained from the producer itself, since they are not tiled + fused. +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, + const DenseMap &fusedLoopsAndRanges) { unsigned nPar = producer.getNumParallelLoops(); unsigned nRed = producer.getNumReductionLoops(); unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); - - // Iterate over dimensions identified by the producer map for `producerIdx`. - // This defines a subset of the loop ranges that we need to complete later. - auto loc = consumer.getLoc(); - for (auto en : llvm::enumerate(producerMap.getResults())) { - unsigned posInProducerLoop = en.value().cast().getPosition(); - loopRanges[posInProducerLoop] = - isa(shapeProducingOp) - ? cast(shapeProducingOp) - .getOrCreateRanges(b, loc)[en.index()] - : cast(shapeProducingOp) - .getOrCreateRanges(b, loc)[en.index()]; + for (auto fusedLoops : fusedLoopsAndRanges) { + loopRanges[fusedLoops.first] = fusedLoops.second; } // Iterate over all dimensions. For the dimensions not identified by the @@ -250,7 +221,48 @@ } } - return cloneWithLoopRanges(b, loc, producer, loopRanges); + return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges); +} + +/// Helper function to get the loop range for a dimension `dim` based on the +/// `shapedOperand`. It is expected to be defined by a subview op or a subtensor +/// op. +static Range getRangeFromOperandShape(OpBuilder &b, Location loc, + Value shapedOperand, unsigned dim) { + Operation *shapeProducingOp = shapedOperand.getDefiningOp(); + if (auto subViewOp = dyn_cast(shapeProducingOp)) { + return subViewOp.getOrCreateRanges(b, loc)[dim]; + } + if (auto subTensorOp = dyn_cast(shapeProducingOp)) { + return subTensorOp.getOrCreateRanges(b, loc)[dim]; + } + llvm_unreachable("SubviewOp or SubTensorOp expected"); +} + +/// Fuses the producer of `producerIdx` into the loop immediately enclosing +/// `consumer`. This is achieved by "recomputing" the `producer` at the time it +/// is needed just before the `consumer. +/// +/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are +/// 2 cases: +/// 1. Buffer case: `producerIdx` is the index of the buffer in +/// `producer.getOutputBuffers()`. +/// 2. Tensor case: `producerIdx` is the index of the tensor in +/// `producer.getResults()`. +static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, + LinalgOp consumer, unsigned consumerIdx) { + AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); + LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx + << ", producer map: " << producerMap << "\n"); + DenseMap fusedLoopsAndRanges; + Location loc = consumer.getLoc(); + Value shapedOperand = consumer.getShapedOperand(consumerIdx); + for (auto en : llvm::enumerate(producerMap.getResults())) { + unsigned posInProducerLoop = en.value().cast().getPosition(); + fusedLoopsAndRanges[posInProducerLoop] = + getRangeFromOperandShape(b, loc, shapedOperand, en.index()); + } + return fuse(b, producer, fusedLoopsAndRanges); } // Encode structural fusion safety preconditions. @@ -601,8 +613,9 @@ /// submap with only parallel loops = affine_map<(i, j) -> (j)> /// Fused dimensions : j static std::set -collectTileAndFuseLoops(LinalgOp op, +collectTileAndFuseLoops(ArrayRef ops, const FusableOpDependencesTy &fusableDependences) { + assert(!ops.empty()); auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { return linalgOp.iterator_types() .getValue() @@ -613,79 +626,73 @@ .size(); }; - LLVM_DEBUG({ - llvm::dbgs() << "Op : "; - op.getOperation()->print(llvm::dbgs(), OpPrintingFlags().useLocalScope()); - llvm::dbgs() << "\n"; - }); - - size_t numOuterParallelLoops = getNumOuterParallelLoops(op); - for (auto dependence : fusableDependences) { - linalg::LinalgOp producer = cast(dependence.first); + size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); + for (auto op : ops.drop_back()) { numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(producer)); + std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); } std::set fusableLoops; auto range = llvm::seq(0, numOuterParallelLoops); fusableLoops.insert(range.begin(), range.end()); - for (auto dependence : fusableDependences) { - LLVM_DEBUG({ - llvm::dbgs() << "\t fusable :"; - for (unsigned i : fusableLoops) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - linalg::LinalgOp producer = cast(dependence.first); - - assert(!dependence.second.empty() && - "unexpected producer but not dependences"); - AffineMap producerIndexingMap = producer.getIndexingMap( - dependence.second.front().dependentOpView.operandNum); - AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), producerIndexingMap); - if (!prunedProducerIndexingMap.isPermutation()) - return {}; - - AffineMap consumerIndexingMap = - op.getIndexingMap(dependence.second.front().indexingOpView.operandNum); - if (consumerIndexingMap.getNumResults() != - prunedProducerIndexingMap.getNumResults()) - return {}; - - LLVM_DEBUG({ - llvm::dbgs() << "\t producerMap : "; - producerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << " pruned : "; - prunedProducerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - AffineMap invProducerIndexMap = - inversePermutation(prunedProducerIndexingMap); - if (!invProducerIndexMap) - return {}; - - AffineMap consumerLoopToProducerLoop = - invProducerIndexMap.compose(consumerIndexingMap); - std::set candidates; - for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { - unsigned position = expr.cast().getPosition(); - if (fusableLoops.count(position)) - candidates.insert(position); + + for (auto op : reverse(ops)) { + for (auto dependence : fusableDependences.lookup(op)) { + LLVM_DEBUG({ + llvm::dbgs() << "\t fusable :"; + for (unsigned i : fusableLoops) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + + AffineMap producerIndexingMap = + op.getIndexingMap(dependence.dependentOpView.operandNum); + AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( + op.iterator_types().getValue(), producerIndexingMap); + if (!prunedProducerIndexingMap.isPermutation()) + return {}; + + LinalgOp consumer = cast(dependence.indexingOpView.op); + AffineMap consumerIndexingMap = + consumer.getIndexingMap(dependence.indexingOpView.operandNum); + if (consumerIndexingMap.getNumResults() != + prunedProducerIndexingMap.getNumResults()) + return {}; + + LLVM_DEBUG({ + llvm::dbgs() << "\t producerMap : "; + producerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << " pruned : "; + prunedProducerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + llvm::dbgs() << "\t consumerMap : "; + consumerIndexingMap.print(llvm::dbgs()); + llvm::dbgs() << "\n"; + }); + + AffineMap invProducerIndexMap = + inversePermutation(prunedProducerIndexingMap); + if (!invProducerIndexMap) + return {}; + + AffineMap consumerLoopToProducerLoop = + invProducerIndexMap.compose(consumerIndexingMap); + std::set candidates; + for (AffineExpr expr : consumerLoopToProducerLoop.getResults()) { + unsigned position = expr.cast().getPosition(); + if (fusableLoops.count(position)) + candidates.insert(position); + } + LLVM_DEBUG({ + llvm::dbgs() << "\t candidates :"; + for (unsigned i : candidates) + llvm::dbgs() << " " << i; + llvm::dbgs() << "\n"; + }); + if (candidates.empty()) + return {}; + std::swap(candidates, fusableLoops); } - LLVM_DEBUG({ - llvm::dbgs() << "\t candidates :"; - for (unsigned i : candidates) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - if (candidates.empty()) - return {}; - std::swap(candidates, fusableLoops); } return fusableLoops; @@ -693,60 +700,69 @@ /// Find all dependences that are to be fusable. static FusableOpDependencesTy -findAllFusableDependences(LinalgOp op, - const LinalgDependenceGraph &dependenceGraph, - const LinalgFusionOptions &fusionOptions) { +findAllFusableDependences(ArrayRef ops, + const LinalgDependenceGraph &dependenceGraph) { FusableOpDependencesTy fusableDependences; // TODO: Currently fusion would not be legal if the fusable dependence is to // the same producer but different indexing map in the consumer. Fix this, but // in the meanwhile disallow such a fusion. DenseMap fusedProducerIndexingMap; - for (auto operandIndex : fusionOptions.indicesToFuse) { - auto fusableDependence = - findFusableProducer(op, operandIndex, dependenceGraph); - if (!fusableDependence) - return FusableOpDependencesTy{}; - LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); - // Do not fuse dependences that are to operations not in the same basic - // block. This avoid moving fused operations across loops that might - // themselves carry dependency making the fusion illegal. - if (producerOp.getOperation()->getBlock() != - op.getOperation()->getBlock()) { - op.emitRemark("unhandled fusion of ops in different basic blocks"); - return FusableOpDependencesTy{}; - } - // Make sure that the indexing map of the view used for fusion in the - // producer is a projected permutation. - unsigned producerIdx = fusableDependence->dependentOpView.operandNum; - AffineMap producerMap = producerOp.getIndexingMap(producerIdx); - if (!producerMap.isProjectedPermutation()) { - op.emitRemark("unhandled non permutation indexing map for fused view in " - "producer for operand at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } + for (LinalgOp op : reverse(ops)) { + for (unsigned operandNum : + llvm::seq(0, op.getNumInputsAndOutputBuffers())) { + Optional + fusableDependence = + findFusableProducer(op, operandNum, dependenceGraph); + if (!fusableDependence) + continue; + LinalgOp producerOp = + cast(fusableDependence->dependentOpView.op); + // Do not fuse dependences that are to operations not in the same basic + // block. This avoid moving fused operations across loops that might + // themselves carry dependency making the fusion illegal. + if (producerOp.getOperation()->getBlock() != + op.getOperation()->getBlock()) { + op.emitRemark("unhandled fusion of ops in different basic blocks"); + return FusableOpDependencesTy{}; + } + // Make sure that the indexing map of the view used for fusion in the + // producer is a projected permutation. + unsigned producerIdx = fusableDependence->dependentOpView.operandNum; + AffineMap producerMap = producerOp.getIndexingMap(producerIdx); + if (!producerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled non permutation indexing map for fused view in " + "producer for operand at index ") + << operandNum; + return FusableOpDependencesTy{}; + } - unsigned consumerIdx = fusableDependence->indexingOpView.operandNum; - AffineMap consumerMap = op.getIndexingMap(consumerIdx); - if (!consumerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled case where indexing map for fused view in the consumer is " - "not a projected permutation while fusing at index ") - << operandIndex; - return FusableOpDependencesTy{}; - } + unsigned consumerIdx = fusableDependence->indexingOpView.operandNum; + AffineMap consumerMap = op.getIndexingMap(consumerIdx); + if (!consumerMap.isProjectedPermutation()) { + op.emitRemark( + "unhandled case where indexing map for fused view in the consumer " + "is " + "not a projected permuration while fusing at index ") + << operandNum; + return FusableOpDependencesTy{}; + } - // Check if the producer is already a fusion candidate. Cannot fuse this - // dependence if it has a different indexing map when used in the consumer. - if (fusedProducerIndexingMap.count(producerOp.getOperation()) && - fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { - op.emitRemark("unhandled fusion to the same producer but with different " - "indexing maps"); - return FusableOpDependencesTy{}; - } - fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + // Check if the producer is already a fusion candidate. Cannot fuse this + // dependence if it has a different indexing map when used in the + // consumer. + if (fusedProducerIndexingMap.count(producerOp.getOperation()) && + fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { + op.emitRemark( + "unhandled fusion to the same producer but with different " + "indexing maps"); + return FusableOpDependencesTy{}; + } + fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; - fusableDependences[producerOp.getOperation()].push_back(*fusableDependence); + fusableDependences[producerOp.getOperation()].push_back( + *fusableDependence); + } } return fusableDependences; } @@ -757,46 +773,110 @@ return false; } +/// Tile the fused loops in the root operation, by setting the tile sizes for +/// all other loops to zero (those will be tiled later). +static Optional tileRootOperation( + OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, + const LinalgTilingOptions &options, const std::set &fusedLoops) { + SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); + auto zero = std_constant_index(0); + for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) { + if (!fusedLoops.count(i)) + tileSizes[i] = zero; + } + LinalgTilingOptions tileFusedLoopsOptions = options; + tileFusedLoopsOptions.setTileSizes(tileSizes); + return tileLinalgOp(builder, op, tileFusedLoopsOptions); +} + +/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected +/// to be a tiled operation such that it is valid to fuse all operations in +/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of +/// `tiledOp`. +static SmallVector +fuseOperations(OpBuilder &builder, LinalgOp tiledOp, + ArrayRef fusionCandidates, + const std::set &fusedLoops) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(tiledOp); + DenseMap fusedLoopsAndRanges; + for (unsigned loop : fusedLoops) { + ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop); + fusedLoopsAndRanges[loop] = getRangeFromOperandShape( + builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); + } + SmallVector fusedOps(fusionCandidates.size()); + for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) { + LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges); + fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; + builder.setInsertionPoint(fusedOp); + } + return fusedOps; +} + +/// Post fusion, tile all the unfused loops of the fused operations. +static Optional tileUnfusedLoops( + OpBuilder &builder, LinalgOp op, ArrayRef tileSizeVector, + const LinalgTilingOptions &options, const std::set &fusedLoops) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(op); + SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); + auto zero = std_constant_index(0); + for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) { + if (fusedLoops.count(i)) + tileSizes[i] = zero; + } + tileSizes.resize(op.getNumLoops()); + if (llvm::all_of(tileSizes, isZero)) + return TiledLinalgOp{op, {}, {}}; + LinalgTilingOptions tileUnfusedLoopsOptions = options; + tileUnfusedLoopsOptions.setTileSizes(tileSizes); + tileUnfusedLoopsOptions.distribution = llvm::None; + return tileLinalgOp(builder, op, tileUnfusedLoopsOptions); +} + template static Optional -tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, +tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions) { - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - // Some of the tiling options might not be supportable with tile and fuse. - // TODO: Support interchange with tile + fuse. - if (!tilingOptions.interchangeVector.empty()) { - op.emitError("unable to handle tile and fuse with interchange"); + ArrayRef tilingOptions) { + if (ops.empty()) return llvm::None; + LinalgOp rootOp = ops.back(); + if (ops.size() != tilingOptions.size()) { + rootOp.emitError( + "need tiling options for all the operations that are to be fused"); + return llvm::None; + } + for (auto op : enumerate(ops)) { + // TODO: Nothing in the fusion of sequence of ops is specific to + // buffers. This check can be removed after it is tested on tensors. + LinalgOp linalgOp = op.value(); + if (!linalgOp.hasBufferSemantics()) { + linalgOp.emitError("tile and fuse only tested for buffer operation"); + return llvm::None; + } + // TODO: Support interchange with tile + fuse. This might actually help do + // better fusion. + if (!tilingOptions[op.index()].interchangeVector.empty()) { + linalgOp.emitError("unable to handle tile and fuse with interchange"); + return llvm::None; + } } - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(op); - ScopedContext scope(rewriter, op.getLoc()); + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPoint(rootOp); + ScopedContext scope(builder, rootOp.getLoc()); // Find all the producers. FusableOpDependencesTy fusableDependences = - findAllFusableDependences(op, dependenceGraph, fusionOptions); + findAllFusableDependences(ops, dependenceGraph); if (fusableDependences.empty()) return llvm::None; - // Enforce the convention that "tiling by zero" skips tiling a particular - // dimension. This convention is significantly simpler to handle instead of - // adjusting affine maps to account for missing dimensions. - auto nLoops = op.getNumLoops(); - SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(rewriter, op); - if (tileSizeVector.size() < nLoops) { - auto zero = std_constant_index(0); - tileSizeVector.append(nLoops - tileSizeVector.size(), zero); - } - - TiledAndFusedLinalgOps ret; - // Find the loops that can be tiled and fused. std::set tileFuseLoops = - collectTileAndFuseLoops(op, fusableDependences); + collectTileAndFuseLoops(ops, fusableDependences); // If there are no fusable dependences or there are no tile+fusable loops, // just return. @@ -804,87 +884,73 @@ return llvm::None; } - // Get the tile sizes for the first and second tiling steps. For the first - // step the tile size are set to zero for the loops that arent - // fused. Similarly for the second step, the tile sizes are set to zero for - // the loops that are fused. For example, if for the following input - // - // ``` - // linalg.add ins(%a, %b) outs(%c) - // linalg.matmul ins(%d, %c) outs(%e) - // ``` - // - // if the tile sizes of the `{i, j, k}` loops where given as `{ti, tj, tk}` - // respectively, and since only `j` can be tiled and fused. The tile sizes - // would be `{0, t_j, 0}` for the first tiling that tiles just the fusable - // loops. The second tiling would be use tile sizes of `{t_i, 0, t_k}` to tile - // the tiled matmul generated by the first tiling step. - SmallVector tileAndFuseSizes, tileSizes; - for (auto tileSize : enumerate(tileSizeVector)) { - auto zero = std_constant_index(0); - if (tileFuseLoops.count(tileSize.index())) { - tileAndFuseSizes.push_back(tileSize.value()); - tileSizes.push_back(zero); - } else { - tileSizes.push_back(tileSize.value()); - tileAndFuseSizes.push_back(zero); + // Tile the fused loops in the last operation in the list. + SmallVector tileSizeVector = + tilingOptions.back().tileSizeComputationFunction(builder, rootOp); + TiledAndFusedLinalgOps ret; + Optional tiledRootOp = tileRootOperation( + builder, rootOp, tileSizeVector, tilingOptions.back(), tileFuseLoops); + if (!tiledRootOp) { + rootOp.emitError("failed to tile the fused loops"); + return llvm::None; + } + ret.op = tiledRootOp->op; + ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); + + // Fuse the other operations into the fused inter-tile loops produced above. + ret.fusedProducers = + fuseOperations(builder, ret.op, ops.drop_back(), tileFuseLoops); + + // Tile the unfused loops for the fused ops. + ret.unfusedLoops.resize(ops.size()); + for (auto op : enumerate(ret.fusedProducers)) { + const LinalgTilingOptions &option = tilingOptions[op.index()]; + if (!option.tileSizeComputationFunction) + continue; + SmallVector producerTileSizeVector = + option.tileSizeComputationFunction(builder, op.value()); + Optional tiledOp = tileUnfusedLoops( + builder, op.value(), producerTileSizeVector, option, tileFuseLoops); + if (!tiledOp) { + op.value().emitError("unable to tile unfused loops"); + return llvm::None; + } + if (tiledOp->op != op.value()) { + builder.eraseOp(op.value()); + ret.fusedProducers[op.index()] = tiledOp->op; } + ret.unfusedLoops[op.index()].assign(tiledOp->loops.begin(), + tiledOp->loops.end()); + builder.setInsertionPoint(tiledOp->op); } - // Tile for the loops that can be fused. - LinalgTilingOptions firstTilingOptions = tilingOptions; - firstTilingOptions.setTileSizes(tileAndFuseSizes); - Optional firstTiledOp = - tileLinalgOp(rewriter, op, firstTilingOptions); - if (!firstTiledOp) + // Tile the unfused loops for the tiled root operation. + Optional rootTileUnfusedLoopsOp = tileUnfusedLoops( + builder, ret.op, tileSizeVector, tilingOptions.back(), tileFuseLoops); + if (!rootTileUnfusedLoopsOp) { + ret.op.emitError("unable to tile unfused loops"); return llvm::None; - ret.op = firstTiledOp->op; - ret.fusedLoops.assign(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); - - rewriter.setInsertionPoint(ret.op); - // Fuse the operands. - for (auto dependence : fusableDependences) { - LinalgOp producerOp = cast(dependence.first); - unsigned producerIdx = dependence.second.front().dependentOpView.operandNum; - unsigned consumerIdx = dependence.second.front().indexingOpView.operandNum; - LinalgOp fusedOp = fuse(rewriter, producerOp, - producerOp.getOutputIndex(producerIdx).getValue(), - ret.op, consumerIdx); - ret.fusedProducers.push_back(fusedOp); - ret.originalProducers.push_back(producerOp); - } - - if (!llvm::all_of(tileSizes, isZero)) { - // Tile the remaining loops of the root operation. - LinalgTilingOptions secondTilingOptions = tilingOptions; - // The distribution is done only for the tile+fused loops. - secondTilingOptions.distribution = llvm::None; - secondTilingOptions.setTileSizes(tileSizes); - Optional secondTiledOp = - tileLinalgOp(rewriter, ret.op, secondTilingOptions); - if (!secondTiledOp) - return llvm::None; - ret.unfusedLoops.assign(secondTiledOp->loops.begin(), - secondTiledOp->loops.end()); - rewriter.eraseOp(ret.op); - ret.op = secondTiledOp->op; } - + if (rootTileUnfusedLoopsOp->op != ret.op) { + builder.eraseOp(ret.op); + ret.op = rootTileUnfusedLoopsOp->op; + } + ret.unfusedLoops.back().assign(rootTileUnfusedLoopsOp->loops.begin(), + rootTileUnfusedLoopsOp->loops.end()); return ret; } -Optional -mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions, - const LinalgFusionOptions &fusionOptions) { - switch (tilingOptions.loopType) { +Optional mlir::linalg::tileAndFuseLinalgOps( + OpBuilder &builder, ArrayRef ops, + const LinalgDependenceGraph &dependenceGraph, + ArrayRef tilingOptions) { + switch (tilingOptions.back().loopType) { case LinalgTilingLoopType::Loops: - return tileAndFuseLinalgOpsImpl(rewriter, op, dependenceGraph, - tilingOptions, fusionOptions); + return tileAndFuseLinalgOpsImpl(builder, ops, dependenceGraph, + tilingOptions); case LinalgTilingLoopType::ParallelLoops: return tileAndFuseLinalgOpsImpl( - rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + builder, ops, dependenceGraph, tilingOptions); default:; } return llvm::None; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -165,17 +165,42 @@ if (!linalgOp.hasBufferSemantics()) return failure(); + llvm::SetVector producers; + producers.insert(linalgOp); + for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) { + if (!fusionOptions.indicesToFuse.count( + dependence.indexingOpView.operandNum)) + continue; + if (isa(dependence.dependentOpView.op)) { + producers.insert(dependence.dependentOpView.op); + } + } + + SmallVector fusionOps; + for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; + ++it) { + auto producerLinalgOp = dyn_cast(&(*it)); + if (producerLinalgOp && producers.count(producerLinalgOp)) + fusionOps.push_back(producerLinalgOp); + } + + SmallVector tilingOptionsVec(fusionOps.size(), + LinalgTilingOptions()); + fusionOps.push_back(linalgOp); + tilingOptionsVec.push_back(tilingOptions); + Optional tiledAndFusedOps = tileAndFuseLinalgOps( - rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + rewriter, fusionOps, dependenceGraph, tilingOptionsVec); if (!tiledAndFusedOps) return failure(); marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); for (auto fusedOp : tiledAndFusedOps->fusedProducers) { fusedOpMarker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); } - for (auto origProducerOp : tiledAndFusedOps->originalProducers) + for (auto origProducerOp : ArrayRef(fusionOps).drop_back()) { originalOpMarker.replaceLinalgMarker(rewriter, origProducerOp.getOperation()); + } rewriter.updateRootInPlace( op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); return success(); diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -47,7 +47,9 @@ // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]] // CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] // CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] -// CHECK: linalg.fill(%[[SV3]], %[[CST]]) +// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] +// CHECK: linalg.fill(%[[SV3_2]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" // CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { // CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] @@ -109,9 +111,12 @@ // CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]] // CHECK-SAME: [%[[M]], %[[TILE_N_2]]] +// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]] // CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]] -// CHECK-SAME: [%[[K]], %[[TILE_N]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) +// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] +// CHECK: %[[SV3_2:.+]] = subview %[[ARG2]][0, %[[IV0]]] +// CHECK-SAME: [%[[K_2]], %[[TILE_N]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV3_2]]) // CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_producer" // CHECK-NOT: linalg.fill // CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]] @@ -186,11 +191,16 @@ // CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] +// CHECK: %[[SV2_2:.+]] = subview %[[ARG3]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N]]] +// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K]]] -// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) +// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] +// CHECK: %[[SV1_2:.+]] = subview %[[ARG1]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV1_2]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" -// CHECK: linalg.fill(%[[SV2]], %[[CST]]) +// CHECK: linalg.fill(%[[SV2_2]], %[[CST]]) // CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_producer" // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = @@ -253,23 +263,26 @@ // CHECK: scf.parallel (%[[IV0:.+]]) = // CHECK-SAME: (%[[C0]]) to (%[[M]]) step (%[[C32]]) { // CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[K2:.+]] = dim %[[ARG2]], %[[C1]] +// CHECK: %[[K:.+]] = dim %[[ARG2]], %[[C1]] // CHECK: %[[SV1:.+]] = subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K2]]] +// CHECK-SAME: [%[[TILE_M]], %[[K]]] // CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]] // CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] // CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]] // CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0] // CHECK-SAME: [%[[TILE_M_2]], %[[N]]] -// CHECK: %[[K1:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[N_2:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]] // CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K1]]] -// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K1]], %[[K2]]] +// CHECK-SAME: [%[[TILE_M]], %[[K_2]]] +// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K_2]], %[[N_2]]] +// CHECK: %[[SV5:.+]] = subview %[[ARG2]][%[[IV0]], 0] +// CHECK: [%[[TILE_M]], %[[N_2]]] // CHECK: linalg.matmul // CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" // CHECK-SAME: ins(%[[SV3]], %[[SV4]] // CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV1]] : memref) +// CHECK-SAME: outs(%[[SV5]] : memref) // CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]] // CHECK: scf.parallel (%[[IV1:.+]]) = // CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-sequence.mlir @@ -0,0 +1,207 @@ +// RUN: mlir-opt -pass-pipeline="test-linalg-tile-and-fuse{tile-sizes=16,32,64},canonicalize,cse" -split-input-file %s | FileCheck %s + +module { + func @three_op_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3 : memref) { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = dim %arg0, %c0 : memref + %d1 = dim %arg1, %c1 : memref + %0 = alloc(%d0, %d1) : memref + linalg.fill(%0, %cst) : memref, f32 + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%0 : memref) + linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg2 : memref, memref) + outs(%arg3 : memref) { + ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : + %5 = addf %arg4, %arg5 : f32 + linalg.yield %5 : f32 + } + return + } +} + +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func @three_op_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK: %[[TEMP:.+]] = alloc(%{{.*}}, %{{.*}}) : memref +// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { +// CHECK-DAG: %[[SV_TEMP:.+]] = subview %[[TEMP]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[SV_ARG2:.+]] = subview %[[ARG2]][%[[IV1]]] +// CHECK-DAG: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, %[[IV1]]] +// CHECK: linalg.fill(%[[SV_TEMP]], %{{.+}}) +// CHECK: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %{{.+}} { +// CHECK: %[[SV_ARG0_K:.+]] = subview %[[SV_ARG0]][0, %[[IV2]]] +// CHECK: %[[SV_ARG1_K:.+]] = subview %[[SV_ARG1]][%[[IV2]], 0] +// CHECK: linalg.matmul +// CHECK-SAME: ins(%[[SV_ARG0_K]], %[[SV_ARG1_K]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_TEMP]] : memref) +// CHECK: } +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[SV_TEMP]], %[[SV_ARG2]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_ARG3]] : memref) +// CHECK: scf.yield +// CHECK: } + +// ----- + +module { + func @sequence_of_matmul(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref, + %arg4: memref) { + %cst = constant 0.000000e+00 : f32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %m = dim %arg0, %c0 : memref + %n1 = dim %arg1, %c1 : memref + %n2 = dim %arg2, %c1 : memref + %n3 = dim %arg3, %c1 : memref + %0 = alloc(%m, %n1) : memref + %1 = alloc(%m, %n2) : memref + linalg.fill(%0, %cst) : memref, f32 + linalg.matmul ins(%arg0, %arg1 : memref, memref) + outs(%0 : memref) + linalg.fill(%1, %cst) : memref, f32 + linalg.matmul ins(%0, %arg2 : memref, memref) + outs(%1 : memref) + linalg.fill(%arg4, %cst) : memref, f32 + linalg.matmul ins(%1, %arg3 : memref, memref) + outs(%arg4 : memref) + return + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK: func @sequence_of_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[C32:.+]] = constant 32 : index +// CHECK-DAG: %[[C64:.+]] = constant 64 : index +// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[N1:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[N2:.+]] = dim %[[ARG2]], %[[C1]] +// CHECK: %[[ALLOC1:.+]] = alloc(%[[M]], %[[N1]]) +// CHECK: %[[ALLOC2:.+]] = alloc(%[[M]], %[[N2]]) +// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) +// CHECK-SAME: step (%[[C16]]) { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[SV_ALLOC2:.+]] = subview %[[ALLOC2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N2]]] +// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] +// CHECK: %[[N3:.+]] = dim %[[ARG4]], %[[C1]] +// CHECK: %[[SV_ARG4:.+]] = subview %[[ARG4]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] +// CHECK: %[[SV_ARG4_2:.+]] = subview %[[ARG4]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N3]]] +// CHECK: %[[SV_ALLOC1:.+]] = subview %[[ALLOC1]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N1]]] +// CHECK: %[[SV_ARG2:.+]] = subview %[[ARG2]][0, 0] [%[[N1]], %[[N2]]] +// CHECK: %[[N0:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[SV_ARG0:.+]] = subview %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M:.+]], %[[N0]]] +// CHECK: %[[SV_ARG1:.+]] = subview %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] + +// CHECK: scf.parallel (%[[IV1:.+]]) = (%[[C0]]) to (%[[N1]]) +// CHECK-SAME: step (%[[C32]]) { +// CHECK: %[[TILE_N1:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N1]]] +// CHECK: %[[SV_SV_ALLOC1:.+]] = subview %[[SV_ALLOC1]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N1]]] +// CHECK: linalg.fill(%[[SV_SV_ALLOC1]], %{{.+}}) +// CHECK: } + +// CHECK: scf.parallel (%[[IV1:.+]]) = (%[[C0]]) to (%[[N1]]) +// CHECK-SAME: step (%[[C32]]) { +// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[N0]] step %[[C64]] { +// CHECK: %[[TILE_N0:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[N0]]] +// CHECK: %[[SV_SV_ARG0:.+]] = subview %[[SV_ARG0]][0, %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N0]]] +// CHECK: %[[TILE_N1:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N1]]] +// CHECK: %[[SV_SV_ARG1:.+]] = subview %[[SV_ARG1]][%[[IV2]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_N0]], %[[TILE_N1]]] +// CHECK: %[[SV_SV_ALLOC1:.+]] = subview %[[SV_ALLOC1]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N1]]] +// CHECK: linalg.matmul ins(%[[SV_SV_ARG0]], %[[SV_SV_ARG1]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_SV_ALLOC1]] : memref) +// CHECK: } +// CHECK: scf.yield +// CHECK: } + +// CHECK: scf.parallel (%[[IV1:.+]]) = (%[[C0]]) to (%[[N2]]) +// CHECK-SAME: step (%[[C32]]) { +// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N2]]] +// CHECK: %[[SV_SV_ALLOC2:.+]] = subview %[[SV_ALLOC2]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]] +// CHECK: linalg.fill(%[[SV_SV_ALLOC2]], %{{.+}}) +// CHECK: } +// CHECK: scf.parallel (%[[IV1:.+]]) = (%[[C0]]) to (%[[N2]]) +// CHECK-SAME: step (%[[C32]]) { +// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[N1]] step %[[C64]] { +// CHECK: %[[TILE_N1:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[N1]]] +// CHECK: %[[SV_SV_ALLOC1:.+]] = subview %[[SV_ALLOC1]][0, %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N1]]] +// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N2]]] +// CHECK: %[[SV_SV_ARG2:.+]] = subview %[[SV_ARG2]][%[[IV2]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_N1]], %[[TILE_N2]]] +// CHECK: %[[SV_SV_ALLOC2:.+]] = subview %[[SV_ALLOC2]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]] +// CHECK: linalg.matmul ins(%[[SV_SV_ALLOC1]], %[[SV_SV_ARG2]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_SV_ALLOC2]] : memref) +// CHECK: } +// CHECK: scf.yield +// CHECK: } + +// CHECK: scf.parallel (%[[IV1:.+]]) = (%[[C0]]) to (%[[N3]]) +// CHECK-SAME: step (%[[C32]]) { +// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N3]]] +// CHECK: %[[SV_SV_ARG4:.+]] = subview %[[SV_ARG4_2]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N3]]] +// CHECK: linalg.fill(%[[SV_SV_ARG4]], %{{.+}}) +// CHECK: } +// CHECK: %[[N3_2:.+]] = dim %[[ARG3]], %[[C1]] +// CHECK: scf.parallel (%[[IV1:.+]]) = (%[[C0]]) to (%[[N3_2]]) +// CHECK-SAME: step (%[[C32]]) { +// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[N2]] step %[[C64]] { +// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[N2]]] +// CHECK: %[[SV_SV_ALLOC2:.+]] = subview %[[SV_ALLOC2]][0, %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_N2]]] +// CHECK: %[[N2_2:.+]] = dim %[[ARG3]], %[[C0]] +// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[N2_2]]] +// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N3_2]]] +// CHECK: %[[SV_ARG3:.+]] = subview %[[ARG3]][%[[IV2]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_N2]], %[[TILE_N3]]] +// CHECK: %[[TILE_N3_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N3]]] +// CHECK: %[[SV_SV_ARG4:.+]] = subview %[[SV_ARG4]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N3_2]]] +// CHECK: linalg.matmul ins(%[[SV_SV_ALLOC2]], %[[SV_ARG3]] +// CHECK-SAME: : memref, memref) +// CHECK-SAME: outs(%[[SV_SV_ARG4]] : memref) +// CHECK: } +// CHECK: scf.yield +// CHECK: } + diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -197,6 +197,46 @@ } } }; + +/// Pass to test tile and fuse of sequence of operations. Intended only for +/// testing. +struct TestLinalgTileAndFuseSequencePass + : public PassWrapper { + TestLinalgTileAndFuseSequencePass() = default; + TestLinalgTileAndFuseSequencePass( + const TestLinalgTileAndFuseSequencePass &pass){}; + + ListOption tileSizes{ + *this, "tile-sizes", llvm::cl::desc("Tile sizes to use for ops"), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override { + FuncOp funcOp = getOperation(); + auto &blocks = funcOp.getBody().getBlocks(); + if (!llvm::hasSingleElement(blocks)) { + return; + } + SmallVector linalgOps = + llvm::to_vector<2>(blocks.front().getOps()); + Aliases aliases; + LinalgDependenceGraph dependenceGraph(aliases, linalgOps); + SmallVector tilingOptions( + linalgOps.size(), + LinalgTilingOptions().setTileSizes(tileSizes).setLoopType( + LinalgTilingLoopType::ParallelLoops)); + OpBuilder builder(funcOp.getContext()); + Optional tileAndFuseOps = tileAndFuseLinalgOps( + builder, linalgOps, dependenceGraph, tilingOptions); + if (!tileAndFuseOps) + return signalPassFailure(); + for (auto op : linalgOps) + builder.eraseOp(op); + } +}; } // namespace namespace mlir { @@ -211,5 +251,12 @@ "test-linalg-greedy-fusion", "Test Linalg fusion by applying a greedy test transformation."); } +void registerTestLinalgTileAndFuseSequencePass() { + PassRegistration + testTileAndFuseSequencePass( + "test-linalg-tile-and-fuse", + "Test Linalg tiling and fusion of a sequence of Linalg operations."); +} + } // namespace test } // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -74,6 +74,7 @@ void registerTestLinalgFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); +void registerTestLinalgTileAndFuseSequencePass(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -138,6 +139,7 @@ test::registerTestLinalgFusionTransforms(); test::registerTestLinalgGreedyFusion(); test::registerTestLinalgHoisting(); + test::registerTestLinalgTileAndFuseSequencePass(); test::registerTestLinalgTransforms(); test::registerTestLivenessPass(); test::registerTestLoopFusion();