diff --git a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h --- a/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h +++ b/mlir/include/mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h @@ -48,16 +48,104 @@ // TODO: OpOperand tracks dependencies on buffer operands. Tensor result will // need an extension to use OpResult. struct LinalgDependenceGraphElem { + using OpView = PointerUnion; // dependentOpView may be either: // 1. src in the case of dependencesIntoGraphs. // 2. dst in the case of dependencesFromDstGraphs. - OpOperand *dependentOpView; + OpView dependentOpView; // View in the op that is used to index in the graph: // 1. src in the case of dependencesFromDstGraphs. // 2. dst in the case of dependencesIntoGraphs. - OpOperand *indexingOpView; + OpView indexingOpView; // Type of the dependence. DependenceType dependenceType; + + // Return the Operation that owns the operand or result represented in + // `opView`. + static Operation *getOwner(OpView opView) { + if (OpOperand *operand = opView.dyn_cast()) + return operand->getOwner(); + return opView.get().cast().getOwner(); + } + // Return the operand or the result Value represented by the `opView`. + static Value getValue(OpView opView) { + if (OpOperand *operand = opView.dyn_cast()) + return operand->get(); + return opView.get(); + } + // Return the indexing map of the operand/result in `opView` specified in + // the owning LinalgOp. If the owner is not a LinalgOp returns llvm::None. + static Optional getIndexingMap(OpView opView) { + auto owner = dyn_cast(getOwner(opView)); + if (!owner) + return llvm::None; + if (OpOperand *operand = opView.dyn_cast()) + return owner.getIndexingMap(operand->getOperandNumber()); + return owner.getOutputIndexingMap( + opView.get().cast().getResultNumber()); + } + // Return the operand number if the `opView` is an OpOperand *. Otherwise + // return llvm::None. + static Optional getOperandNumber(OpView opView) { + if (OpOperand *operand = opView.dyn_cast()) + return operand->getOperandNumber(); + return llvm::None; + } + // Return the result number if the `opView` is an OpResult. Otherwise return + // llvm::None. + static Optional getResultNumber(OpView opView) { + if (OpResult result = opView.dyn_cast().cast()) + return result.getResultNumber(); + return llvm::None; + } + + // Return the owner of the dependent OpView. + Operation *getDependentOp() const { return getOwner(dependentOpView); } + + // Return the owner of the indexing OpView. + Operation *getIndexingOp() const { return getOwner(indexingOpView); } + + // Return the operand or result stored in the dependentOpView. + Value getDependentValue() const { return getValue(dependentOpView); } + + // Return the operand or result stored in the indexingOpView. + Value getIndexingValue() const { return getValue(indexingOpView); } + + // If the dependent OpView is an operand, return operand number. Return + // llvm::None otherwise. + Optional getDependentOpViewOperandNum() const { + return getOperandNumber(dependentOpView); + } + + // If the indexing OpView is an operand, return operand number. Return + // llvm::None otherwise. + Optional getIndexingOpViewOperandNum() const { + return getOperandNumber(indexingOpView); + } + + // If the dependent OpView is a result value, return the result + // number. Return llvm::None otherwise. + Optional getDependentOpViewResultNum() const { + return getResultNumber(dependentOpView); + } + + // If the dependent OpView is a result value, return the result + // number. Return llvm::None otherwise. + Optional getIndexingOpViewResultNum() const { + return getResultNumber(indexingOpView); + } + + // Return the indexing map of the operand/result in the dependent OpView as + // specified in the owner of the OpView. + Optional getDependentOpViewIndexingMap() const { + return getIndexingMap(dependentOpView); + } + + // Return the indexing map of the operand/result in the indexing OpView as + // specified in the owner of the OpView. + Optional getIndexingOpViewIndexingMap() const { + return getIndexingMap(indexingOpView); + } }; using LinalgDependences = SmallVector; using DependenceGraph = DenseMap; diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -218,15 +218,14 @@ // TODO: we are not considering paths yet, just interleaved positions. for (auto dt : types) { for (auto dependence : getDependencesFrom(src, dt)) { - auto interimPos = - linalgOpPositions.lookup(dependence.dependentOpView->getOwner()); + auto interimPos = linalgOpPositions.lookup(dependence.getDependentOp()); // Skip if not interleaved. if (interimPos >= dstPos || interimPos <= srcPos) continue; - Value consumerView = dependence.indexingOpView->get(); + Value consumerView = dependence.getIndexingValue(); if (view && !aliases.alias(view, consumerView)) continue; - auto *op = dependence.dependentOpView->getOwner(); + auto *op = dependence.getDependentOp(); LLVM_DEBUG(dbgs() << "\n***Found covering dependence of type " << getDependenceTypeStr(dt) << ": " << *src << " -> " << *op << " on " << consumerView); @@ -241,7 +240,7 @@ ArrayRef depTypes) const { for (auto dep : depTypes) for (auto dependence : getDependencesInto(dstLinalgOp, dep)) - if (dependence.dependentOpView->getOwner() == srcLinalgOp) + if (dependence.getDependentOp() == srcLinalgOp) return true; return false; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -258,11 +258,9 @@ /// `producer.getOutputBuffers()`. /// 2. Tensor case: `producerIdx` is the index of the tensor in /// `producer.getResults()`. -static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, - unsigned producerOutNumber, OpOperand &consumerOpOperand) { - AffineMap producerMap = producerOp.getOutputIndexingMap(producerOutNumber); - LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerOutNumber - << ", producer map: " << producerMap << "\n"); +static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, + OpOperand &consumerOpOperand) { + LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); DenseMap fusedLoopsAndRanges; Value shapedOperand = consumerOpOperand.get(); for (auto en : llvm::enumerate(producerMap.getResults())) { @@ -354,6 +352,8 @@ findFusableProducer(OpOperand &consumerOpOperand, const LinalgDependenceGraph &dependenceGraph) { LinalgOp consumerOp = cast(consumerOpOperand.getOwner()); + // Note that buffer semantics implies that the dependence will only be from + // OpOperand -> OpOperand. assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand"); // Only consider RAW and WAW atm. @@ -364,22 +364,24 @@ for (auto dependence : llvm::make_filter_range( dependenceGraph.getDependencesInto(consumerOp, depType), [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { - return elem.indexingOpView->get() == consumerOpOperand.get() && - elem.indexingOpView->getOperandNumber() == + Value v = elem.getIndexingValue(); + Optional operandNum = + elem.getIndexingOpViewOperandNum(); + return isa(elem.getDependentOp()) && + v == consumerOpOperand.get() && operandNum && + operandNum.getValue() == consumerOpOperand.getOperandNumber(); })) { - // Consumer consumes this view, `isStructurallyFusableProducer` also // checks whether it is a strict subview of the producer view. - auto producer = cast(dependence.dependentOpView->getOwner()); + auto producer = cast(dependence.getDependentOp()); LLVM_DEBUG(llvm::dbgs() << "\n" << LinalgDependenceGraph::getDependenceTypeStr(depType) - << "producer: " << *dependence.dependentOpView->getOwner() - << " view: " << dependence.dependentOpView->get() - << " output index: " - << dependence.dependentOpView->getOperandNumber() - - producer.getNumInputs() + << "producer: " << *dependence.getDependentOp() << " view: " + << dependence.getDependentValue() << " output index: " + << (dependence.getDependentOpViewOperandNum().getValue() - + producer.getNumInputs()) << "\n"); // Simple fusability checks. @@ -399,18 +401,21 @@ Optional fusableDependence = findFusableProducer(consumerOpOperand, graph); if (!fusableDependence) - return {}; + return llvm::None; + + LinalgOp producerOp = dyn_cast(fusableDependence->getDependentOp()); + if (!producerOp) + return llvm::None; - LinalgOp producerOp = - cast(fusableDependence->dependentOpView->getOwner()); // If producer is already in the same block as consumer, we are done. if (consumerOpOperand.get().getParentBlock() == - fusableDependence->dependentOpView->get().getParentBlock()) - return {}; + fusableDependence->getDependentValue().getParentBlock()) + return llvm::None; - unsigned producerIdx = - fusableDependence->dependentOpView->getOperandNumber() - - producerOp.getNumInputs(); + Optional producerMap = + fusableDependence->getDependentOpViewIndexingMap(); + if (!producerMap) + return llvm::None; // Must be a subview or a slice to guarantee there are loops we can fuse // into. @@ -418,7 +423,7 @@ auto slice = consumerOpOperand.get().getDefiningOp(); if (!subView && !slice) { LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)"); - return {}; + return llvm::None; } // Fuse `producer` just before `consumer`. @@ -428,7 +433,7 @@ LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOpOperand.getOwner() << "\n"); - auto fusedProducer = fuse(b, producerOp, producerIdx, consumerOpOperand); + auto fusedProducer = fuse(b, producerOp, *producerMap, consumerOpOperand); return FusionInfo{producerOp, fusedProducer}; } @@ -474,8 +479,13 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, OpOperand &consumerOpOperand) { auto producerOp = dyn_cast(producerOpResult.getOwner()); - assert(producerOp && "expected Linalg producer"); - LinalgOp consumerOp = cast(consumerOpOperand.getOwner()); + if (!producerOp) + return llvm::None; + + LinalgOp consumerOp = dyn_cast(consumerOpOperand.getOwner()); + if (!consumerOp) + return llvm::None; + Value inputTensor = consumerOpOperand.get(); // Must be a subtensor to guarantee there are loops we can fuse into. @@ -496,8 +506,10 @@ b.setInsertionPoint(consumerOp); ScopedContext scope(b, consumerOp->getLoc()); LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); - LinalgOp fusedProducer = fuse( - b, producerOp, producerOpResult.getResultNumber(), consumerOpOperand); + LinalgOp fusedProducer = + fuse(b, producerOp, + producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()), + consumerOpOperand); // Replace use. // Canonicalizations are not guaranteed to have happened before constructing @@ -531,30 +543,34 @@ /// inverse(producerIndexMap).compose(consumerIndexMap) static Optional getConsumerLoopToProducerLoopMap( LinalgDependenceGraph::LinalgDependenceGraphElem dependence) { - auto producer = cast(dependence.dependentOpView->getOwner()); - AffineMap producerIndexingMap = - producer.getIndexingMap(dependence.dependentOpView->getOperandNumber()); - auto consumer = cast(dependence.indexingOpView->getOwner()); - AffineMap consumerIndexingMap = - consumer.getIndexingMap(dependence.indexingOpView->getOperandNumber()); + auto producer = dyn_cast(dependence.getDependentOp()); + if (!producer) + return None; + + Optional producerIndexingMap = + dependence.getDependentOpViewIndexingMap(); + Optional consumerIndexingMap = + dependence.getIndexingOpViewIndexingMap(); + if (!producerIndexingMap || !consumerIndexingMap) + return None; AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap( - producer.iterator_types().getValue(), producerIndexingMap); + producer.iterator_types().getValue(), *producerIndexingMap); if (!prunedProducerIndexingMap.isPermutation()) return None; - if (consumerIndexingMap.getNumResults() != + if (consumerIndexingMap->getNumResults() != prunedProducerIndexingMap.getNumResults()) return None; LLVM_DEBUG({ llvm::dbgs() << "\t producerMap : "; - producerIndexingMap.print(llvm::dbgs()); + producerIndexingMap->print(llvm::dbgs()); llvm::dbgs() << " pruned : "; prunedProducerIndexingMap.print(llvm::dbgs()); llvm::dbgs() << "\n"; llvm::dbgs() << "\t consumerMap : "; - consumerIndexingMap.print(llvm::dbgs()); + consumerIndexingMap->print(llvm::dbgs()); llvm::dbgs() << "\n"; }); @@ -562,7 +578,7 @@ if (!invProducerIndexMap) return None; - return invProducerIndexMap.compose(consumerIndexingMap); + return invProducerIndexMap.compose(*consumerIndexingMap); } /// Given a projected permutation `map`, returns true if the map changes the @@ -710,10 +726,7 @@ FusableOpDependencesTy mlir::linalg::findAllFusableDependences( ArrayRef ops, const LinalgDependenceGraph &dependenceGraph) { FusableOpDependencesTy fusableDependences; - // TODO: Currently fusion would not be legal if the fusable dependence is to - // the same producer but different indexing map in the consumer. Fix this, but - // in the meanwhile disallow such a fusion. - DenseMap fusedProducerIndexingMap; + DenseMap> fusedProducerIndexingMap; for (LinalgOp op : reverse(ops)) { for (OpOperand &opOperand : op.getShapedOpOperands()) { Optional @@ -721,54 +734,47 @@ if (!fusableDependence) continue; LinalgOp producerOp = - cast(fusableDependence->dependentOpView->getOwner()); + dyn_cast(fusableDependence->getDependentOp()); + if (!producerOp) + continue; // Do not fuse dependences that are to operations not in the same basic // block. This avoid moving fused operations across loops that might // themselves carry dependency making the fusion illegal. - if (producerOp->getBlock() != op->getBlock()) { - op.emitRemark("unhandled fusion of ops in different basic blocks"); - return FusableOpDependencesTy{}; - } + if (producerOp->getBlock() != op->getBlock()) + continue; + // Make sure that the indexing map of the view used for fusion in the // producer is a projected permutation. - unsigned producerIdx = - fusableDependence->dependentOpView->getOperandNumber(); - AffineMap producerMap = producerOp.getIndexingMap(producerIdx); - if (!producerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled non permutation indexing map for fused view in " - "producer for operand at index ") - << opOperand.getOperandNumber(); - return FusableOpDependencesTy{}; - } - - unsigned consumerIdx = - fusableDependence->indexingOpView->getOperandNumber(); - AffineMap consumerMap = op.getIndexingMap(consumerIdx); - if (!consumerMap.isProjectedPermutation()) { - op.emitRemark( - "unhandled case where indexing map for fused view in the consumer " - "is not a projected permutation while fusing at index ") - << opOperand.getOperandNumber(); - return FusableOpDependencesTy{}; - } - - // Check if the producer is already a fusion candidate. Cannot fuse this - // dependence if it has a different indexing map when used in the - // consumer. - if (fusedProducerIndexingMap.count(producerOp.getOperation()) && - fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) { - op.emitRemark( - "unhandled fusion to the same producer but with different " - "indexing maps"); - return FusableOpDependencesTy{}; - } - fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap; + Optional producerMap = + fusableDependence->getDependentOpViewIndexingMap(); + Optional consumerMap = + fusableDependence->getIndexingOpViewIndexingMap(); + assert( + consumerMap && + "unable to find indexing map of operand/result of indexing OpView"); + fusedProducerIndexingMap[producerOp.getOperation()].push_back( + *consumerMap); + if (!producerMap || !producerMap->isProjectedPermutation() || + !consumerMap->isProjectedPermutation()) + continue; fusableDependences[producerOp.getOperation()].push_back( *fusableDependence); } } + // TODO: Currently fusion would not be legal if the fusable dependence is to + // the same producer but different indexing map in the consumer. Fix this, but + // in the meanwhile disallow such a fusion. + for (auto useIndexingMapsList : fusedProducerIndexingMap) { + AffineMap map1 = useIndexingMapsList.second.front(); + for (AffineMap map2 : + ArrayRef(useIndexingMapsList.second).drop_front()) { + if (map1 != map2) { + fusableDependences.erase(useIndexingMapsList.first); + break; + } + } + } return fusableDependences; } @@ -819,7 +825,7 @@ tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef ops, const LinalgDependenceGraph &dependenceGraph, const LinalgTilingOptions &tilingOptions) { - if (ops.empty()) + if (ops.size() < 2) return llvm::None; LinalgOp rootOp = ops.back(); for (auto op : enumerate(ops)) { @@ -827,14 +833,14 @@ // buffers. This check can be removed after it is tested on tensors. LinalgOp linalgOp = op.value(); if (!linalgOp.hasBufferSemantics()) { - linalgOp.emitError("tile and fuse only tested for buffer operation"); + linalgOp.emitRemark("tile and fuse only tested for buffer operation"); return llvm::None; } } // TODO: Support interchange with tile + fuse. This might actually help do // better fusion. if (!tilingOptions.interchangeVector.empty()) { - rootOp.emitError("unable to handle tile and fuse with interchange"); + rootOp.emitRemark("unable to handle tile and fuse with interchange"); return llvm::None; } @@ -864,7 +870,7 @@ Optional tiledRootOp = tileRootOperation( builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims); if (!tiledRootOp) { - rootOp.emitError("failed to tile the fused loops"); + rootOp.emitRemark("failed to tile the fused loops"); return llvm::None; } ret.op = tiledRootOp->op; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -161,12 +161,16 @@ DenseSet producers; producers.insert(linalgOp); - for (auto dependence : dependenceGraph.getDependentOperations(linalgOp)) { - if (!fusionOptions.indicesToFuse.count( - dependence.indexingOpView->getOperandNumber())) + for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) { + Optional operandNumber = dependence.getIndexingOpViewOperandNum(); + // When looking at dependences into, indexingOp is always OpOperand. We + // could assert, but continue if this is not the case. + if (!operandNumber) continue; - if (isa(dependence.dependentOpView->getOwner())) - producers.insert(dependence.dependentOpView->getOwner()); + if (!fusionOptions.indicesToFuse.count(operandNumber.getValue())) + continue; + if (isa(dependence.getDependentOp())) + producers.insert(dependence.getDependentOp()); } SmallVector fusionOps; diff --git a/mlir/test/Dialect/Linalg/fusion-pattern.mlir b/mlir/test/Dialect/Linalg/fusion-pattern.mlir --- a/mlir/test/Dialect/Linalg/fusion-pattern.mlir +++ b/mlir/test/Dialect/Linalg/fusion-pattern.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -test-linalg-fusion-transform-patterns -canonicalize -cse -split-input-file | FileCheck %s module { func @basic_fusion(%arg0: memref, %arg1: memref, @@ -371,7 +371,6 @@ %2 = alloc(%0, %1) : memref linalg.matmul ins(%arg0, %arg1 : memref, memref) outs(%2 : memref) - // expected-remark @+1 {{unhandled fusion to the same producer but with different indexing maps}} linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, @@ -387,6 +386,15 @@ return } } +// CHECK-LABEL: func @matmul_plus_transpose_matmul +// CHECK-NOT: scf.parallel +// CHECK-NOT: scf.for +// CHECK: linalg.matmul +// CHECK-NOT: scf.parallel +// CHECK-NOT: scf.for +// CHECK: linalg.generic +// CHECK-NOT: scf.parallel +// CHECK-NOT: scf.for // ----- @@ -416,7 +424,6 @@ %6 = subview %arg0[%arg3, %arg5] [%3, %5] [1, 1] : memref to memref %7 = subview %arg1[%arg5, %arg4] [%5, %4] [1, 1] : memref to memref %8 = subview %arg2[%arg3, %arg4] [%3, %4] [1, 1] : memref to memref - // expected-remark @+1 {{unhandled fusion of ops in different basic blocks}} linalg.matmul {__internal_linalg_transform__ = "basic_fusion"} ins(%6, %7 : memref, memref) outs(%8 : memref) @@ -426,6 +433,13 @@ return } } +// CHECK-LABEL: func @basic_no_fusion +// CHECK-NOT: scf.parallel +// CHECK: linalg.fill +// CHECK: scf.parallel +// CHECK: scf.for +// CHECK-NOT: linalg.fill +// CHECK: linalg.matmul // -----