diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -376,6 +376,24 @@ })); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the position of buffer in inputs + outputs list + }], + /*retTy=*/"Optional", + /*methodName=*/"getIndexOfInputAndOutputBuffer", + /*args=*/(ins "Value":$value), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + Optional inputIndex = getIndexOfInput(value); + if (inputIndex.hasValue()) return inputIndex.getValue(); + Optional outputIndex = getIndexOfOutputBuffer(value); + if (outputIndex.hasValue()) { + return $_op.getNumInputs() + outputIndex.getValue(); + } + return llvm::None; + }] + >, //===------------------------------------------------------------------===// // Other interface methods. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -18,6 +18,7 @@ namespace mlir { namespace linalg { +struct LinalgFusionOptions; struct LinalgTilingOptions; //===----------------------------------------------------------------------===// @@ -30,6 +31,13 @@ SmallVector loops; }; +struct TiledAndFusedLinalgOps { + LinalgOp op; + SmallVector fusedProducers; + SmallVector originalProducers; + SmallVector loops; +}; + /// Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( MLIRContext *context, SmallVectorImpl &patterns); @@ -52,6 +60,12 @@ Optional tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options); +Optional +tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, + const LinalgDependenceGraph &dependenceGraph, + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions); + /// Interchanges the `iterator_types` and `iterator_maps` dimensions of `op`. /// This is an in-place transformation controlled by `interchangeVector`. /// An empty vector is interpreted as the identity permutation and the @@ -322,6 +336,59 @@ } }; +struct LinalgFusionOptions { + /// Optional list of operands indices to use for fusion. When unspecified, + /// only one fusion is done, i.e., the pattern returns after the first fusion. + Optional> indicesToFuse = None; + LinalgFusionOptions &setIndicesToFuse(ArrayRef operands) { + indicesToFuse = DenseSet(); + indicesToFuse->insert(operands.begin(), operands.end()); + return *this; + } +}; + +struct LinalgBaseTileAndFusePattern : public RewritePattern { + LinalgBaseTileAndFusePattern(StringRef opName, MLIRContext *context, + const LinalgDependenceGraph &dependenceGraph, + LinalgTilingOptions tilingOptions, + LinalgFusionOptions fusionOptions, + LinalgMarker marker = LinalgMarker(), + LinalgMarker originalOpMarker = LinalgMarker(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// Dependence graph needed for fusion. + const LinalgDependenceGraph &dependenceGraph; + /// Options to control tiling. + LinalgTilingOptions tilingOptions; + /// Options to control fusion. + LinalgFusionOptions fusionOptions; + /// Marker to control application of the pattern. + LinalgMarker marker; + /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used + /// to build the dependence graph change then the dependenceGraph needs to be + /// recomputed right now. To not invalidate the dependenceGraph as + /// transformation happens, the original producer can be tagged with a marker + /// that can be later used to delete the original operations. + LinalgMarker originalOpMarker; +}; + +template +struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern { + LinalgTileAndFusePattern(MLIRContext *context, + const LinalgDependenceGraph &dependenceGraph, + LinalgTilingOptions tilingOptions, + LinalgFusionOptions fusionOptions, + LinalgMarker marker = LinalgMarker(), + LinalgMarker originalOpMarker = LinalgMarker(), + PatternBenefit benefit = 1) + : LinalgBaseTileAndFusePattern( + OpTy::getOperationName(), context, dependenceGraph, tilingOptions, + fusionOptions, marker, originalOpMarker, benefit) {} +}; + /// /// Linalg interchange patterns. /// diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -10,6 +10,7 @@ #define MLIR_DIALECT_LINALG_UTILS_H_ #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/SCF/SCF.h" @@ -78,6 +79,20 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer); +/// Returns the list of producer {op, view} pairs that are the last write to the +/// `consumer` and can be fused with it. +Optional +findFusableProducer(LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &dependenceGraph); + +/// Fuses the `producer` and `consumer` where the dependence between the two ops +/// are at `producerIdx` in the `producer` and `consumerIdx` in the +/// `consumer`. Does not check the legality of the fusion and returns the +/// generated fused op. +LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, + LinalgOp consumer, unsigned consumerIdx, + OperationFolder *folder = nullptr); + /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. /// When non-null, the optional pointer `folder` is used to call into the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -154,9 +154,9 @@ llvm_unreachable("Expect to be able to extract a view defining loop range"); } -static LinalgOp fuse(Value producedView, LinalgOp producer, LinalgOp consumer, - unsigned consumerIdx, unsigned producerIdx, - OperationFolder *folder) { +LinalgOp mlir::linalg::fuse(OpBuilder &b, LinalgOp producer, + unsigned producerIdx, LinalgOp consumer, + unsigned consumerIdx, OperationFolder *folder) { assert(producer.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(consumer.hasBufferSemantics() && @@ -174,9 +174,7 @@ // we can always identify a data dimension with a (at least one) loop // dimension. AffineMap producerMap = - producer.indexing_maps()[producer.getNumInputs() + producerIdx] - .cast() - .getValue(); + producer.indexing_maps()[producerIdx].cast().getValue(); LLVM_DEBUG(dbgs() << "Producer Idx: " << producerIdx << ", producer map: " << producerMap << "\n"); @@ -185,10 +183,9 @@ unsigned nWin = producer.getNumWindowLoops(); SmallVector loopRanges(nPar + nRed + nWin); - OpBuilder b(consumer.getOperation()); - auto loc = consumer.getLoc(); // Iterate over dimensions identified by the producer map for `producerIdx`. // This defines a subset of the loop ranges that we need to complete later. + auto loc = consumer.getLoc(); for (auto en : llvm::enumerate(producerMap.getResults())) { unsigned posInProducerLoop = en.value().cast().getPosition(); loopRanges[posInProducerLoop] = @@ -319,73 +316,80 @@ return true; } -static Optional -fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, - const LinalgDependenceGraph &graph, OperationFolder *folder, - LinalgDependenceGraph::DependenceType depType) { - assert(consumer.hasBufferSemantics() && - "expected linalg op with buffer semantics"); - LLVM_DEBUG(dbgs() << "\nStart examining consumer: " - << *consumer.getOperation()); - for (auto dependence : graph.getDependencesInto(consumer, depType)) { - LLVM_DEBUG(dbgs() << "\n***Consider producer:\t" - << *dependence.dependentOpView.op << "\n"); - auto producer = cast(dependence.dependentOpView.op); - - // Check that the dependence is indeed on the input `consumerIdx` view. - auto consumedView = dependence.indexingView; - if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) - continue; - - // Consumer consumes this view, `isStructurallyFusableProducer` also checks - // whether it is a strict subview of the producer view. - auto producedView = dependence.dependentOpView.view; - auto producerIdx = producer.getIndexOfOutputBuffer(producedView).getValue(); - // `consumerIdx` and `producerIdx` exist by construction. - LLVM_DEBUG(dbgs() << "\n" - << LinalgDependenceGraph::getDependenceTypeStr(depType) - << "producer: " << *producer.getOperation() << " view: " - << producedView << " output index: " << producerIdx); - - // Must be a subview or a slice to guarantee there are loops we can fuse - // into. - auto subView = consumedView.getDefiningOp(); - auto slice = consumedView.getDefiningOp(); - if (!subView && !slice) { - LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); - continue; - } +Optional +mlir::linalg::findFusableProducer( + LinalgOp consumer, unsigned consumerIdx, + const LinalgDependenceGraph &dependenceGraph) { + for (auto depType : { + LinalgDependenceGraph::DependenceType::RAW, + LinalgDependenceGraph::DependenceType::WAW, + }) { + for (auto dependence : + dependenceGraph.getDependencesInto(consumer, depType)) { + auto producer = cast(dependence.dependentOpView.op); - // Simple fusability checks. - if (!isFusableInto(graph, consumer, consumedView, producer)) - continue; + // Check that the dependence is indeed on the input `consumerIdx` view. + auto consumedView = dependence.indexingView; + if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView)) + continue; - // Fuse `producer` just before `consumer`. - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(consumer.getOperation()); - ScopedContext scope(b, consumer.getLoc()); - LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); - auto fusedProducer = fuse(producedView, producer, consumer, consumerIdx, - producerIdx, folder); + // Consumer consumes this view, `isStructurallyFusableProducer` also + // checks whether it is a strict subview of the producer view. + auto producedView = dependence.dependentOpView.view; + auto producerIdx = + producer.getIndexOfOutputBuffer(producedView).getValue(); + // `consumerIdx` and `producerIdx` exist by construction. + LLVM_DEBUG(dbgs() << "\n" + << LinalgDependenceGraph::getDependenceTypeStr(depType) + << "producer: " << *producer.getOperation() << " view: " + << producedView << " output index: " << producerIdx); + + // Simple fusability checks. + if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) + continue; - return FusionInfo{producer, fusedProducer}; + return dependence; + } } - return llvm::None; + return {}; } // Only consider RAW and WAW atm. Optional mlir::linalg::fuseProducerOf( OpBuilder &b, LinalgOp consumer, unsigned consumerIdx, const LinalgDependenceGraph &graph, OperationFolder *folder) { - for (auto dep : { - LinalgDependenceGraph::DependenceType::RAW, - LinalgDependenceGraph::DependenceType::WAW, - }) { - if (auto res = - fuseProducerOfDep(b, consumer, consumerIdx, graph, folder, dep)) - return res; + Optional fusableDependence = + findFusableProducer(consumer, consumerIdx, graph); + if (!fusableDependence) + return {}; + + LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); + Value producerView = fusableDependence->dependentOpView.view; + Value consumerView = fusableDependence->indexingView; + + // Must be a subview or a slice to guarantee there are loops we can fuse + // into. + auto subView = consumerView.getDefiningOp(); + auto slice = consumerView.getDefiningOp(); + if (!subView && !slice) { + LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)"); + return {}; } - return llvm::None; + + // Fuse `producer` just before `consumer`. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(consumer.getOperation()); + ScopedContext scope(b, consumer.getLoc()); + LLVM_DEBUG(dbgs() << "Fuse into consumer: " << *consumer << "\n"); + Optional producerIdxOpt = + producerOp.getIndexOfInputAndOutputBuffer(producerView); + assert(producerIdxOpt.hasValue() && + "producer view passes is not an operand of the producer"); + unsigned producerIdx = producerIdxOpt.getValue(); + + auto fusedProducer = + fuse(b, producerOp, producerIdx, consumer, consumerIdx, folder); + return FusionInfo{producerOp, fusedProducer}; } static void fuseLinalgOpsGreedily(FuncOp f) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -319,24 +319,9 @@ template Optional static tileLinalgOpImpl( - OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { - OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(op); - ScopedContext scope(b, op.getLoc()); - - assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); - // 1. Enforce the convention that "tiling by zero" skips tiling a particular - // dimension. This convention is significantly simpler to handle instead of - // adjusting affine maps to account for missing dimensions. + OpBuilder &b, LinalgOp op, ArrayRef tileSizes, + const LinalgTilingOptions &options) { auto nLoops = op.getNumLoops(); - SmallVector tileSizeVector = - options.tileSizeComputationFunction(b, op); - if (tileSizeVector.size() < nLoops) { - auto zero = std_constant_index(0); - tileSizeVector.append(nLoops - tileSizeVector.size(), zero); - } - - ArrayRef tileSizes = tileSizeVector; // Initial tile sizes may be too big, only take the first nLoops. tileSizes = tileSizes.take_front(nLoops); @@ -350,17 +335,7 @@ return llvm::None; } - // If interchangeVector is empty, use the identity. Build the permutation map - // otherwise. - auto invPermutationMap = - AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); - if (!options.interchangeVector.empty()) - invPermutationMap = inversePermutation(AffineMap::getPermutationMap( - options.interchangeVector, b.getContext())); - if (!invPermutationMap) - return llvm::None; - - // 2. Build the tiled loop ranges. + // 1. Build the tiled loop ranges. auto allViewSizes = getViewSizes(b, op); // The flattened loopToOperandRangesMaps is expected to be an invertible // permutation map (asserted in the inverse calculation). @@ -374,17 +349,39 @@ SmallVector loopRanges; LoopIndexToRangeIndexMap loopIndexToRangeIndex; std::tie(loopRanges, loopIndexToRangeIndex) = makeTiledLoopRanges( - b, scope.getLocation(), viewSizesToLoopsMap, allViewSizes, tileSizes); - if (!options.interchangeVector.empty()) - applyPermutationToVector(loopRanges, options.interchangeVector); + b, op.getLoc(), viewSizesToLoopsMap, allViewSizes, tileSizes); + SmallVector iteratorTypes; + for (auto attr : + enumerate(op.iterator_types().cast().getValue())) { + if (loopIndexToRangeIndex.count(attr.index())) + iteratorTypes.push_back(attr.value()); + } + // If interchangeVector is empty, use the identity. Build the permutation map + // otherwise. + auto invPermutationMap = + AffineMap::getMultiDimIdentityMap(tileSizes.size(), b.getContext()); + if (!options.interchangeVector.empty()) { + // Based on the pruned iterations (due to zero tile size), recompute the + // interchange vector. + SmallVector interchangeVector; + interchangeVector.reserve(options.interchangeVector.size()); + for (auto pos : options.interchangeVector) { + auto it = loopIndexToRangeIndex.find(pos); + if (it == loopIndexToRangeIndex.end()) + continue; + interchangeVector.push_back(it->second); + } + invPermutationMap = inversePermutation( + AffineMap::getPermutationMap(interchangeVector, b.getContext())); + if (!invPermutationMap) + return llvm::None; + applyPermutationToVector(loopRanges, interchangeVector); + applyPermutationToVector(iteratorTypes, interchangeVector); + } - // 3. Create the tiled loops. + // 2. Create the tiled loops. LinalgOp res = op; SmallVector ivs; - SmallVector iteratorTypes = - llvm::to_vector<4>(op.iterator_types().cast().getValue()); - if (!options.interchangeVector.empty()) - applyPermutationToVector(iteratorTypes, options.interchangeVector); GenerateLoopNest::doit( loopRanges, iteratorTypes, [&](ValueRange localIvs) { @@ -409,10 +406,10 @@ }, options.distribution); - // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. + // 3. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); - // 5. Gather the newly created loops and return them with the new op. + // 4. Gather the newly created loops and return them with the new op. SmallVector loops; loops.reserve(ivs.size()); for (auto iv : ivs) { @@ -428,17 +425,284 @@ return TiledLinalgOp{res, loops}; } +template +Optional static tileLinalgOpImpl( + OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(op); + ScopedContext scope(b, op.getLoc()); + + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + // Enforce the convention that "tiling by zero" skips tiling a particular + // dimension. This convention is significantly simpler to handle instead of + // adjusting affine maps to account for missing dimensions. + auto nLoops = op.getNumLoops(); + SmallVector tileSizeVector = + options.tileSizeComputationFunction(b, op); + if (tileSizeVector.size() < nLoops) { + auto zero = std_constant_index(0); + tileSizeVector.append(nLoops - tileSizeVector.size(), zero); + } + + return tileLinalgOpImpl(b, op, tileSizeVector, options); +} + Optional mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { - if (options.loopType == LinalgTilingLoopType::Loops) + switch (options.loopType) { + case LinalgTilingLoopType::Loops: return tileLinalgOpImpl(b, op, options); - if (options.loopType == LinalgTilingLoopType::ParallelLoops) + case LinalgTilingLoopType::ParallelLoops: return tileLinalgOpImpl(b, op, options); - // TODO: Impl tiling to affine loops when it makes sense. + default:; + } return llvm::None; } +/// Returns the position of the loop in `op` that can be tiled based on the +/// operations that are to be fused with it. For example, in a +/// +/// linalg. matmul(%a, %b, %c) +/// +/// if the producer of both %a needs to be fused with this op, only the `i` loop +/// of the matmul can be tiled while fusing. If producer of %a, and %b are to be +/// fused, then no loops can be tiled while fusing. +static DenseSet collectTileAndFuseLoops( + LinalgOp op, ArrayRef indices, + ArrayRef + fusableDependences) { + // 1. Only parallel loops can be used for tile + fuse.Find the number of + // commong outer parallel loops between the op and its producers being fused. + auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { + return linalgOp.iterator_types() + .getValue() + .take_while([](Attribute attr) -> bool { + return attr.cast().getValue() == + getParallelIteratorTypeName(); + }) + .size(); + }; + + size_t numOuterParallelLoops = getNumOuterParallelLoops(op); + for (auto dependence : fusableDependences) { + numOuterParallelLoops = + std::min(numOuterParallelLoops, getNumOuterParallelLoops(cast( + dependence.dependentOpView.op))); + } + + // Need to compute what tiled loops can be "fused". Given the precondition + // that all indexing map for the producer view is a permutation, a tile loop + // corresponding to the loop di, is fusable if the indexing map of the + // fused view in the consumer uses di. + // For example, + // + // linalg.copy (%arg0, %arg1) + // linalg.matmul %arg2, %arg1, %arg3 + // + // The indexingMap for the fused view in the consumer is + // + // affine_map<(d0, d1, d2) -> (d2, d1)> + // + // Since d0 is not the result expressions of this map, it is not treated as + // tile + fuse loop, (but d1 is). + SmallVector, 1> commonTilableLoops; + for (auto dependence : fusableDependences) { + unsigned consumerIdx = + op.getIndexOfInputAndOutputBuffer(dependence.indexingView).getValue(); + AffineMap consumerAccess = op.getIndexingMap(consumerIdx); + // Previously asserted that the consumerAccess map is a projected + // permutation, so all results are known to be AffineDimExprs.To remove this + // restriction check what dimension of the consumer loop appear in the + // `consumerAccess`. + DenseSet positions; + for (auto expr : consumerAccess.getResults()) { + positions.insert(expr.cast().getPosition()); + } + commonTilableLoops.emplace_back(std::move(positions)); + } + + // 2. Of the outer parallel loops, only those loops can be tiled + fused where + // the loop dim is used in index map at the producer. + DenseSet tilableParallelLoops; + for (auto index : llvm::seq(0, numOuterParallelLoops)) { + bool addAsTilable = true; + for (auto &tilableLoops : commonTilableLoops) { + if (!tilableLoops.count(index)) { + addAsTilable = false; + break; + } + } + if (addAsTilable) + tilableParallelLoops.insert(index); + } + return tilableParallelLoops; +} + +template +static Optional +tileAndFuseLinalgOpsImpl(PatternRewriter &rewriter, LinalgOp op, + const LinalgDependenceGraph &dependenceGraph, + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions) { + // Some of the tiling options might not be supportable with tile and fuse. + if (!tilingOptions.interchangeVector.empty()) { + op.emitError("unable to handle tile and fuse with interchange"); + return {}; + } + + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(op); + ScopedContext scope(rewriter, op.getLoc()); + + // Find all the producers. + SmallVector + fusableDependences; + SmallVector fusedIndices; + for (auto operand : llvm::enumerate(op.getInputsAndOutputBuffers())) { + if (fusionOptions.indicesToFuse && + !fusionOptions.indicesToFuse->count(operand.index())) + continue; + Optional + fusableDependence = + findFusableProducer(op, operand.index(), dependenceGraph); + if (!fusableDependence) { + if (fusionOptions.indicesToFuse) { + op.emitError("unable to fuse at index ") << operand.index(); + return {}; + } + continue; + } + // Make sure that the indexing map of the view used for fusion in the + // producer is a permutation. + LinalgOp producerOp = cast(fusableDependence->dependentOpView.op); + Value producerView = fusableDependence->dependentOpView.view; + unsigned producerIdx = + producerOp.getIndexOfInputAndOutputBuffer(producerView).getValue(); + AffineMap producerMap = producerOp.getIndexingMap(producerIdx); + if (!producerMap.isProjectedPermutation()) { + op.emitError("unhandled non permutation indexing map for fused view in " + "producer for operand at index ") + << operand.index(); + return {}; + } + Value consumerView = fusableDependence->indexingView; + unsigned consumerIdx = + op.getIndexOfInputAndOutputBuffer(consumerView).getValue(); + if (!op.getIndexingMap(consumerIdx).isProjectedPermutation()) { + op.emitError( + "unhandled case where indexing map for fused view in the consumer is " + "not a projected permuration while fusing at index ") + << operand.index(); + return {}; + } + fusableDependences.push_back(*fusableDependence); + fusedIndices.push_back(operand.index()); + if (!fusionOptions.indicesToFuse) + break; + } + // For now return if there is nothing to fuse. Not returning here would mean + // this method would just tile and return success. That might be acceptable + // behavior too. Evaluate based on use cases. + if (fusableDependences.empty()) { + op.emitError("nothing to fuse"); + return {}; + } + + assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); + // Enforce the convention that "tiling by zero" skips tiling a particular + // dimension. This convention is significantly simpler to handle instead of + // adjusting affine maps to account for missing dimensions. + auto nLoops = op.getNumLoops(); + SmallVector tileSizeVector = + tilingOptions.tileSizeComputationFunction(rewriter, op); + if (tileSizeVector.size() < nLoops) { + auto zero = std_constant_index(0); + tileSizeVector.append(nLoops - tileSizeVector.size(), zero); + } + + DenseSet tileFuseLoops = + collectTileAndFuseLoops(op, fusedIndices, fusableDependences); + + // If the number of fused loops doesnt match the number of loops mapped for + // distribution, abort the transofmration. + if (tilingOptions.distribution && + tilingOptions.distribution->distributionMethod.size() != + tileFuseLoops.size()) { + op.emitError("unable to tile and fuse since not all loops marked for " + "distribution are valid for tile and fuse"); + return {}; + } + + SmallVector tileAndFuseSizes, tileSizes; + for (auto tileSize : enumerate(tileSizeVector)) { + auto zero = std_constant_index(0); + if (tileFuseLoops.count(tileSize.index())) { + tileAndFuseSizes.push_back(tileSize.value()); + tileSizes.push_back(zero); + } else { + tileSizes.push_back(tileSize.value()); + tileAndFuseSizes.push_back(zero); + } + } + + TiledAndFusedLinalgOps ret; + + // Tile for the loops that can be fused. + Optional firstTiledOp = + tileLinalgOpImpl(rewriter, op, tileAndFuseSizes, tilingOptions); + if (!firstTiledOp) + return {}; + ret.op = firstTiledOp->op; + ret.loops.append(firstTiledOp->loops.begin(), firstTiledOp->loops.end()); + if (fusableDependences.empty()) { + return ret; + } + + rewriter.setInsertionPoint(ret.op); + // Fuse the operands. + for (auto producer : enumerate(fusableDependences)) { + LinalgOp producerOp = cast(producer.value().dependentOpView.op); + unsigned producerIdx = + producerOp.getIndexOfOutputBuffer(producer.value().dependentOpView.view) + .getValue(); + LinalgOp fusedOp = fuse(rewriter, producerOp, producerIdx, ret.op, + fusedIndices[producer.index()]); + ret.fusedProducers.push_back(fusedOp); + ret.originalProducers.push_back(producerOp); + } + + // Tile the remaining loops of the root operation. + LinalgTilingOptions secondTilingOptions = tilingOptions; + secondTilingOptions.distribution = llvm::None; + Optional secondTiledOp = tileLinalgOpImpl( + rewriter, ret.op, tileSizes, secondTilingOptions); + if (!secondTiledOp) + return {}; + ret.loops.append(secondTiledOp->loops.begin(), secondTiledOp->loops.end()); + rewriter.eraseOp(ret.op); + ret.op = secondTiledOp->op; + + return ret; +} + +Optional +mlir::linalg::tileAndFuseLinalgOps(PatternRewriter &rewriter, LinalgOp op, + const LinalgDependenceGraph &dependenceGraph, + const LinalgTilingOptions &tilingOptions, + const LinalgFusionOptions &fusionOptions) { + switch (tilingOptions.loopType) { + case LinalgTilingLoopType::Loops: + return tileAndFuseLinalgOpsImpl(rewriter, op, dependenceGraph, + tilingOptions, fusionOptions); + case LinalgTilingLoopType::ParallelLoops: + return tileAndFuseLinalgOpsImpl( + rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + default:; + } + return {}; +} + namespace { /// Helper classes for type list expansion. template diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -129,6 +129,42 @@ return success(); } +mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( + StringRef opName, MLIRContext *context, + const LinalgDependenceGraph &dependenceGraph, + LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, + LinalgMarker marker, LinalgMarker originalOpMarker, PatternBenefit benefit) + : RewritePattern(opName, {}, benefit, context), + dependenceGraph(dependenceGraph), tilingOptions(tilingOptions), + fusionOptions(fusionOptions), marker(marker), + originalOpMarker(originalOpMarker) {} + +LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + if (failed(marker.checkAndNotify(rewriter, linalgOp))) + return failure(); + if (!linalgOp.hasBufferSemantics()) + return failure(); + + Optional tiledAndFusedOps = tileAndFuseLinalgOps( + rewriter, op, dependenceGraph, tilingOptions, fusionOptions); + if (!tiledAndFusedOps) + return failure(); + marker.replaceLinalgMarker(rewriter, tiledAndFusedOps->op.getOperation()); + for (auto fusedOp : tiledAndFusedOps->fusedProducers) { + marker.replaceLinalgMarker(rewriter, fusedOp.getOperation()); + } + for (auto origProducerOp : tiledAndFusedOps->originalProducers) + originalOpMarker.replaceLinalgMarker(rewriter, + origProducerOp.getOperation()); + rewriter.updateRootInPlace( + op, [&]() { originalOpMarker.replaceLinalgMarker(rewriter, op); }); + return success(); +} + /// Linalg base interchange pattern. mlir::linalg::LinalgBaseInterchangePattern::LinalgBaseInterchangePattern( StringRef opName, MLIRContext *context, diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -0,0 +1,285 @@ +// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s + +module { + func @basic_fusion(%arg0: memref, %arg1: memref, + %arg2: memref) { + %cst = constant 0.000000e+00 : f32 + linalg.fill(%arg2, %cst) : memref, f32 + linalg.matmul %arg0, %arg1, %arg2 + {__internal_linalg_transform__ = "basic_fusion"} + : (memref, memref, memref) + return + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK: func @basic_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C32:.+]] = constant 32 : index +// CHECK-DAG: %[[C64:.+]] = constant 64 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[CST:.+]] = constant 0.0{{.*}} : f32 +// CHECK-DAG: linalg.fill(%[[ARG2]], %[[CST]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" +// CHECK-DAG: %[[M:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[N:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = +// CHECK-SAME: to (%[[M]], %[[N]]) +// CHECK-SAME: step (%[[C32]], %[[C64]]) { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[K:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: %[[SV1:.+]] = subview %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K]]] +// CHECK: %[[K_2:.+]] = dim %[[ARG1]], %[[C0]] +// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]] +// CHECK: %[[SV2:.+]] = subview %[[ARG1]][0, %[[IV1]]] +// CHECK-SAME: %[[K_2]], %[[TILE_N]] +// CHECK: %[[M_2:.+]] = dim %[[ARG2]], %[[C0]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] +// CHECK: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]] +// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N_2]]] +// CHECK: %[[SV3:.+]] = subview %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] +// CHECK: linalg.fill(%[[SV3]], %[[CST]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion" +// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { +// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] +// CHECK: %[[SV4:.+]] = subview %[[SV1]][0, %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] +// CHECK: %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]] +// CHECK: %[[SV5:.+]] = subview %[[SV2]][%[[IV2]], 0] +// CHECK-SAME: [%[[TILE_K_2]], %[[TILE_N]]] +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion" +// CHECK-SAME: %[[SV4]], %[[SV5]], %[[SV3]] +// CHECK: } +// CHECK: } +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" + +// ----- + +module { + func @rhs_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref) { + %cst = constant 0.000000e+00 : f32 + linalg.copy(%arg1, %arg2) : memref, memref + linalg.fill(%arg3, %cst) : memref, f32 + linalg.matmul %arg0, %arg2, %arg3 + {__internal_linalg_transform__ = "rhs_fusion"} + : (memref, memref, memref) + return + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK: func @rhs_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C32:.+]] = constant 32 : index +// CHECK-DAG: %[[C64:.+]] = constant 64 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[CST:.+]] = constant 0.0{{.*}} : f32 +// CHECK-DAG: linalg.copy(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_original" +// CHECK-DAG: %[[N:.+]] = dim %[[ARG2]], %[[C1]] +// CHECK: scf.parallel (%[[IV0:.+]]) = +// CHECK-SAME: (%[[C0]]) to (%[[N]]) step (%[[C64]]) { +// CHECK: %[[K:.+]] = dim %[[ARG2]], %[[C0]] +// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N]]] +// CHECK: %[[SV1:.+]] = subview %[[ARG2]][0, %[[IV0]]] +// CHECK-SAME: [%[[K]], %[[TILE_N]]] +// CHECK: %[[M:.+]] = dim %[[ARG3]], %[[C0]] +// CHECK: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]] +// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[N_2]]] +// CHECK: %[[SV2:.+]] = subview %[[ARG3]][0, %[[IV0]]] +// CHECK-SAME: [%[[M]], %[[TILE_N_2]]] +// CHECK: %[[SV3:.+]] = subview %[[ARG1]][0, %[[IV0]]] +// CHECK-SAME: [%[[K]], %[[TILE_N]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion" +// CHECK-NOT: linalg.fill +// CHECK-DAG: %[[M_2:.+]] = dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K_2:.+]] = dim %[[ARG0]], %[[C1]] +// CHECK: scf.parallel (%[[IV1:.+]]) = +// CHECK-SAME: (%[[C0]]) to (%[[M_2]]) step (%[[C32]]) { +// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K_2]] step %[[C16]] { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[M_2]]] +// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K_2]]] +// CHECK: %[[SV4:.+]] = subview %[[ARG0]][%[[IV1]], %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] +// CHECK: %[[TILE_K_2:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] +// CHECK: %[[SV5:.+]] = subview %[[SV1]][%[[IV2]], 0] +// CHECK-SAME: [%[[TILE_K_2]], %[[TILE_N]]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[M]]] +// CHECK: %[[SV6:.+]] = subview %[[SV2]][%[[IV1]], 0] +// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion" +// CHECK-SAME: %[[SV4]], %[[SV5]], %[[SV6]] +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_rhs_fusion_original" + + +// ----- + +module { + func @two_operand_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref) { + %cst = constant 0.000000e+00 : f32 + linalg.copy(%arg0, %arg1) : memref, memref + linalg.fill(%arg3, %cst) : memref, f32 + linalg.matmul %arg1, %arg2, %arg3 + {__internal_linalg_transform__ = "two_operand_fusion"} + : (memref, memref, memref) + return + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK: func @two_operand_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C32:.+]] = constant 32 : index +// CHECK-DAG: %[[C64:.+]] = constant 64 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK-DAG: %[[CST:.+]] = constant 0.0{{.*}} : f32 +// CHECK: linalg.copy(%[[ARG0]], %[[ARG1]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_original" +// CHECK: linalg.fill(%[[ARG3]], %[[CST]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_original" +// CHECK-DAG: %[[M:.+]] = dim %[[ARG1]], %[[C0]] +// CHECK: scf.parallel (%[[IV0:.+]]) = +// CHECK-SAME: (%[[C0]]) to (%[[M]]) step (%[[C32]]) { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[K:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[SV1:.+]] = subview %[[ARG1]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K]]] +// CHECK: %[[M_2:.+]] = dim %[[ARG3]], %[[C0]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] +// CHECK: %[[N:.+]] = dim %[[ARG3]], %[[C1]] +// CHECK: %[[SV2:.+]] = subview %[[ARG3]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_2]], %[[N]]] +// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K]]] +// CHECK: linalg.copy(%[[SV3]], %[[SV1]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion" +// CHECK: linalg.fill(%[[SV2]], %[[CST]]) +// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion" +// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG2]], %[[C1]] +// CHECK: scf.parallel (%[[IV1:.+]]) = +// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { +// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { +// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]] +// CHECK: %[[SV4:.+]] = subview %[[SV1]][0, %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] +// CHECK: %[[K_2:.+]] = dim %[[ARG2]], %[[C0]] +// CHECK: %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]] +// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]] +// CHECK: %[[SV5:.+]] = subview %[[ARG2]][%[[IV2]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_K_2]], %[[TILE_N]]] +// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]] +// CHECK: %[[SV6:.+]] = subview %[[SV2]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion" +// CHECK-SAME: %[[SV4]], %[[SV5]], %[[SV6]] +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_two_operand_fusion_original" + +// ----- + +module { + func @matmul_fusion(%arg0: memref, %arg1: memref, + %arg2: memref, %arg3: memref, + %arg4: memref) { + linalg.matmul %arg0, %arg1, %arg2 + : (memref, memref, memref) + linalg.matmul %arg2, %arg3, %arg4 + {__internal_linalg_transform__ = "lhs_fusion"} + : (memref, memref, memref) + return + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (32, -d0 + s0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (16, -d0 + s0)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (64, -d0 + s0)> +// CHECK: func @matmul_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C32:.+]] = constant 32 : index +// CHECK-DAG: %[[C64:.+]] = constant 64 : index +// CHECK-DAG: %[[C16:.+]] = constant 16 : index +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original" +// CHECK-DAG: %[[M:.+]] = dim %[[ARG2]], %[[C0]] +// CHECK: scf.parallel (%[[IV0:.+]]) = +// CHECK-SAME: (%[[C0]]) to (%[[M]]) step (%[[C32]]) { +// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[K:.+]] = dim %[[ARG2]], %[[C1]] +// CHECK: %[[SV1:.+]] = subview %[[ARG2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K]]] +// CHECK: %[[M_2:.+]] = dim %[[ARG4]], %[[C0]] +// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M_2]]] +// CHECK: %[[N:.+]] = dim %[[ARG4]], %[[C1]] +// CHECK: %[[SV2:.+]] = subview %[[ARG4]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M_2]], %[[N]]] +// CHECK: %[[N_2:.+]] = dim %[[ARG1]], %[[C1]] +// CHECK: %[[SV3:.+]] = subview %[[ARG0]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[K]]] +// CHECK: %[[SV4:.+]] = subview %[[ARG1]][0, 0] [%[[K]], %[[N_2]]] +// CHECK: %[[SV5:.+]] = subview %[[ARG2]][%[[IV0]], 0] +// CHECK-SAME: [%[[TILE_M]], %[[N_2]]] +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" +// CHECK-SAME: %[[SV3]], %[[SV4]], %[[SV5]] +// CHECK-DAG: %[[N_2:.+]] = dim %[[ARG3]], %[[C1]] +// CHECK: scf.parallel (%[[IV1:.+]]) = +// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { +// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { +// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K]]] +// CHECK: %[[SV4:.+]] = subview %[[SV1]][0, %[[IV2]]] +// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] +// CHECK: %[[K_2:.+]] = dim %[[ARG3]], %[[C0]] +// CHECK: %[[TILE_K_2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K_2]]] +// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]] +// CHECK: %[[SV5:.+]] = subview %[[ARG3]][%[[IV2]], %[[IV1]]] +// CHECK-SAME: [%[[TILE_K_2]], %[[TILE_N]]] +// CHECK: %[[TILE_N_2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]] +// CHECK: %[[SV6:.+]] = subview %[[SV2]][0, %[[IV1]]] +// CHECK-SAME: [%[[TILE_M_2]], %[[TILE_N_2]]] +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" +// CHECK-SAME: %[[SV4]], %[[SV5]], %[[SV6]] +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: linalg.matmul +// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original" diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ TestGpuMemoryPromotion.cpp TestGpuParallelLoopMapping.cpp TestInlining.cpp + TestLinalgFusionTransforms.cpp TestLinalgHoisting.cpp TestLinalgTransforms.cpp TestLiveness.cpp diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -0,0 +1,115 @@ +//===- TestLinalgFusionTransforms.cpp - Test Linalg fusion patterns -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing Linalg fusion patterns. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; +using namespace mlir::linalg; + +namespace { +struct TestLinalgFusionTransforms + : public PassWrapper { + TestLinalgFusionTransforms() = default; + TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override; +}; +} // namespace + +static void fillFusionPatterns(MLIRContext *context, + const LinalgDependenceGraph &dependenceGraph, + OwningRewritePatternList &patterns) { + patterns.insert>( + context, dependenceGraph, + LinalgTilingOptions() + .setTileSizes({32, 64, 16}) + .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgFusionOptions(), + LinalgMarker(Identifier::get("basic_fusion", context), + Identifier::get("after_basic_fusion", context)), + LinalgMarker(ArrayRef(), + Identifier::get("after_basic_fusion_original", context))); + + patterns.insert>( + context, dependenceGraph, + LinalgTilingOptions() + .setTileSizes({32, 64, 16}) + .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgFusionOptions().setIndicesToFuse({0}), + LinalgMarker(Identifier::get("lhs_fusion", context), + Identifier::get("after_lhs_fusion", context)), + LinalgMarker(ArrayRef(), + Identifier::get("after_lhs_fusion_original", context))); + + patterns.insert>( + context, dependenceGraph, + LinalgTilingOptions() + .setTileSizes({32, 64, 16}) + .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgFusionOptions().setIndicesToFuse({1}), + LinalgMarker(Identifier::get("rhs_fusion", context), + Identifier::get("after_rhs_fusion", context)), + LinalgMarker(ArrayRef(), + Identifier::get("after_rhs_fusion_original", context))); + + patterns.insert>( + context, dependenceGraph, + LinalgTilingOptions() + .setTileSizes({32, 64, 16}) + .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgFusionOptions().setIndicesToFuse({0, 2}), + LinalgMarker(Identifier::get("two_operand_fusion", context), + Identifier::get("after_two_operand_fusion", context)), + LinalgMarker( + ArrayRef(), + Identifier::get("after_two_operand_fusion_original", context))); + + patterns.insert>( + context, dependenceGraph, + LinalgTilingOptions() + .setTileSizes({32, 64, 16}) + .setLoopType(LinalgTilingLoopType::ParallelLoops), + LinalgFusionOptions().setIndicesToFuse({0, 1}), + LinalgMarker(Identifier::get("two_input_fusion", context), + Identifier::get("after_two_input_fusion", context)), + LinalgMarker( + ArrayRef(), + Identifier::get("after_two_input_fusion_original", context))); +} + +static void applyFusionPatterns(MLIRContext *context, FuncOp funcOp) { + OwningRewritePatternList fusionPatterns; + Aliases alias; + LinalgDependenceGraph dependenceGraph = + LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); + fillFusionPatterns(context, dependenceGraph, fusionPatterns); + applyPatternsAndFoldGreedily(funcOp, fusionPatterns); +} + +void TestLinalgFusionTransforms::runOnFunction() { + applyFusionPatterns(&getContext(), getFunction()); +} + +namespace mlir { +void registerTestLinalgFusionTransforms() { + PassRegistration testFusionTransformsPass( + "test-linalg-fusion-transform-patterns", + "Test Linalg fusion transformation patterns by applying them greedily."); +} +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -56,6 +56,7 @@ void registerTestGpuMemoryPromotionPass(); void registerTestGpuParallelLoopMappingPass(); void registerTestInterfaces(); +void registerTestLinalgFusionTransforms(); void registerTestLinalgHoisting(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); @@ -110,6 +111,7 @@ registerTestExpandTanhPass(); registerTestGpuMemoryPromotionPass(); registerTestInterfaces(); + registerTestLinalgFusionTransforms(); registerTestLinalgHoisting(); registerTestLinalgTransforms(); registerTestLivenessPass();