diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -169,71 +169,6 @@ ArrayRef peeledLoops, LinalgTilingLoopType loopType); -/// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This -/// proceeds as follows: -/// - Find outer parallel loops in these ops that can be fused. -/// - Tile fusable outer parallel loops of the last operation in the sequence. -/// - Fuse the remaining operations with the tiled operation -/// -/// For example, consider the sequence of matmul below -/// -/// linalg.matmul ins(%arg0, %arg1 : memref<256x32xf32>, memref<32x32xf32>) -/// outs(%arg2 : memref<256x32xf32>) -/// linalg.matmul ins(%arg2, %arg3 : memref<256x32xf32>, memref<32x32xf32>) -/// outs(%arg4 : memref<256x32xf32>) -/// -/// It is legal to fuse the RAW dependence (through %arg2) by only fusing the -/// matmuls row-wise. For example, the fused computation for the above is shown -/// below. The outer `scf.parallel` loop is the "fused" loop obtained by tiling -/// along the rows of the matrix. The entire rows of the first matmul operation -/// need to be computed before they can be used for the second matmul. The -/// second matmul is further tiled (similar to normal tiling). -/// -/// #map0 = affine_map<(d0, d1)[s0] -> (d0 * 32 + s0 + d1)> -/// #map1 = affine_map<(d0, d1) -> (d0 * 32 + d1)> -/// scf.parallel (%arg5) = (%c0) to (%c256) step (%c16) { -/// %0 = subview %arg2[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %1 = subview %arg4[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %2 = subview %arg0[%arg5, 0] [16, 32] [1, 1] -/// : memref<256x32xf32> to memref<16x32xf32, #map0> -/// %3 = subview %arg1[0, 0] [32, 32] [1, 1] -/// : memref<32x32xf32> to memref<32x32xf32, #map1> -/// %4 = subview %arg3[0, 0] [32, 32] [1, 1] -/// : memref<32x32xf32> to memref<32x32xf32, #map1> -/// linalg.matmul -/// ins(%2, %3 : memref<16x32xf32, #map0>, memref<32x32xf32, #map1>) -/// outs(%0 : memref<16x32xf32, #map0>) -/// linalg.matmul -/// ins(%0, %4 : memref<16x4xf32, #map0>, memref<4x8xf32, #map0>) -/// outs(%1 : memref<16x8xf32, #map0>) -/// } -/// -/// `tilingOptions` are used to tile the corresponding operation in `ops` (the -/// size of the former should be same as size of the latter. Based on how -/// tile+fuse is implemented, the fused loops are generated based on the last -/// operation in the sequence. For example, the tile sizes for the fused loops -/// is obtained from `tilingOptions.back()`. The following tiling options are -/// handled differently in tile+fuse (compared to tile only) -/// - Interchange of the tiling loops is not supported right now. -/// - Only the fused loops are distributed. -struct TiledAndFusedLinalgOps { - /// Operation obtained by tiling the last operation in sequence of `ops` - /// passed to `tileAndFuseLinalgOps`. - LinalgOp op; - /// The dimension of the loops that are fused. - std::set fusedLoopDims; - /// The generated fused operations (created within the fused loops). - SmallVector fusedProducers; - /// The fused loop generated. - SmallVector fusedLoops; -}; -FailureOr -tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions); - /// Interchange the `iterator_types` and `iterator_maps` dimensions and adapts /// the index accesses of `op`. This is an in-place transformation controlled by /// `interchangeVector`. An empty vector is interpreted as the identity @@ -847,62 +782,6 @@ LinalgTransformationFilter filter; }; -struct LinalgFusionOptions { - /// List of operands indices to use for fusion. - llvm::SmallSet indicesToFuse = {}; - LinalgFusionOptions &setIndicesToFuse(ArrayRef operands) { - indicesToFuse.insert(operands.begin(), operands.end()); - return *this; - } -}; - -struct LinalgBaseTileAndFusePattern : public RewritePattern { - LinalgBaseTileAndFusePattern( - StringRef opName, MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(), - LinalgTransformationFilter originalOpMarker = - LinalgTransformationFilter(), - PatternBenefit benefit = 1); - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; - -private: - /// Dependence graph needed for fusion. - const LinalgDependenceGraph &dependenceGraph; - /// Options to control tiling. - LinalgTilingOptions tilingOptions; - /// Options to control fusion. - LinalgFusionOptions fusionOptions; - /// Marker to control application of the pattern. - LinalgTransformationFilter filter; - /// Marker set on the fused op after tile and fuse. - LinalgTransformationFilter fusedOpMarker; - /// The dependenceGraph is not modifiable, i.e. if the Linalg operations used - /// to build the dependence graph changes then the dependenceGraph needs to be - /// recomputed right now. To not invalidate the dependenceGraph as - /// transformation happens, the original producer can be tagged with a filter - /// that can be later used to delete the original operations. - LinalgTransformationFilter originalOpMarker; -}; - -template -struct LinalgTileAndFusePattern : public LinalgBaseTileAndFusePattern { - LinalgTileAndFusePattern( - MLIRContext *context, const LinalgDependenceGraph &dependenceGraph, - LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgTransformationFilter fusedOpMarker = LinalgTransformationFilter(), - LinalgTransformationFilter originalOpMarker = - LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBaseTileAndFusePattern( - OpTy::getOperationName(), context, dependenceGraph, tilingOptions, - fusionOptions, f, fusedOpMarker, originalOpMarker, benefit) {} -}; - /// /// Linalg tile and fuse tensor ops pattern. /// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -460,436 +460,3 @@ consumerOpOperand.set(def); return FusionInfo{cast(producerOpResult.getOwner()), fusedProducer}; } - -/// Prune all dimensions that are of reduction iterator type from `map`. -static AffineMap pruneReductionDimsFromMap(ArrayRef iteratorTypes, - AffineMap map) { - llvm::SmallBitVector projectedDims(iteratorTypes.size()); - for (const auto &attr : llvm::enumerate(iteratorTypes)) { - if (!isParallelIterator(attr.value())) - projectedDims.set(attr.index()); - } - return getProjectedMap(map, projectedDims); -} - -/// Returns the mapping from iterations in the consumer that write to the same -/// location as the iterations in the producer. To do so use -/// - indexing map of the fused view in the consumer : consumerIndexMap -/// - indexing map of the fused view in the producer : producerIndexMap -/// consumerLoopToProducerLoop = -/// inverse(producerIndexMap).compose(consumerIndexMap) -static FailureOr getConsumerLoopToProducerLoopMap( - LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { - auto producer = dyn_cast(dependence.getDependentOp()); - if (!producer) - return failure(); - - Optional producerIndexingMap = - dependence.getDependentOpViewIndexingMap(); - Optional consumerIndexingMap = - dependence.getIndexingOpViewIndexingMap(); - if (!producerIndexingMap || !consumerIndexingMap) - return failure(); - - AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), *producerIndexingMap); - if (!prunedProducerIndexingMap.isPermutation()) - return failure(); - - if (consumerIndexingMap->getNumResults() != - prunedProducerIndexingMap.getNumResults()) - return failure(); - - LLVM_DEBUG({ - llvm::dbgs() << "\t producerMap : "; - producerIndexingMap->print(llvm::dbgs()); - llvm::dbgs() << " pruned : "; - prunedProducerIndexingMap.print(llvm::dbgs()); - llvm::dbgs() << "\n"; - llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap->print(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); - - AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap); - if (!invProducerIndexMap) - return failure(); - - return invProducerIndexMap.compose(*consumerIndexingMap); -} - -/// Given a projected permutation `map`, returns true if the map changes the -/// order in which the fused loop dimension appear. -static bool doesTransposeAccess(AffineMap map, - const std::set &fusableLoops) { - Optional lastFusableLoop; - for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) { - return expr.cast().getPosition(); - })) { - if (!fusableLoops.count(pos)) - continue; - if (!lastFusableLoop) { - lastFusableLoop = pos; - continue; - } - if (pos <= *lastFusableLoop) - return true; - lastFusableLoop = pos; - } - return false; -} - -/// Returns the positions of the loop in `op` that can be tiled based on the -/// operations that are to be fused with it. For example, in a -/// -/// linalg.matmul ins(%a, %b : ...) outs(%c : ...) -/// -/// if the producer of %a needs to be fused with this op, only the `i` loop of -/// the matmul can be tiled while fusing. If producer of %a, and %b are to be -/// fused, then no loops can be tiled while fusing. The conditions used are: -/// 1. Only parallel loops can be used for tile + fuse. Find the number of -/// common outer parallel loops between the op and its producers being fused. -/// 2. Of the parallel loops only some can be fused. Only those loops can be -/// fused such where the fusable loops iteration space only touches one tile -/// of the fused operation. This is because the producer (which is writing -/// the fused subview) has update semantics. -/// -/// Since an inverse computation is needed, we need to consider the projection -/// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops -/// are the dimensions of the consumerLoopToProducerLoop map that correspond to -/// parallel loops and appear in the result of the map -/// -/// Example 1: -/// linalg.fill(%cst, %c) -/// linalg.matmul ins(%a, %b) outs(%c) -/// Number of parallel loops : 2 -/// producerIndexMap = affine_map<(i, j) ->(i , j)> -/// consumerIndexMap = affine_map<(i, j, k) -> (i, j)> -/// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)> -/// Fused dimensions : i, j -/// -/// Example 2: -/// linalg.matmul ins(%a, %b) outs(%c) -/// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ... -/// iterator_types = ["parallel", "parallel"]} -/// ins(%c) ... -/// -/// Number of parallel loops = 2: -/// producerIndexMap (projected to parallel loops) = -/// affine_map<(i, j) -> (i, j)> -/// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)> -/// Fused dimensions : i, j -/// -/// Example 3: -/// memref.copy(%s, %b) -/// linalg.matmul ins(%a, %b) outs(%c) -/// -/// Number of parallel loops = 2 -/// produceIndexMap : affine_map<(i, j) -> (i, j)> -/// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)> -/// submap with only parallel loops = affine_map<(i, j) -> (j)> -/// Fused dimensions : j -static std::set -collectFusableLoops(ArrayRef ops, - const FusableOpDependencesTy &fusableDependences) { - assert(!ops.empty()); - auto getNumOuterParallelLoops = [](LinalgOp linalgOp) { - return linalgOp.iterator_types() - .getValue() - .take_while([](Attribute attr) -> bool { - return attr.cast().getValue() == - getParallelIteratorTypeName(); - }) - .size(); - }; - - size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back()); - for (auto op : ops.drop_back()) { - numOuterParallelLoops = - std::min(numOuterParallelLoops, getNumOuterParallelLoops(op)); - } - - std::set fusableLoops; - auto range = llvm::seq(0, numOuterParallelLoops); - fusableLoops.insert(range.begin(), range.end()); - - for (auto op : reverse(ops)) { - for (auto dependence : fusableDependences.lookup(op)) { - LLVM_DEBUG({ - llvm::dbgs() << "\t fusable :"; - for (unsigned i : fusableLoops) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - - Optional consumerLoopToProducerLoop = - getConsumerLoopToProducerLoopMap(dependence); - if (!consumerLoopToProducerLoop) { - op.emitRemark("failed to get map from consumer loop to producer loop"); - return {}; - } - // todo: This condition is only an implementation limitation. When fusing - // the operation, if the accesses in the producer/consumer are transposes - // of each other, the loop bounds for the tiled producer can be - // manipulated accordingly. This requires some additional bookkeeping in - // the implementation of tile+fuse that is deferred to later. - if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) { - op.emitRemark("unhandled fusion when fusion requires permutation"); - return {}; - } - - std::set candidates; - for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) { - unsigned position = expr.cast().getPosition(); - if (fusableLoops.count(position)) - candidates.insert(position); - } - LLVM_DEBUG({ - llvm::dbgs() << "\t candidates :"; - for (unsigned i : candidates) - llvm::dbgs() << " " << i; - llvm::dbgs() << "\n"; - }); - if (candidates.empty()) - return {}; - std::swap(candidates, fusableLoops); - } - } - - return fusableLoops; -} - -/// Find all dependences that are fusable. -FusableOpDependencesTy mlir::linalg::findAllFusableDependences( - ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { - FusableOpDependencesTy fusableDependences; - DenseMap> fusedProducerIndexingMap; - for (LinalgOp op : reverse(ops)) { - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { - Optional - fusableDependence = findFusableProducer(*opOperand, dependenceGraph); - if (!fusableDependence) - continue; - LinalgOp producerOp = - dyn_cast(fusableDependence->getDependentOp()); - if (!producerOp) - continue; - // Do not fuse dependences that are to operations not in the same basic - // block. This avoid moving fused operations across loops that might - // themselves carry dependency making the fusion illegal. - if (producerOp->getBlock() != op->getBlock()) - continue; - - // Make sure that the indexing map of the view used for fusion in the - // producer is a projected permutation. - Optional producerMap = - fusableDependence->getDependentOpViewIndexingMap(); - Optional consumerMap = - fusableDependence->getIndexingOpViewIndexingMap(); - assert( - consumerMap && - "unable to find indexing map of operand/result of indexing OpView"); - fusedProducerIndexingMap[producerOp.getOperation()].push_back( - *consumerMap); - if (!producerMap || !producerMap->isProjectedPermutation() || - !consumerMap->isProjectedPermutation()) - continue; - - fusableDependences[producerOp.getOperation()].push_back( - *fusableDependence); - } - } - // TODO: Currently fusion would not be legal if the fusable dependence is to - // the same producer but different indexing map in the consumer. Fix this, but - // in the meanwhile disallow such a fusion. - for (auto useIndexingMapsList : fusedProducerIndexingMap) { - AffineMap map1 = useIndexingMapsList.second.front(); - for (AffineMap map2 : - ArrayRef(useIndexingMapsList.second).drop_front()) { - if (map1 != map2) { - fusableDependences.erase(useIndexingMapsList.first); - break; - } - } - } - return fusableDependences; -} - -/// Tile the fused loops in the root operation, by setting the tile sizes for -/// all other loops to zero (those will be tiled later). -static FailureOr -tileRootOperation(OpBuilder &b, LinalgOp op, ArrayRef tileSizeVector, - const LinalgTilingOptions &options, - const std::set &fusedLoops) { - SmallVector tileSizes(tileSizeVector.begin(), tileSizeVector.end()); - auto zero = b.create(op.getLoc(), 0); - for (unsigned i = 0, e = tileSizes.size(); i != e; ++i) - if (!fusedLoops.count(i)) - tileSizes[i] = zero; - LinalgTilingOptions tileFusedLoopsOptions = options; - tileFusedLoopsOptions.setTileSizes(tileSizes); - // TODO: Propagate RewriterBase everywhere. - IRRewriter rewriter(b); - return tileLinalgOp(rewriter, op, tileFusedLoopsOptions); -} - -/// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected -/// to be a tiled operation such that it is valid to fuse all operations in -/// `fusionCandidates`, i.e. move the operation within the inter-tile loops of -/// `tiledOp`. -static SmallVector -fuseOperations(OpBuilder &b, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp, - ArrayRef fusionCandidates, - const FusableOpDependencesTy &fusableDependences, - const std::set &fusedLoops) { - LinalgOp tiledOp = tiledLinalgOp.op; - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(tiledOp); - - DenseMap fusedLoopsAndRanges; - for (unsigned loop : fusedLoops) { - ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); - fusedLoopsAndRanges[loop] = getRangeFromOperandShape( - b, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension); - } - - SmallVector fusedOps(fusionCandidates.size()); - DenseMap origOpToFusedOp; - origOpToFusedOp[rootOp.getOperation()] = tiledOp; - for (const auto &candidate : enumerate(llvm::reverse(fusionCandidates))) { - LinalgOp origOp = candidate.value(); - LinalgOp fusedOp = fuse(b, origOp, fusedLoopsAndRanges); - origOpToFusedOp[origOp.getOperation()] = fusedOp; - fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; - - // Prepare the builder for the next insertion point. - auto guard = llvm::make_scope_exit([&]() { b.setInsertionPoint(fusedOp); }); - if (!origOp.hasTensorSemantics()) - continue; - - // If the producer consumer operations are linalg operations on tensors, the - // dependence is due to value produced (as a return tensor) by the producer - // and used in the consumer. The returned value of the fused op needs to be - // made the operand of the tiled/fused consumer operation. By construction - // the value returned by the producer is the value used by the consumer. - for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { - if (dependence.dependenceType != - LinalgDependenceGraph::DependenceType::RAW) - continue; - - unsigned resultIndex = dependence.getDependentOpViewResultNum().value(); - LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); - if (!consumer) - continue; - - Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); - consumer.getOperation()->setOperand( - dependence.getIndexingOpViewOperandNum().value(), replacementValue); - } - - // At this point, all Linalg uses of the tensors produced by `origOp` have - // been replaced. However, there may still be "output tensor"-like uses - // coming from WAW dependencies. - // All these uses are iter_args of the outermost loop (TODO: add a check). - // Such iter_args uses serve 2 purposes: - // 1. give a shape to the output - // 2. encode destructive updates that may be inplaceable by bufferization. - // To keep the second type of information while letting the unfused op die - // unused, we need to forward the producer output operand. - if (auto forOp = dyn_cast(tiledLinalgOp.loops.front())) { - for (auto &operand : forOp.getIterOpOperands()) { - if (auto opResult = operand.get().dyn_cast()) { - if (opResult.getOwner() == origOp) { - Value output = - origOp.getOutputOperand(opResult.getResultNumber())->get(); - assert(output.getType().isa()); - operand.set(output); - } - } - } - } - } - return fusedOps; -} - -static FailureOr -tileAndFuseLinalgOpsImpl(OpBuilder &b, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { - if (ops.size() < 2) - return failure(); - LinalgOp rootOp = ops.back(); - if (!llvm::all_of( - ops, - [](LinalgOp linalgOp) { return linalgOp.hasBufferSemantics(); }) && - !llvm::all_of(ops, [](LinalgOp linalgOp) { - return linalgOp.hasTensorSemantics(); - })) { - rootOp.emitError( - "unable to fuse operations that have tensor semantics with operations " - "that have buffer semantics and viceversa."); - return failure(); - } - // TODO: Support interchange with tile + fuse. This might actually help do - // better fusion. - if (!tilingOptions.interchangeVector.empty()) { - rootOp.emitRemark("unable to handle tile and fuse with interchange"); - return failure(); - } - - OpBuilder::InsertionGuard guard(b); - b.setInsertionPoint(rootOp); - - // Find all the producers. - LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n"); - FusableOpDependencesTy fusableDependences = - findAllFusableDependences(ops, dependenceGraph); - if (fusableDependences.empty()) { - LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n"); - return failure(); - } - - TiledAndFusedLinalgOps ret; - // Find the loops that can be tiled and fused. - LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n"); - ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); - - // If there are no fusable dependences or there are no tile+fusable loops, - // just return. - if (ret.fusedLoopDims.empty()) { - LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n"); - return failure(); - } - - // Tile the fused loops in the last operation in the list. - SmallVector tileSizeVector = - tilingOptions.tileSizeComputationFunction(b, rootOp); - FailureOr tiledRootOp = tileRootOperation( - b, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); - if (failed(tiledRootOp)) { - rootOp.emitRemark("failed to tile the fused loops"); - return failure(); - } - ret.op = tiledRootOp->op; - ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); - - // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(b, rootOp, *tiledRootOp, ops.drop_back(), - fusableDependences, ret.fusedLoopDims); - - return ret; -} - -FailureOr -mlir::linalg::tileAndFuseLinalgOps(OpBuilder &b, ArrayRef ops, - const LinalgDependenceGraph &dependenceGraph, - const LinalgTilingOptions &tilingOptions) { - switch (tilingOptions.loopType) { - case LinalgTilingLoopType::Loops: - case LinalgTilingLoopType::ParallelLoops: - case LinalgTilingLoopType::TiledLoops: - return tileAndFuseLinalgOpsImpl(b, ops, dependenceGraph, tilingOptions); - default:; - } - return failure(); -} diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -356,111 +356,6 @@ return tiledOp.loops.front()->getResults(); } -static ValueRange -getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) { - if (tiledAndFusedOp.fusedLoops.empty()) - return tiledAndFusedOp.op.getOperation()->getResults(); - return tiledAndFusedOp.fusedLoops.front()->getResults(); -} - -mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern( - StringRef opName, MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions, - LinalgTransformationFilter f, LinalgTransformationFilter fusedOpMarker, - LinalgTransformationFilter originalOpMarker, PatternBenefit benefit) - : RewritePattern(opName, benefit, context, {}), - dependenceGraph(dependenceGraph), tilingOptions(std::move(tilingOptions)), - fusionOptions(std::move(fusionOptions)), filter(std::move(f)), - fusedOpMarker(std::move(fusedOpMarker)), - originalOpMarker(std::move(originalOpMarker)) {} - -LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite( - Operation *op, PatternRewriter &rewriter) const { - LinalgOp linalgOp = dyn_cast(op); - // TODO: remove hasIndexSemantics check once index ops are supported. - if (!linalgOp || linalgOp.hasIndexSemantics()) - return failure(); - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - - DenseSet producers; - producers.insert(linalgOp); - for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { - Optional operandNumber = dependence.getIndexingOpViewOperandNum(); - // When looking at dependences into, indexingOp is always OpOperand. We - // could assert, but continue if this is not the case. - if (!operandNumber) - continue; - if (!fusionOptions.indicesToFuse.count(*operandNumber)) - continue; - if (isa(dependence.getDependentOp())) - producers.insert(dependence.getDependentOp()); - } - - SmallVector fusionOps; - for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie; - ++it) { - auto producerLinalgOp = dyn_cast(&(*it)); - if (producerLinalgOp && producers.count(producerLinalgOp)) - fusionOps.push_back(producerLinalgOp); - } - fusionOps.push_back(linalgOp); - - SmallVector tileSizes = - tilingOptions.tileSizeComputationFunction(rewriter, op); - LinalgTilingOptions instanceTilingOptions = tilingOptions; - instanceTilingOptions.setTileSizes(tileSizes); - Optional tiledAndFusedOps = tileAndFuseLinalgOps( - rewriter, fusionOps, dependenceGraph, instanceTilingOptions); - if (!tiledAndFusedOps) - return failure(); - - // Tile the unfused loops; - SmallVector unfusedLoopTileSizes; - Value zero = rewriter.create(op->getLoc(), 0); - for (const auto &tileSize : enumerate(tileSizes)) { - if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index())) - unfusedLoopTileSizes.push_back(zero); - else - unfusedLoopTileSizes.push_back(tileSize.value()); - } - // Tile the loop only if there is a non-zero tile size. - if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops()) - unfusedLoopTileSizes.resize(linalgOp.getNumLoops()); - if (llvm::any_of(unfusedLoopTileSizes, [](Value val) { - if (auto cst = val.getDefiningOp()) - return cst.value() != 0; - return true; - })) { - LinalgTilingOptions unfusedTilingOptions = tilingOptions; - unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); - FailureOr unfusedTiledOp = - tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); - if (failed(unfusedTiledOp)) - return failure(); - rewriter.replaceOp(tiledAndFusedOps->op, - getTiledOpResult(unfusedTiledOp.value())); - tiledAndFusedOps->op = unfusedTiledOp->op; - } - op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.value())); - - filter.replaceLinalgTransformationFilter(rewriter, - tiledAndFusedOps->op.getOperation()); - for (auto fusedOp : tiledAndFusedOps->fusedProducers) { - fusedOpMarker.replaceLinalgTransformationFilter(rewriter, - fusedOp.getOperation()); - } - for (auto origProducerOp : ArrayRef(fusionOps).drop_back()) { - originalOpMarker.replaceLinalgTransformationFilter( - rewriter, origProducerOp.getOperation()); - } - rewriter.updateRootInPlace(op, [&]() { - originalOpMarker.replaceLinalgTransformationFilter(rewriter, op); - }); - return success(); -} - /// Linalg tiling pattern. mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( MLIRContext *context, LinalgTilingOptions options, diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ /dev/null @@ -1,307 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s - -module { - func.func @basic_fusion(%arg0: memref, %arg1: memref, - %arg2: memref) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref) - linalg.matmul {__internal_linalg_transform__ = "basic_fusion"} - ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return - } -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 64)> -// CHECK: func @basic_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.0{{.*}} : f32 -// CHECK-DAG: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" -// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[ARG2]] -// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = -// CHECK-SAME: to (%[[M]], %[[N]]) -// CHECK-SAME: step (%[[C32]], %[[C64]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV1:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K]]] -// CHECK: %[[K_2:.+]] = memref.dim %[[ARG1]], %[[C0]] -// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP2]](%[[IV1]])[%[[N]]] -// CHECK: %[[SV2:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]] -// CHECK-SAME: %[[K_2]], %[[TILE_N]] -// CHECK: %[[SV3:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] -// CHECK: %[[M_2:.+]] = memref.dim %[[ARG2]], %[[C0]] -// CHECK: %[[N_2:.+]] = memref.dim %[[ARG2]], %[[C1]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_2]], %[[M]]] -// CHECK: %[[TILE_N_3:.+]] = affine.min #[[MAP5]](%[[IV1]])[%[[N_2]], %[[N]]] -// CHECK: %[[SV3_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N_3]]] -// CHECK: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" -// CHECK-SAME: ins(%[[CST]]{{.*}}outs(%[[SV3_2]] -// CHECK: scf.for %[[IV2:.+]] = %[[C0]] to %[[K]] step %[[C16]] { -// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP3]](%[[IV2]])[%[[K]]] -// CHECK: %[[SV4:.+]] = memref.subview %[[SV1]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] -// CHECK: %[[SV5:.+]] = memref.subview %[[SV2]][%[[IV2]], 0] -// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]] -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion" -// CHECK-SAME: ins(%[[SV4]], %[[SV5]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV3]] : memref) -// CHECK: } -// CHECK: } -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" - -// ----- - -module { - func.func @matmul_fusion(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3: memref, - %arg4: memref) { - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} - ins(%arg2, %arg3 : memref, memref) - outs(%arg4 : memref) - return - } -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)> -// CHECK: func @matmul_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original" -// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG2]], %[[C0]] -// CHECK: scf.parallel (%[[IV0:.+]]) = -// CHECK-SAME: (%[[C0]]) to (%[[M]]) step (%[[C32]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[K2:.+]] = memref.dim %[[ARG2]], %[[C1]] -// CHECK: %[[SV1:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[K2]]] -// CHECK: %[[N:.+]] = memref.dim %[[ARG4]], %[[C1]] -// CHECK: %[[SV2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N]]] -// CHECK: %[[M_3:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP4]](%[[IV0]])[%[[M_3]], %[[M]]] -// CHECK: %[[K1:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV3:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[K1]]] -// CHECK: %[[SV1_2:.+]] = memref.subview %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[K2]]] -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" -// CHECK-SAME: ins(%[[SV3]], %[[ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV1_2]] : memref) -// CHECK: %[[N_2:.+]] = memref.dim %[[ARG3]], %[[C1]] -// CHECK: scf.parallel (%[[IV1:.+]]) = -// CHECK-SAME: (%[[C0]]) to (%[[N_2]]) step (%[[C64]]) { -// CHECK-NEXT: scf.for %[[IV2:.+]] = %[[C0]] to %[[K2]] step %[[C16]] { -// CHECK: %[[TILE_K:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[K2]]] -// CHECK: %[[SV6:.+]] = memref.subview %[[SV1]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_K]]] -// CHECK: %[[TILE_N:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N_2]]] -// CHECK: %[[SV7:.+]] = memref.subview %[[ARG3]][%[[IV2]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_K]], %[[TILE_N]]] -// CHECK: %[[SV8:.+]] = memref.subview %[[SV2]][0, %[[IV1]]] -// CHECK-SAME: [%[[TILE_M]], %[[TILE_N]]] -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" -// CHECK-SAME: ins(%[[SV6]], %[[SV7]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV8]] : memref) -// CHECK: } -// CHECK: } -// CHECK: } -// CHECK: linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_original" - -// ----- - -module { - func.func @matmul_plus_matmul(%arg0: memref, %arg1: memref, - %arg2: memref) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg2, %c0 : memref - %1 = memref.dim %arg2, %c1 : memref - %2 = memref.alloc(%0, %1) : memref - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%2 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "transpose_fusion"} - ins(%2, %2 : memref, memref) - outs(%arg2 : memref) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : - %3 = arith.addf %arg3, %arg4 : f32 - linalg.yield %3 : f32 - } - return - } -} -// CHECK: func @matmul_plus_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK: %[[T2:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: linalg.matmul -// CHECK-SAME: after_transpose_fusion_original -// CHECK: scf.parallel (%[[ARG3:[a-zA-Z0-9_]+]], %[[ARG4:.[a-zA-Z0-9_]+]]) -// CHECK: %[[T5:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]] -// CHECK: %[[T6:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] -// CHECK: %[[T8:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0] -// CHECK: %[[T9:.+]] = memref.subview %[[ARG1]][0, %[[ARG4]]] -// CHECK: %[[T10:.+]] = memref.subview %[[T2]][%[[ARG3]], %[[ARG4]]] -// CHECK: linalg.matmul -// CHECK-SAME: after_transpose_fusion_producer -// CHECK-SAME: ins(%[[T8]], %[[T9]] -// CHECK-SAME: outs(%[[T10]] -// CHECK-NOT: linalg.matmul -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[T5]], %[[T5]] -// CHECK-SAME: outs(%[[T6]] -// CHECK-SAME: after_transpose_fusion - -// ----- - -module { - func.func @matmul_plus_transpose_matmul(%arg0: memref, - %arg1: memref, - %arg2: memref) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = memref.dim %arg2, %c0 : memref - %1 = memref.dim %arg2, %c1 : memref - %2 = memref.alloc(%0, %1) : memref - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%2 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1, d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "transpose_fusion"} - ins(%2, %2 : memref, memref) - outs(%arg2 : memref) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : - %3 = arith.addf %arg3, %arg4 : f32 - linalg.yield %3 : f32 - } - return - } -} -// CHECK-LABEL: func @matmul_plus_transpose_matmul -// CHECK-NOT: scf.parallel -// CHECK-NOT: scf.for -// CHECK: linalg.matmul -// CHECK-NOT: scf.parallel -// CHECK-NOT: scf.for -// CHECK: linalg.generic -// CHECK-NOT: scf.parallel -// CHECK-NOT: scf.for - -// ----- - -#map0 = affine_map<(d0)[s0] -> (32, -d0 + s0)> -#map1 = affine_map<(d0)[s0] -> (64, -d0 + s0)> -#map2 = affine_map<(d0)[s0] -> (16, -d0 + s0)> -#map3 = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -module { - func.func @basic_no_fusion(%arg0: memref, %arg1: memref, - %arg2: memref) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %c2 = arith.constant 2 : index - %c32 = arith.constant 32 : index - %c64 = arith.constant 64 : index - %c16 = arith.constant 16 : index - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref) - %0 = memref.dim %arg0, %c0 : memref - %1 = memref.dim %arg1, %c1 : memref - %2 = memref.dim %arg0, %c1 : memref - scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%0, %1) step (%c32, %c64) { - scf.for %arg5 = %c0 to %2 step %c16 { - %3 = affine.min #map0(%arg3)[%0] - %4 = affine.min #map1(%arg4)[%1] - %5 = affine.min #map2(%arg5)[%2] - %6 = memref.subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref to memref - %7 = memref.subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref to memref - %8 = memref.subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref to memref - linalg.matmul {__internal_linalg_transform__ = "basic_fusion"} - ins(%6, %7 : memref, memref) - outs(%8 : memref) - } - scf.yield - } - return - } -} -// CHECK-LABEL: func @basic_no_fusion -// CHECK-NOT: scf.parallel -// CHECK: linalg.fill -// CHECK: scf.parallel -// CHECK: scf.for -// CHECK-NOT: linalg.fill -// CHECK: linalg.matmul - -// ----- - -module { - func.func @basic_conv_fusion(%arg0: memref, %arg1: memref, - %arg2: memref) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.fill ins(%cst : f32) outs(%arg2 : memref) - linalg.conv_2d {__internal_linalg_transform__ = "basic_fusion"} - ins(%arg1, %arg0 : memref, memref) outs(%arg2 : memref) - return - } -} -// CHECK: func @basic_conv_fusion -// CHECK: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" -// CHECK: scf.parallel (%{{.+}}, %{{.+}}, %{{.+}}) -// CHECK-SAME: { -// CHECK: linalg.fill -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_producer" -// CHECK: linalg.conv_2d -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion" -// CHECK: } -// CHECK: linalg.conv_2d -// CHECK-SAME: __internal_linalg_transform__ = "after_basic_fusion_original" diff --git a/mlir/test/Dialect/Linalg/fusion-sequence.mlir b/mlir/test/Dialect/Linalg/fusion-sequence.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/fusion-sequence.mlir +++ /dev/null @@ -1,252 +0,0 @@ -// RUN: mlir-opt -pass-pipeline="func.func(test-linalg-tile-and-fuse{tile-sizes=16,32,64}),resolve-shaped-type-result-dims,canonicalize,cse" -split-input-file %s | FileCheck %s - -module { - func.func @three_op_fusion(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3 : memref) { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %d0 = memref.dim %arg0, %c0 : memref - %d1 = memref.dim %arg1, %c1 : memref - %0 = memref.alloc(%d0, %d1) : memref - linalg.fill ins(%cst : f32) outs(%0 : memref) - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0, %arg2 : memref, memref) - outs(%arg3 : memref) { - ^bb0(%arg4 : f32, %arg5 : f32, %arg6 : f32) : - %5 = arith.addf %arg4, %arg5 : f32 - linalg.yield %5 : f32 - } - return - } -} - -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK: func @three_op_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK: %[[TEMP:.+]] = memref.alloc(%{{.*}}, %{{.*}}) : memref -// CHECK: scf.parallel (%[[IV0:.+]], %[[IV1:.+]]) = {{.*}} { -// CHECK: %[[SV_TEMP_1:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG2:.+]] = memref.subview %[[ARG2]][%[[IV1]]] -// CHECK-DAG: %[[SV_ARG3:.+]] = memref.subview %[[ARG3]][%[[IV0]], %[[IV1]]] -// CHECK-DAG: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[SV_ARG1:.+]] = memref.subview %[[ARG1]][0, %[[IV1]]] -// CHECK: %[[SV_TEMP_2:.+]] = memref.subview %[[TEMP]][%[[IV0]], %[[IV1]]] -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_TEMP_1]] -// CHECK: linalg.matmul -// CHECK-SAME: ins(%[[SV_ARG0]], %[[SV_ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_TEMP_2]] : memref) -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[SV_TEMP_1]], %[[SV_ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG3]] : memref) -// CHECK: scf.yield -// CHECK: } - -// ----- - -module { - func.func @sequence_of_matmul(%arg0: memref, %arg1: memref, - %arg2: memref, %arg3: memref, - %arg4: memref) { - %cst = arith.constant 0.000000e+00 : f32 - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %m = memref.dim %arg0, %c0 : memref - %n1 = memref.dim %arg1, %c1 : memref - %n2 = memref.dim %arg2, %c1 : memref - %n3 = memref.dim %arg3, %c1 : memref - %0 = memref.alloc(%m, %n1) : memref - %1 = memref.alloc(%m, %n2) : memref - linalg.fill ins(%cst : f32) outs(%0 : memref) - linalg.matmul ins(%arg0, %arg1 : memref, memref) - outs(%0 : memref) - linalg.fill ins(%cst : f32) outs(%1 : memref) - linalg.matmul ins(%0, %arg2 : memref, memref) - outs(%1 : memref) - linalg.fill ins(%cst : f32) outs(%arg4 : memref) - linalg.matmul ins(%1, %arg3 : memref, memref) - outs(%arg4 : memref) - return - } -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)> - - -// CHECK: func @sequence_of_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: memref -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: memref -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[N1:.+]] = memref.dim %[[ARG1]], %[[C1]] -// CHECK-DAG: %[[N2:.+]] = memref.dim %[[ARG2]], %[[C1]] -// CHECK: %[[ALLOC1:.+]] = memref.alloc(%[[M]], %[[N1]]) -// CHECK: %[[ALLOC2:.+]] = memref.alloc(%[[M]], %[[N2]]) -// CHECK: scf.parallel (%[[IV0:.+]]) = (%[[C0]]) to (%[[M]]) -// CHECK-SAME: step (%[[C16]]) { -// CHECK: %[[TILE_M:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[SV_ALLOC3:.+]] = memref.subview %[[ALLOC2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N2]]] -// CHECK: %[[N3:.+]] = memref.dim %[[ARG4]], %[[C1]] -// CHECK: %[[SV_ARG4:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N3]]] -// CHECK: %[[M_2:.+]] = memref.dim %[[ARG4]], %[[C0]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M_2]], %[[M]]] -// CHECK: %[[SV_ARG4_2:.+]] = memref.subview %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N3]]] -// CHECK: %[[SV_ALLOC1:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M]], %[[N1]]] -// CHECK: %[[TILE_M_5:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]], %[[M]]] -// CHECK: %[[N0:.+]] = memref.dim %[[ARG0]], %[[C1]] -// CHECK: %[[SV_ARG0:.+]] = memref.subview %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_5]], %[[N0]]] -// CHECK: %[[SV_ALLOC4:.+]] = memref.subview %[[ALLOC1]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_5]], %[[N1]]] -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC1]] -// CHECK: linalg.matmul ins(%[[SV_ARG0]], %[[ARG1]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC4]] : memref) -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ALLOC3]] -// CHECK: linalg.matmul ins(%[[SV_ALLOC1]], %[[ARG2]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ALLOC3]] : memref) -// CHECK: linalg.fill ins(%{{.+}}{{.*}}outs(%[[SV_ARG4_2]] -// CHECK: linalg.matmul ins(%[[SV_ALLOC3]], %[[ARG3]] -// CHECK-SAME: : memref, memref) -// CHECK-SAME: outs(%[[SV_ARG4]] : memref) -// CHECK: scf.yield -// CHECK: } - - -// ----- - -module { - func.func @tensor_op_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %arg3: tensor) - -> tensor { - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - %1 = tensor.dim %0, %c0 : tensor - %2 = tensor.dim %0, %c1 : tensor - %3 = linalg.init_tensor [%1, %2] : tensor - %4 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0, %arg3 : tensor, tensor) - outs(%3 : tensor) { - ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): - %5 = arith.addf %arg4, %arg5 : f32 - linalg.yield %5 : f32 - } -> tensor - return %4 : tensor - } -} -// CHECK-LABEL: func @tensor_op_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[INIT:.+]] = linalg.init_tensor -// CHECK: %[[R0:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG5:.+]] = %[[INIT]]) -> (tensor) { -// CHECK: %[[R1:.+]] = scf.for %{{.+}} to %{{.+}} step %{{.+}} iter_args(%[[ARG7:.+]] = %[[ARG5]]) -> (tensor) { -// CHECK-DAG: %[[STARG3:.+]] = tensor.extract_slice %[[ARG3]] -// CHECK-DAG: %[[STARG7:.+]] = tensor.extract_slice %[[ARG7]] -// CHECK-DAG: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]] -// CHECK-DAG: %[[STARG1:.+]] = tensor.extract_slice %[[ARG1]] -// CHECK-DAG: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]] -// CHECK: %[[T0:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[STARG0]], %[[STARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[STARG2]] : tensor) -> tensor -// CHECK: %[[T1:.+]] = linalg.generic -// CHECK-SAME: ins(%[[T0:.+]], %[[STARG3]] : tensor, tensor) -// CHECK-SAME: outs(%[[STARG7]] : tensor) -// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[T1]] into %[[ARG7]] -// CHECK: scf.yield %[[RESULT]] -// CHECK: } -// CHECK: scf.yield %[[R1]] -// CHECK: } -// CHECK: return %[[R0]] - -// ----- - -module { - func.func @tensor_matmul_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor, %arg3: tensor, - %arg4: tensor, %arg5: tensor, - %arg6: tensor) -> tensor { - %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] - %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) - outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] - %2 = linalg.matmul ins(%1, %arg5 : tensor, tensor) - outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] - return %2 : tensor - } -} - -// CHECK: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)> - -// CHECK: func @tensor_matmul_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor) -> tensor { -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK: %[[M:.+]] = tensor.dim %[[ARG0]], %c0 : tensor -// CHECK: %[[R0:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] = -// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor) { -// CHECK: %[[TILE_M_1:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] -// CHECK: %[[N3:.+]] = tensor.dim %[[ARG8]], %[[C1]] -// CHECK: %[[STARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_1]], %[[N3]]] -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]], %[[M]]] -// CHECK: %[[N2:.+]] = tensor.dim %[[ARG4]], %[[C1]] -// CHECK: %[[STARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N2]]] -// CHECK: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[STARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N0]]] -// CHECK: %[[N1:.+]] = tensor.dim %[[ARG2]], %[[C1]] -// CHECK: %[[STARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N1]]] -// CHECK: %[[T0:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[STARG0]], %[[ARG1]] : tensor, tensor -// CHECK-SAME: ) outs(%[[STARG2]] : tensor) -// CHECK: %[[T1:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[T0]], %arg3 : tensor, tensor -// CHECK-SAME: ) outs(%[[STARG4]] : tensor) -// CHECK: %[[T2:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[T1]], %arg5 : tensor, tensor -// CHECK-SAME: ) outs(%[[STARG6]] : tensor) -// CHECK: %[[R1:.+]] = tensor.insert_slice %[[T2]] -// CHECK-SAME: into %[[ARG8]][%[[IV0]], 0] [%[[TILE_M_1]], %[[N3]]] -// CHECK: scf.yield %[[R1]] : tensor -// CHECK: } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ /dev/null @@ -1,193 +0,0 @@ -// RUN: mlir-opt %s -test-linalg-tensor-fusion-transform-patterns -resolve-shaped-type-result-dims -canonicalize -cse --split-input-file | FileCheck %s - -module { - func.func @matmul_fusion(%A: tensor, %B: tensor, - %AB_init: tensor, %C: tensor, - %ABC_init: tensor) -> tensor { - %AB = linalg.matmul ins(%A, %B : tensor, tensor) - outs(%AB_init : tensor) -> tensor // - %ABC = linalg.matmul {__internal_linalg_transform__ = "lhs_fusion"} - ins(%AB, %C : tensor, tensor) - outs(%ABC_init : tensor) -> tensor // - return %ABC : tensor - } -} -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 32)> -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 64)> -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 32)> - -// CHECK: func @matmul_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor - -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index -// CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index -// CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index -// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[M]] step %[[C32]] -// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -> (tensor) { -// CHECK: %[[TILE_M_2:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[M]]] -// CHECK: %[[N3:.+]] = tensor.dim %[[ARG6]], %[[C1]] -// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_2]], %[[N3]]] -// CHECK: %[[TILE_M_3:.+]] = affine.min #[[MAP5]](%[[IV0]])[%[[M]], %[[M]]] -// CHECK: %[[N1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N1]]] -// CHECK: %[[N2_2:.+]] = tensor.dim %[[ARG2]], %[[C1]] -// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], 0] -// CHECK-SAME: [%[[TILE_M_3]], %[[N2_2]]] -// CHECK: %[[LHS:.+]] = linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion_producer" -// CHECK-SAME: ins(%[[ST_ARG0]], %[[ARG1]] : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) -// CHECK: %[[N2:.+]] = tensor.dim %[[ARG1]], %[[C1]] -// CHECK: %[[N3_2:.+]] = tensor.dim %[[ARG3]], %[[C1]] -// CHECK: %[[YIELD0:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[N3_2]] step %[[C64]] -// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ST_ARG6]]) -> (tensor) { -// CHECK: %[[YIELD1:.+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = -// CHECK-SAME: %[[C0]] to %[[N2]] step %[[C16]] -// CHECK-SAME: iter_args(%[[ARG10:.+]] = %[[ARG8]]) -> (tensor) { -// CHECK: %[[TILE_N2:.+]] = affine.min #[[MAP2]](%[[IV2]])[%[[N2]]] -// CHECK: %[[ST_LHS:.+]] = tensor.extract_slice %[[LHS]][0, %[[IV2]]] -// CHECK-SAME: [%[[TILE_M_3]], %[[TILE_N2]]] -// CHECK: %[[TILE_N3:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N3_2]]] -// CHECK: %[[ST_ARG3:.+]] = tensor.extract_slice %[[ARG3]][%[[IV2]], %[[IV1]]] -// CHECK-SAME: [%[[TILE_N2]], %[[TILE_N3]]] -// CHECK: %[[M_4:.+]] = tensor.dim %[[ARG10]], %[[C0]] -// CHECK: %[[ST_ARG4:.+]] = tensor.extract_slice %[[ARG10]][0, %[[IV1]]] -// CHECK-SAME: [%[[M_4]], %[[TILE_N3]]] -// CHECK: %[[ST_RESULT:.+]] = linalg.matmul -// CHECK-SAME: __internal_linalg_transform__ = "after_lhs_fusion" -// CHECK-SAME: ins(%[[ST_LHS]], %[[ST_ARG3]] -// CHECK-SAME: : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG4]] : tensor) -// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[ST_RESULT]] -// CHECK-SAME: into %[[ARG10]][0, %[[IV1]]] [%[[M_4]], %[[TILE_N3]]] -// CHECK: scf.yield %[[UPDATE1]] -// CHECK: } -// CHECK: scf.yield %[[YIELD1]] -// CHECK: } -// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[YIELD0]] into -// CHECK-SAME: %[[ARG6]][%[[IV0]], 0] [%[[TILE_M_2]], %[[N3]]] -// CHECK: scf.yield %[[UPDATE0]] -// CHECK: } -// CHECK: return %[[RESULT]] - -// ----- - -module { - func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor{ - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %arg2, %c0 : tensor - %1 = tensor.dim %arg2, %c1 : tensor - %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) -> tensor - %3 = tensor.dim %2, %c0 : tensor - %4 = tensor.dim %2, %c1 : tensor - %5 = linalg.init_tensor [%3, %4] : tensor - %6 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_linalg_transform__ = "transpose_fusion"} - ins(%2, %2 : tensor, tensor) - outs(%5 : tensor) { - ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : - %7 = arith.addf %arg3, %arg4 : f32 - linalg.yield %7 : f32 - } -> tensor - return %6 : tensor - } -} -// CHECK: func @matmul_plus_matmul -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] -// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) -// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] -// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) -// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] -// CHECK: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[LHS:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] -// CHECK-SAME: : tensor, tensor) -// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) -// CHECK: %[[ST_RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[LHS]] : tensor) -// CHECK-SAME: outs(%[[ST_ARG6]] : tensor) -// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] -// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] -// CHECK: scf.yield %[[UPDATE]] -// CHECK: scf.yield %[[YIELD]] -// CHECK: return %[[RESULT]] - -// ----- - -module { - func.func @matmul_out_fusion(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %0 = linalg.fill ins(%c0 : f32) outs(%arg0 : tensor) -> tensor - %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} - ins(%arg1, %arg2 : tensor, tensor) - outs(%0 : tensor) -> tensor - return %1 : tensor - } -} - -// CHECK-LABEL: func @matmul_out_fusion( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[C0:.*]] = arith.constant 0.0{{.*}} : f32 -// CHECK-NOT: fill -// CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor) { -// CHECK: scf.for %[[J:.*]] -// CHECK: %[[ST:.*]] = tensor.extract_slice %[[ARG0]] -// CHECK: %[[ST_FILL:.*]] = linalg.fill -// CHECK-SAME: {__internal_linalg_transform__ = "after_out_fusion_producer"} -// CHECK-SAME: ins(%[[C0]] : f32) outs(%[[ST]] : tensor) -> tensor -// CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor) { -// CHECK-NOT: fill -// CHECK: %[[ST_FILL_SUB:.*]] = tensor.extract_slice %[[BB]][0, 0] -// CHECK: %[[ST_MM_SUB:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%[[ST_FILL_SUB]] : tensor) -> tensor -// CHECK: %[[ST_MM:.*]] = tensor.insert_slice %[[ST_MM_SUB]] into %[[BB]] -// CHECK: scf.yield %[[ST_MM]] : tensor -// CHECK: %[[MM:.*]] = tensor.insert_slice %[[ST_MM_RES]] into {{.*}} -// CHECK: scf.yield %[[MM]] : tensor - -// ----- - -module { - func.func @generic_plus_matmul(%arg0: tensor, %arg1: tensor, - %arg2: tensor) -> tensor { - %c0 = arith.constant 0.0 : f32 - %0 = linalg.generic { - indexing_maps = [affine_map<(m, n) -> ()>, affine_map<(m, n) -> (m, n)>], - iterator_types = ["parallel", "parallel"]} - ins(%c0 : f32) - outs(%arg0: tensor) { - ^bb(%0: f32, %1: f32) : - linalg.yield %0 : f32 - } -> tensor - %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} - ins(%arg1, %arg2 : tensor, tensor) - outs(%0 : tensor) -> tensor - return %1 : tensor - } -} diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -183,3 +183,180 @@ // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] // CHECK scf.yield %[[INSERT]] + +// ----- + +module { + func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, + %arg2: tensor) -> tensor{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %arg2, %c0 : tensor + %1 = tensor.dim %arg2, %c1 : tensor + %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %3 = tensor.dim %2, %c0 : tensor + %4 = tensor.dim %2, %c1 : tensor + %5 = linalg.init_tensor [%3, %4] : tensor + %6 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "gemm_plus_gemm_fusion"} + ins(%2, %2 : tensor, tensor) + outs(%5 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : + %7 = arith.addf %arg3, %arg4 : f32 + linalg.yield %7 : f32 + } -> tensor + return %6 : tensor + } +} +// This fuses as expected but the gemm operation is inlined twice. It should be CSE-d but isnt today. + +// CHECK: func @matmul_plus_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) +// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) +// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : +// CHECK-SAME: outs(%[[ST_ARG2]] : +// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[RHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] : +// CHECK-SAME: outs(%[[ST_ARG2_1]] : +// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: %[[ST_RESULT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[ST_ARG6]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] +// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +module { + func.func @matmul_plus_transpose_matmul(%arg0: tensor, %arg1: tensor, + %arg2: tensor) -> tensor{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %arg2, %c0 : tensor + %1 = tensor.dim %arg2, %c1 : tensor + %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + %3 = tensor.dim %2, %c0 : tensor + %4 = tensor.dim %2, %c1 : tensor + %5 = linalg.init_tensor [%3, %4] : tensor + %6 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d1, d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "gemm_plus_gemm_fusion"} + ins(%2, %2 : tensor, tensor) + outs(%5 : tensor) { + ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : + %7 = arith.addf %arg3, %arg4 : f32 + linalg.yield %7 : f32 + } -> tensor + return %6 : tensor + } +} +// CHECK: func @matmul_plus_transpose_matmul +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG4:.+]] = %{{[a-zA-Z0-9_]+}}) +// CHECK: %[[YIELD:.+]] = scf.for %[[IV1:[a-zA-Z0-9_]+]] +// CHECK-SAME: iter_args(%[[ARG6:.+]] = %[[ARG4]]) +// CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] +// CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] +// CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] +// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] +// CHECK-SAME: : tensor, tensor) +// CHECK-SAME: outs(%[[ST_ARG2]] : tensor) +// CHECK-DAG: %[[STR_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] +// CHECK-DAG: %[[STR_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV0]]] +// CHECK-DAG: %[[STR_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV1]], %[[IV0]]] +// CHECK: %[[RHS:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[STR_ARG0]], %[[STR_ARG1]] : +// CHECK-SAME: outs(%[[STR_ARG2]] : +// CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: %[[ST_RESULT:.+]] = linalg.generic +// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: outs(%[[ST_ARG6]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] +// CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +module { + func.func @matmul_sequence_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor, %arg3: tensor, %arg4: tensor, + %arg5: tensor, %arg6: tensor) -> tensor { + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] + %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) + outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] + %2 = linalg.matmul + {__internal_linalg_transform__ = "gemm_sequence_fusion"} + ins(%1, %arg5 : tensor, tensor) + outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] + return %2 : tensor + } +} + +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @matmul_sequence_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor) -> tensor { +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] : +// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]] +// CHECK-DAG: %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] : +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]] +// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]] +// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]] +// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] = +// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor) { +// CHECK-DAG: %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%{{.+}}, %[[M]]] +// CHECK-DAG: %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]] +// CHECK-DAG: %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]] +// CHECK-DAG: %[[SLICE_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV]], 0] [%[[TILE_M]], %[[N1]]] +// CHECK-DAG: %[[TILE_GEMM1:.+]] = linalg.matmul ins(%[[SLICE_ARG0]], %[[SLICE_ARG1]] : +// CHECK-SAME: outs(%[[SLICE_ARG2]] : +// CHECK-DAG: %[[SLICE_ARG3:.+]] = tensor.extract_slice %[[ARG3]][0, 0] [%[[N1]], %[[N2]]] +// CHECK-DAG: %[[SLICE_ARG4:.+]] = tensor.extract_slice %[[ARG4]][%[[IV]], 0] [%[[TILE_M]], %[[N2]]] +// CHECK-DAG: %[[TILE_GEMM2:.+]] = linalg.matmul ins(%[[TILE_GEMM1]], %[[SLICE_ARG3]] : +// CHECK-SAME: outs(%[[SLICE_ARG4]] : +// CHECK-DAG: %[[SLICE_ARG5:.+]] = tensor.extract_slice %[[ARG5]][0, 0] [%[[N2]], %[[N3]]] +// CHECK-DAG: %[[SLICE_ARG6:.+]] = tensor.extract_slice %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK-DAG: %[[TILE_GEMM3:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[TILE_GEMM2]], %[[SLICE_ARG5]] : +// CHECK-SAME: outs(%[[SLICE_ARG6]] : +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] +// CHECK: scf.yield %[[UPDATE]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFusionTransforms.cpp @@ -23,130 +23,6 @@ using namespace mlir; using namespace mlir::linalg; -/// Use this to safely fill patterns for this test, since RewritePatternSet::add -/// forwards Rvalues only to the first pattern. -template -static void fillFusionPattern(MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - RewritePatternSet &patterns, - const Twine &testCase, - ArrayRef tileSizes, - ArrayRef indicesToFuse) { - patterns.add>( - context, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(LoopType), - LinalgFusionOptions().setIndicesToFuse(indicesToFuse), - LinalgTransformationFilter( - StringAttr::get(context, testCase + "_fusion"), - StringAttr::get(context, "after_" + testCase + "_fusion")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_" + testCase + "_fusion_producer")), - LinalgTransformationFilter( - ArrayRef(), - StringAttr::get(context, "after_" + testCase + "_fusion_original"))); -} - -template -static void fillFusionPatterns(MLIRContext *context, - const LinalgDependenceGraph &dependenceGraph, - RewritePatternSet &patterns) { - fillFusionPattern(context, dependenceGraph, patterns, - /*testCase=*/"basic", - /*tileSizes=*/{32, 64, 16}, - /*indicesToFuse=*/{2}); - - auto fillMatmulPattern = [&](const Twine &testCase, - ArrayRef indicesToFuse) { - fillFusionPattern(context, dependenceGraph, patterns, - testCase, /*tileSizes=*/{32, 64, 16}, - indicesToFuse); - }; - fillMatmulPattern(/*testCase=*/"basic", - /*indicesToFuse=*/{2}); - fillMatmulPattern(/*testCase=*/"lhs", - /*indicesToFuse=*/{0}); - fillMatmulPattern(/*testCase=*/"out", - /*indicesToFuse=*/{2}); - fillMatmulPattern(/*testCase=*/"rhs", - /*indicesToFuse=*/{1}); - fillMatmulPattern(/*testCase=*/"two_operand", - /*indicesToFuse=*/{0, 2}); - - fillFusionPattern(context, dependenceGraph, patterns, - /*testCase=*/"transpose", - /*tileSizes=*/{32, 64}, - /*indicesToFuse=*/{0, 1}); -} - -namespace { -template -struct TestLinalgFusionTransforms - : public PassWrapper, - OperationPass> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransforms) - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - TestLinalgFusionTransforms() = default; - TestLinalgFusionTransforms(const TestLinalgFusionTransforms &pass) {} - - void runOnOperation() override { - MLIRContext *context = &this->getContext(); - func::FuncOp funcOp = this->getOperation(); - RewritePatternSet fusionPatterns(context); - Aliases alias; - LinalgDependenceGraph dependenceGraph = - LinalgDependenceGraph::buildDependenceGraph(alias, funcOp); - fillFusionPatterns(context, dependenceGraph, fusionPatterns); - (void)applyPatternsAndFoldGreedily(funcOp, std::move(fusionPatterns)); - } -}; - -struct TestLinalgFusionTransformsParallelLoops - : public TestLinalgFusionTransforms { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestLinalgFusionTransformsParallelLoops) - - StringRef getArgument() const final { - return "test-linalg-fusion-transform-patterns"; - } - StringRef getDescription() const final { - return "Test Linalg fusion transformation patterns by applying them " - "greedily."; - } -}; - -struct TestLinalgFusionTransformsLoops - : public TestLinalgFusionTransforms { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFusionTransformsLoops) - - StringRef getArgument() const final { - return "test-linalg-tensor-fusion-transform-patterns"; - } - StringRef getDescription() const final { - return "Test Linalg on tensor fusion transformation " - "patterns by applying them greedily."; - } -}; - -struct TestLinalgFusionTransformsTiledLoops - : public TestLinalgFusionTransforms { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestLinalgFusionTransformsTiledLoops) - - StringRef getArgument() const final { - return "test-linalg-tiled-loop-fusion-transform-patterns"; - } - StringRef getDescription() const final { - return "Test Linalg on tensor fusion transformation " - "patterns by applying them greedily."; - } -}; -} // namespace - static LogicalResult fuseLinalgOpsGreedily(func::FuncOp f) { OpBuilder b(f); DenseSet eraseSet; @@ -236,82 +112,13 @@ } while (succeeded(fuseLinalgOpsGreedily(getOperation()))); } }; - -/// Pass to test tile and fuse of sequence of operations. Intended only for -/// testing. -struct TestLinalgTileAndFuseSequencePass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( - TestLinalgTileAndFuseSequencePass) - - StringRef getArgument() const final { return "test-linalg-tile-and-fuse"; } - StringRef getDescription() const final { - return "Test Linalg tiling and fusion of a sequence of Linalg operations."; - } - TestLinalgTileAndFuseSequencePass() = default; - TestLinalgTileAndFuseSequencePass( - const TestLinalgTileAndFuseSequencePass &pass) - : PassWrapper(pass){}; - - ListOption tileSizes{*this, "tile-sizes", - llvm::cl::desc("Tile sizes to use for ops")}; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } - - void runOnOperation() override { - func::FuncOp funcOp = getOperation(); - auto &blocks = funcOp.getBody().getBlocks(); - if (!llvm::hasSingleElement(blocks)) { - return; - } - SmallVector linalgOps = - llvm::to_vector<2>(blocks.front().getOps()); - Aliases aliases; - LinalgDependenceGraph dependenceGraph(aliases, linalgOps); - OpBuilder builder(funcOp.getContext()); - linalg::LinalgTilingLoopType loopType = LinalgTilingLoopType::ParallelLoops; - if (llvm::any_of(linalgOps, [](LinalgOp linalgOp) { - return linalgOp.hasTensorSemantics(); - })) - loopType = LinalgTilingLoopType::Loops; - Optional tileAndFuseOps = tileAndFuseLinalgOps( - builder, linalgOps, dependenceGraph, - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType)); - if (!tileAndFuseOps) - return signalPassFailure(); - if (linalgOps.back().hasTensorSemantics()) { - linalgOps.back().getOperation()->replaceAllUsesWith( - tileAndFuseOps->fusedLoops.front()); - } - for (auto op : linalgOps) - if (op.hasBufferSemantics()) - op.erase(); - } -}; - } // namespace namespace mlir { namespace test { -void registerTestLinalgFusionTransforms() { - PassRegistration(); -} -void registerTestLinalgTensorFusionTransforms() { - PassRegistration(); -} -void registerTestLinalgTiledLoopFusionTransforms() { - PassRegistration(); -} void registerTestLinalgGreedyFusion() { PassRegistration(); } -void registerTestLinalgTileAndFuseSequencePass() { - PassRegistration(); -} } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -184,6 +184,12 @@ addPatternForTiling< TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( context, {10}, "gemm_fusion", patterns); + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, {10, 20}, "gemm_plus_gemm_fusion", patterns); + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, {10}, "gemm_sequence_fusion", patterns); return; } } diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -88,12 +88,8 @@ void registerTestLastModifiedPass(); void registerTestLinalgDecomposeOps(); void registerTestLinalgElementwiseFusion(); -void registerTestLinalgFusionTransforms(); -void registerTestLinalgTensorFusionTransforms(); -void registerTestLinalgTiledLoopFusionTransforms(); void registerTestLinalgGreedyFusion(); void registerTestLinalgHoisting(); -void registerTestLinalgTileAndFuseSequencePass(); void registerTestLinalgTransforms(); void registerTestLivenessPass(); void registerTestLoopFusion(); @@ -187,12 +183,8 @@ mlir::test::registerTestLastModifiedPass(); mlir::test::registerTestLinalgDecomposeOps(); mlir::test::registerTestLinalgElementwiseFusion(); - mlir::test::registerTestLinalgFusionTransforms(); - mlir::test::registerTestLinalgTensorFusionTransforms(); - mlir::test::registerTestLinalgTiledLoopFusionTransforms(); mlir::test::registerTestLinalgGreedyFusion(); mlir::test::registerTestLinalgHoisting(); - mlir::test::registerTestLinalgTileAndFuseSequencePass(); mlir::test::registerTestLinalgTransforms(); mlir::test::registerTestLivenessPass(); mlir::test::registerTestLoopFusion();