diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterface.td @@ -726,6 +726,18 @@ getNumShapedOperands()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the OpOperands for all the shaped operands. + }], + /*retTy=*/" OpOperand&", + /*methodName=*/"getShapedOpOperand", + /*args=*/(ins "unsigned":$i), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return *(this->getShapedOpOperands().begin() + i); + }] + >, InterfaceMethod< /*desc=*/[{ Return the range over input and output operands. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -35,6 +35,7 @@ LinalgOp op; SmallVector loops; SmallVector tensorResults; + TiledLinalgOp &operator=(const TiledLinalgOp &) = default; }; /// Populates patterns for vectorization of all ConvN-D ops. @@ -412,9 +413,8 @@ LinalgTilingOptions options, LinalgMarker marker = LinalgMarker(), PatternBenefit benefit = 1); - LogicalResult - matchAndRewriteBase(Operation *op, PatternRewriter &rewriter, - SmallVectorImpl &tensorResults) const; + LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter, + TiledLinalgOp &result) const; private: /// LinalgTransformMarker handles special attribute manipulations. @@ -432,14 +432,14 @@ marker, benefit) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - SmallVector tensorResults; + TiledLinalgOp tiledLinalgOp; if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter, - tensorResults))) + tiledLinalgOp))) return failure(); - if (tensorResults.empty()) + if (tiledLinalgOp.tensorResults.empty()) rewriter.eraseOp(op); else - rewriter.replaceOp(op, tensorResults); + rewriter.replaceOp(op, tiledLinalgOp.tensorResults); return success(); } }; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -92,26 +92,31 @@ /// Fuses producer into consumer if the producer is structurally feasible and /// the fusion would not violate dependencies. -/// Implements the fusion part of the "tileAndFuse on buffers" -/// transformation and thus requires the `consumerdIdx`^th operand of `consumer` -/// to be a `subview` op (generally obtained by applying the tiling -/// transformation). -Optional fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, - unsigned consumerIdx, +/// Implements the fusion part of the "tileAndFuse on buffers" transformation +/// and thus requires the `consumerOpOperand` to be a `subview` op (generally +/// obtained by applying the tiling transformation). +Optional fuseProducerOfBuffer(OpBuilder &b, + OpOperand &consumerOpOperand, const LinalgDependenceGraph &graph); /// Tensor counterpart of `fuseProducerOfBuffer`. /// This implements the fusion part of the "tileAndFuse on tensors" -/// transformation and thus requires the `consumerdIdx`^th operand of `consumer` -/// to be the result of a `subtensor` op (generally obtained by applying the -/// tiling transformation). -Optional fuseProducerOfTensor(OpBuilder &b, LinalgOp consumer, - unsigned consumerIdx); +/// transformation and thus requires the `consumerOpOperand` to be a `subtensor` +/// op (generally obtained by applying the tiling transformation). +Optional fuseProducerOfTensor(OpBuilder &b, + OpOperand &consumerOpOperand); +/// Tensor counterpart of `fuseProducerOfBuffer`. +/// This implements the fusion part of the "tileAndFuse on tensors" +/// transformation and thus requires the `consumerOpOperand` to be a `subtensor` +/// op (generally obtained by applying the tiling transformation). +/// Assumes `producerOfTensor` is a Linalg op that produces `consumerOpOperand`. +Optional fuseProducerOfTensor(OpBuilder &b, + OpResult producerOpResult, + OpOperand &consumerOpOperand); /// Fuse linalg operation on tensors, with the producer of the operand at /// position `consumerIdx` of the consumer. Optional> fuseTensorOps(PatternRewriter &rewriter, - Operation *consumer, - unsigned consumerIdx); + OpOperand &consumerOpOperand); /// Like `getShape`, but only returns statically-known information, without /// generating any new IR. For each shape dimension, returns >=0 if that diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -258,20 +258,19 @@ /// `producer.getOutputBuffers()`. /// 2. Tensor case: `producerIdx` is the index of the tensor in /// `producer.getResults()`. -static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx, - LinalgOp consumer, unsigned consumerIdx) { - AffineMap producerMap = producer.getOutputIndexingMap(producerIdx); - LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx +static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, + unsigned producerOutNumber, OpOperand &consumerOpOperand) { + AffineMap producerMap = producerOp.getOutputIndexingMap(producerOutNumber); + LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerOutNumber << ", producer map: " << producerMap << "\n"); DenseMap fusedLoopsAndRanges; - Location loc = consumer.getLoc(); - Value shapedOperand = consumer.getShapedOperand(consumerIdx); + Value shapedOperand = consumerOpOperand.get(); for (auto en : llvm::enumerate(producerMap.getResults())) { unsigned posInProducerLoop = en.value().cast().getPosition(); - fusedLoopsAndRanges[posInProducerLoop] = - getRangeFromOperandShape(b, loc, shapedOperand, en.index()); + fusedLoopsAndRanges[posInProducerLoop] = getRangeFromOperandShape( + b, consumerOpOperand.getOwner()->getLoc(), shapedOperand, en.index()); } - return fuse(b, producer, fusedLoopsAndRanges); + return fuse(b, producerOp, fusedLoopsAndRanges); } // Encode structural fusion safety preconditions. @@ -378,9 +377,10 @@ } static Optional -findFusableProducer(LinalgOp consumer, unsigned consumerIdx, +findFusableProducer(OpOperand &consumerOpOperand, const LinalgDependenceGraph &dependenceGraph) { - assert(consumer.hasBufferSemantics() && "revisit usage of shaped operand"); + LinalgOp consumerOp = cast(consumerOpOperand.getOwner()); + assert(consumerOp.hasBufferSemantics() && "revisit usage of shaped operand"); // Only consider RAW and WAW atm. for (auto depType : { @@ -388,21 +388,16 @@ LinalgDependenceGraph::DependenceType::WAW, }) { for (auto dependence : llvm::make_filter_range( - dependenceGraph.getDependencesInto(consumer, depType), - [consumerIdx]( - LinalgDependenceGraph::LinalgDependenceGraphElem elem) { - return elem.indexingOpView->getOperandNumber() == consumerIdx; + dependenceGraph.getDependencesInto(consumerOp, depType), + [&](LinalgDependenceGraph::LinalgDependenceGraphElem elem) { + return elem.indexingOpView->get() == consumerOpOperand.get() && + elem.indexingOpView->getOperandNumber() == + consumerOpOperand.getOperandNumber(); })) { - // Check that the dependence is indeed on the input `consumerIdx` view. - Value consumedView = dependence.indexingOpView->get(); - if (!isSameSubView(consumer.getShapedOperand(consumerIdx), consumedView)) - continue; - // Consumer consumes this view, `isStructurallyFusableProducer` also // checks whether it is a strict subview of the producer view. auto producer = cast(dependence.dependentOpView->getOwner()); - Value producedView = dependence.dependentOpView->get(); LLVM_DEBUG(llvm::dbgs() << "\n" << LinalgDependenceGraph::getDependenceTypeStr(depType) @@ -412,10 +407,10 @@ << dependence.dependentOpView->getOperandNumber() - producer.getNumInputs() << "\n"); - (void)producedView; // Simple fusability checks. - if (!isFusableInto(dependenceGraph, consumer, consumedView, producer)) + if (!isFusableInto(dependenceGraph, consumerOp, consumerOpOperand.get(), + producer)) continue; return dependence; @@ -425,29 +420,28 @@ } Optional -mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer, - unsigned consumerIdx, +mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, OpOperand &consumerOpOperand, const LinalgDependenceGraph &graph) { Optional fusableDependence = - findFusableProducer(consumer, consumerIdx, graph); + findFusableProducer(consumerOpOperand, graph); if (!fusableDependence) return {}; LinalgOp producerOp = cast(fusableDependence->dependentOpView->getOwner()); // If producer is already in the same block as consumer, we are done. - if (consumer->getBlock() == producerOp->getBlock()) + if (consumerOpOperand.get().getParentBlock() == + fusableDependence->dependentOpView->get().getParentBlock()) return {}; unsigned producerIdx = fusableDependence->dependentOpView->getOperandNumber() - producerOp.getNumInputs(); - Value consumerView = consumer.getShapedOperand(consumerIdx); // Must be a subview or a slice to guarantee there are loops we can fuse // into. - auto subView = consumerView.getDefiningOp(); - auto slice = consumerView.getDefiningOp(); + auto subView = consumerOpOperand.get().getDefiningOp(); + auto slice = consumerOpOperand.get().getDefiningOp(); if (!subView && !slice) { LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)"); return {}; @@ -455,25 +449,25 @@ // Fuse `producer` just before `consumer`. OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(consumer.getOperation()); - ScopedContext scope(b, consumer.getLoc()); - LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n"); + b.setInsertionPoint(consumerOpOperand.getOwner()); + ScopedContext scope(b, consumerOpOperand.getOwner()->getLoc()); + LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " + << *consumerOpOperand.getOwner() << "\n"); - auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx); + auto fusedProducer = fuse(b, producerOp, producerIdx, consumerOpOperand); return FusionInfo{producerOp, fusedProducer}; } /// Walk back use-def chain through scf::For yields. /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp -static void getProducerOfTensor(Value tensor, LinalgOp &producer, - unsigned &outputIndex) { +static void getProducerOfTensor(Value tensor, OpResult &opResult) { if (!tensor.getType().isa()) return; while (true) { + LLVM_DEBUG(llvm::dbgs() << "\ngetProducerOfTensor: " << tensor); if (auto linalgOp = tensor.getDefiningOp()) { - producer = linalgOp; - outputIndex = tensor.cast().getResultNumber(); + opResult = tensor.cast(); return; } if (auto subTensorOp = tensor.getDefiningOp()) { @@ -482,7 +476,7 @@ } if (auto blockArg = tensor.dyn_cast()) { if (auto forOp = blockArg.getDefiningOp()) { - tensor = forOp.getResult(blockArg.getArgNumber()); + tensor = *(forOp.getIterOperands().begin() + blockArg.getArgNumber()); continue; } } @@ -490,45 +484,58 @@ } } -Optional mlir::linalg::fuseProducerOfTensor(OpBuilder &b, - LinalgOp consumer, - unsigned consumerIdx) { - Value inputTensor = consumer.getInput(consumerIdx); - LinalgOp producerOp; - unsigned producerIdx; - getProducerOfTensor(inputTensor, producerOp, producerIdx); +Optional +mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { + Value inputTensor = consumerOpOperand.get(); + OpResult producerOpResult; + getProducerOfTensor(inputTensor, producerOpResult); + if (!producerOpResult) { + LLVM_DEBUG(llvm::dbgs() << "\nUnable to find producer"); + return {}; + } + return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); +} + +Optional +mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, + OpOperand &consumerOpOperand) { + auto producerOp = dyn_cast(producerOpResult.getOwner()); + assert(producerOp && "expected Linalg producer"); + LinalgOp consumerOp = cast(consumerOpOperand.getOwner()); + Value inputTensor = consumerOpOperand.get(); // Must be a subtensor to guarantee there are loops we can fuse into. auto subTensor = inputTensor.getDefiningOp(); - if (!subTensor || !producerOp) { - LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)"); + if (!subTensor) { + LLVM_DEBUG(llvm::dbgs() + << "\nNot fusable, not a subtensor: " << inputTensor); return {}; } // If producer is already in the same block as consumer, we are done. - if (consumer->getBlock() == producerOp->getBlock()) + if (consumerOpOperand.get().getParentBlock() == + producerOpResult.getParentBlock()) return {}; // Insert fused `producer` just before `consumer`. OpBuilder::InsertionGuard g(b); - b.setInsertionPoint(consumer.getOperation()); - ScopedContext scope(b, consumer.getLoc()); - LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n"); - LinalgOp fusedProducer = - fuse(b, producerOp, producerIdx, consumer, consumerIdx); + b.setInsertionPoint(consumerOp); + ScopedContext scope(b, consumerOp->getLoc()); + LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); + LinalgOp fusedProducer = fuse( + b, producerOp, producerOpResult.getResultNumber(), consumerOpOperand); // Replace use. // Canonicalizations are not guaranteed to have happened before constructing // `fusedProducer`. In the tensor case this can result in temporary type // mismatches. Insert a `tensor.cast` op to propagate the transformation // invariant that types are compatible. - Value def = fusedProducer->getResult(producerIdx); - OpOperand &use = consumer->getOpOperand(consumerIdx); - Type consumerType = use.get().getType(); + Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); + Type consumerType = consumerOpOperand.get().getType(); if (consumerType != def.getType()) def = b.create(fusedProducer.getLoc(), consumerType, def); - use.set(def); - return FusionInfo{producerOp, fusedProducer}; + consumerOpOperand.set(def); + return FusionInfo{cast(producerOpResult.getOwner()), fusedProducer}; } /// Prune all dimensions that are of reduction iterator type from `map`. @@ -734,11 +741,9 @@ // in the meanwhile disallow such a fusion. DenseMap fusedProducerIndexingMap; for (LinalgOp op : reverse(ops)) { - for (auto operandIndex : - llvm::seq(0, op.getNumShapedOperands())) { + for (OpOperand &opOperand : op.getShapedOpOperands()) { Optional - fusableDependence = - findFusableProducer(op, operandIndex, dependenceGraph); + fusableDependence = findFusableProducer(opOperand, dependenceGraph); if (!fusableDependence) continue; LinalgOp producerOp = @@ -759,7 +764,7 @@ op.emitRemark( "unhandled non permutation indexing map for fused view in " "producer for operand at index ") - << operandIndex; + << opOperand.getOperandNumber(); return FusableOpDependencesTy{}; } @@ -770,7 +775,7 @@ op.emitRemark( "unhandled case where indexing map for fused view in the consumer " "is not a projected permutation while fusing at index ") - << operandIndex; + << opOperand.getOperandNumber(); return FusableOpDependencesTy{}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -178,8 +178,10 @@ } static Optional> -fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx, +fuseTensorOpsImpl(LinalgOp producer, OpOperand &consumerOpOperand, PatternRewriter &rewriter) { + LinalgOp consumer = cast(consumerOpOperand.getOwner()); + unsigned consumerIdx = consumerOpOperand.getOperandNumber(); if (!areTensorOpsFusable(producer, consumer, consumerIdx)) return llvm::None; @@ -1027,21 +1029,19 @@ } // namespace Optional> -mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer, - unsigned consumerIdx) { - if (consumerIdx >= consumer->getNumOperands()) - return llvm::None; - Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp(); +mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, + OpOperand &consumerOpOperand) { + Operation *producer = consumerOpOperand.get().getDefiningOp(); if (!producer || producer->getNumResults() != 1) return llvm::None; // Fuse when consumer is GenericOp or IndexedGenericOp. - if (!isa(consumer) || + if (!isa(consumerOpOperand.getOwner()) || !isa(producer)) return llvm::None; - return fuseTensorOpsImpl(cast(producer), cast(consumer), - consumerIdx, rewriter); + return fuseTensorOpsImpl(cast(producer), consumerOpOperand, + rewriter); } namespace { @@ -1053,12 +1053,12 @@ LogicalResult matchAndRewrite(LinalgOpTy op, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. - for (auto operandNum : llvm::seq(0, op->getNumOperands())) { - Operation *producer = op->getOperand(operandNum).getDefiningOp(); + for (OpOperand &opOperand : op.getShapedOpOperands()) { + Operation *producer = opOperand.get().getDefiningOp(); if (!producer) continue; Optional> fusedOpResults = - fuseTensorOps(rewriter, op, operandNum); + fuseTensorOps(rewriter, opOperand); if (fusedOpResults) { rewriter.replaceOp(op, *fusedOpResults); if (producer->use_empty()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -117,8 +117,7 @@ options(options) {} LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( - Operation *op, PatternRewriter &rewriter, - SmallVectorImpl &tensorResults) const { + Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { LinalgOp linalgOp = dyn_cast(op); if (!linalgOp) return failure(); @@ -131,7 +130,7 @@ return failure(); // Return relevant information to derived pattern. - tensorResults = res->tensorResults; + result = *res; // New marker if specified. marker.replaceLinalgMarker(rewriter, res->op.getOperation()); diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp @@ -135,14 +135,14 @@ // Tile and Fuse for tensors inputs (TODO: all tensor operands). bool changed = false; for (LinalgOp linalgOp : llvm::reverse(linalgOps)) { - for (auto en : llvm::enumerate(linalgOp.getShapedOperands())) { - if (en.value().getType().isa()) { + for (OpOperand &opOperand : linalgOp.getShapedOpOperands()) { + if (opOperand.get().getType().isa()) { // TODO: LinalgDependenceGraph should be able to update itself. // The current naive and expensive reconstruction of the graph should be // removed. linalg::Aliases aliases; linalg::LinalgDependenceGraph graph(aliases, linalgOps); - if (auto info = fuseProducerOfBuffer(b, linalgOp, en.index(), graph)) { + if (auto info = fuseProducerOfBuffer(b, opOperand, graph)) { auto *originalOp = info->originalProducer.getOperation(); eraseSet.insert(originalOp); auto *originalOpInLinalgOpsVector = @@ -151,11 +151,11 @@ changed = true; } } else { - assert(en.value().getType().isa()); - // Tile and Fuse tensor input (TODO: init_tensors too). - if (en.index() >= linalgOp.getNumInputs()) + assert(opOperand.get().getType().isa()); + // Tile and Fuse tensor input. + if (opOperand.getOperandNumber() >= linalgOp.getNumInputs()) continue; - if (auto info = fuseProducerOfTensor(b, linalgOp, en.index())) { + if (auto info = fuseProducerOfTensor(b, opOperand)) { auto *originalOp = info->originalProducer.getOperation(); auto *originalOpInLinalgOpsVector = std::find(linalgOps.begin(), linalgOps.end(), originalOp);