diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -229,6 +229,10 @@ ArrayRef depTypes = { DependenceType::RAW, DependenceType::WAW}) const; + void print(raw_ostream &os) const; + + void dump() const; + private: // Keep dependences in both directions, this is not just a performance gain // but it also reduces usage errors. diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -162,6 +162,8 @@ } void LinalgDependenceGraph::addDependencesBetween(LinalgOp src, LinalgOp dst) { + LLVM_DEBUG(dbgs() << "addDependencesBetween " << *src.getOperation() + << " and " << *dst.getOperation() << "\n"); if (src.hasTensorSemantics() && dst.hasTensorSemantics()) { for (OpOperand &dstOpOperand : dst.getInputOpOperands()) { // Check if the operand is defined by the src. @@ -170,6 +172,18 @@ addDependenceElem(DependenceType::RAW, dstOpOperand.get(), &dstOpOperand); } + for (OpOperand &dstOpOperand : dst.getOutputOpOperands()) { + // Check if the operand is defined by the src. + auto definingOp = dstOpOperand.get().getDefiningOp(); + if (definingOp && definingOp == src) { + if (dst.isInitTensor(&dstOpOperand)) { + addDependenceElem(DependenceType::RAW, dstOpOperand.get(), + &dstOpOperand); + } + addDependenceElem(DependenceType::WAW, dstOpOperand.get(), + &dstOpOperand); + } + } return; } assert(src.hasBufferSemantics() && dst.hasBufferSemantics() && @@ -322,3 +336,21 @@ dependentOperations.append(t.begin(), t.end()); return dependentOperations; } + +void LinalgDependenceGraph::print(raw_ostream &os) const { + for (auto dt : { + LinalgDependenceGraph::DependenceType::RAW, + LinalgDependenceGraph::DependenceType::WAW, + }) { + const auto &fromGraph = dependencesFromGraphs[dt]; + for (const auto &it : fromGraph) { + os << "[LinalgDependenceGraph] DT " << dt << " from: " << *it.first + << ":\n"; + for (const auto &dep : it.second) { + os << "\tDT " << dt << " " << *dep.getDependentOp() << ":\n"; + } + } + } +} + +void LinalgDependenceGraph::dump() const { print(llvm::errs()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -29,6 +29,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -331,6 +332,10 @@ static Optional findFusableProducer(OpOperand &consumerOpOperand, const LinalgDependenceGraph &dependenceGraph) { + LLVM_DEBUG(llvm::dbgs() << "findFusableProducer for: " + << consumerOpOperand.get() << " @" + << consumerOpOperand.getOperandNumber() << " in " + << *consumerOpOperand.getOwner() << "\n"); LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); if (!consumerOp) return {}; @@ -340,9 +345,14 @@ LinalgDependenceGraph::DependenceType::RAW, LinalgDependenceGraph::DependenceType::WAW, }) { + LLVM_DEBUG(llvm::dbgs() + << "Dependencies into: " << *consumerOp.getOperation() << "\n"); for (auto dependence : llvm::make_filter_range( dependenceGraph.getDependencesInto(consumerOp, depType), [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { + LLVM_DEBUG(llvm::dbgs() << "Inspect dependence btw: " + << elem.getIndexingValue() << " and " + << elem.getDependentValue() << "\n"); Value v = elem.getIndexingValue(); Optional operandNum = elem.getIndexingOpViewOperandNum(); @@ -783,12 +793,14 @@ /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of /// `tiledOp`. static SmallVector -fuseOperations(OpBuilder &builder, LinalgOp rootOp, LinalgOp tiledOp, +fuseOperations(OpBuilder &builder, LinalgOp rootOp, TiledLinalgOp tiledLinalgOp, ArrayRef fusionCandidates, const FusableOpDependencesTy &fusableDependences, const std::set &fusedLoops) { + LinalgOp tiledOp = tiledLinalgOp.op; OpBuilder::InsertionGuard guard(builder); builder.setInsertionPoint(tiledOp); + DenseMap fusedLoopsAndRanges; for (unsigned loop : fusedLoops) { ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true); @@ -804,27 +816,49 @@ LinalgOp fusedOp = fuse(builder, origOp, fusedLoopsAndRanges); origOpToFusedOp[origOp.getOperation()] = fusedOp; fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp; + + // Prepare the builder for the next insertion point. + auto guard = + llvm::make_scope_exit([&]() { builder.setInsertionPoint(fusedOp); }); + if (!origOp.hasTensorSemantics()) + continue; + // If the producer consumer operations are linalg operations on tensors, the // dependence is due to value produced (as a return tensor) by the producer // and used in the consumer. The returned value of the fused op needs to be // made the operand of the tiled/fused consumer operation. By construction // the value returned by the producer is the value used by the consumer. for (auto &dependence : fusableDependences.lookup(origOp.getOperation())) { - if (origOp.hasTensorSemantics() && - dependence.dependenceType == - LinalgDependenceGraph::DependenceType::RAW) { - unsigned resultIndex = - dependence.getDependentOpViewResultNum().getValue(); - LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); - if (!consumer) - continue; - Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); - consumer.getOperation()->setOperand( - dependence.getIndexingOpViewOperandNum().getValue(), - replacementValue); - } + if (dependence.dependenceType != + LinalgDependenceGraph::DependenceType::RAW) + continue; + + unsigned resultIndex = + dependence.getDependentOpViewResultNum().getValue(); + LinalgOp consumer = origOpToFusedOp.lookup(dependence.getIndexingOp()); + if (!consumer) + continue; + + Value replacementValue = fusedOp.getOperation()->getResult(resultIndex); + consumer.getOperation()->setOperand( + dependence.getIndexingOpViewOperandNum().getValue(), + replacementValue); } - builder.setInsertionPoint(fusedOp); + + // At this point, all Linalg uses of the tensors produced by `origOp` have + // been replaced. However, there may still be "output tensor"-like uses + // coming from WAW dependencies. + // All these uses are iter_args of the outermost loop (TODO: add a check). + // Such iter_args uses serve 2 purposes: + // 1. give a shape to the output + // 2. encode destructive updates that may be inplaceable by bufferization. + // To keep the second type of information while letting the unfused op die + // unused, we need to forward the producer output operand. + for (auto &operand : + cast(tiledLinalgOp.loops.front()).getIterOpOperands()) + if (auto opResult = operand.get().dyn_cast()) + if (opResult.getOwner() == origOp) + operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]); } return fusedOps; } @@ -860,18 +894,23 @@ ScopedContext scope(builder, rootOp.getLoc()); // Find all the producers. + LLVM_DEBUG(llvm::dbgs() << "findAllFusableDependences\n"); FusableOpDependencesTy fusableDependences = findAllFusableDependences(ops, dependenceGraph); - if (fusableDependences.empty()) + if (fusableDependences.empty()) { + LLVM_DEBUG(llvm::dbgs() << "no fusable dependencies found\n"); return llvm::None; + } TiledAndFusedLinalgOps ret; // Find the loops that can be tiled and fused. + LLVM_DEBUG(llvm::dbgs() << "collectFusableLoops\n"); ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences); // If there are no fusable dependences or there are no tile+fusable loops, // just return. if (ret.fusedLoopDims.empty()) { + LLVM_DEBUG(llvm::dbgs() << "no fusable loops found\n"); return llvm::None; } @@ -888,8 +927,9 @@ ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end()); // Fuse the other operations into the fused inter-tile loops produced above. - ret.fusedProducers = fuseOperations(builder, rootOp, ret.op, ops.drop_back(), - fusableDependences, ret.fusedLoopDims); + ret.fusedProducers = + fuseOperations(builder, rootOp, *tiledRootOp, ops.drop_back(), + fusableDependences, ret.fusedLoopDims); return ret; } diff --git a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor-pattern.mlir @@ -143,3 +143,34 @@ // CHECK: scf.yield %[[UPDATE]] // CHECK: scf.yield %[[YIELD]] // CHECK: return %[[RESULT]] + +// ----- + +module { + func @matmul_out_fusion(%arg0: tensor, %arg1: tensor, + %arg2: tensor) -> tensor { + %c0 = constant 0.0 : f32 + %0 = linalg.fill(%arg0, %c0) : tensor, f32 -> tensor + %1 = linalg.matmul {__internal_linalg_transform__ = "out_fusion"} + ins(%arg1, %arg2 : tensor, tensor) + outs(%0 : tensor) -> tensor + return %1 : tensor + } +} + +// CHECK-LABEL: func @matmul_out_fusion( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[C0:.*]] = constant 0.0{{.*}} : f32 +// CHECK-NOT: fill +// CHECK: scf.for %[[I:.*]]{{.*}}iter_args(%{{.*}} = %[[ARG0]]) -> (tensor) { +// CHECK: scf.for %[[J:.*]] +// CHECK: %[[ST:.*]] = subtensor %[[ARG0]] +// CHECK: %[[ST_FILL:.*]] = linalg.fill(%[[ST]], %[[C0]]) {__internal_linalg_transform__ = "after_out_fusion_producer"} : tensor, f32 -> tensor +// CHECK: %[[ST_MM_RES:.*]] = scf.for %[[K:.*]]{{.*}}iter_args(%[[BB:.*]] = %[[ST_FILL]]) -> (tensor) { +// CHECK-NOT: fill +// CHECK: %[[ST_MM:.*]] = linalg.matmul {__internal_linalg_transform__ = "after_out_fusion"} ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%[[BB]] : tensor) -> tensor +// CHECK: scf.yield %[[ST_MM]] : tensor +// CHECK: %[[MM:.*]] = subtensor_insert %[[ST_MM_RES]] into {{.*}} +// CHECK: scf.yield %[[MM]] : tensor diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -52,6 +52,19 @@ ArrayRef(), Identifier::get("after_lhs_fusion_original", context))); + patterns.add>( + context, dependenceGraph, + LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType), + LinalgFusionOptions().setIndicesToFuse({2}), + LinalgTransformationFilter(Identifier::get("out_fusion", context), + Identifier::get("after_out_fusion", context)), + LinalgTransformationFilter( + ArrayRef(), + Identifier::get("after_out_fusion_producer", context)), + LinalgTransformationFilter( + ArrayRef(), + Identifier::get("after_out_fusion_original", context))); + patterns.add>( context, dependenceGraph, LinalgTilingOptions().setTileSizes({32, 64, 16}).setLoopType(LoopType),