diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -69,10 +69,9 @@ static ShapeDimension getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth, bool fromSubViewOpOnly = false) { - auto maps = op.indexing_maps(); // Iterate over the inputs and outputs in order. // Extract the subranges from the linearized ranges. - for (auto en : llvm::enumerate(op.getShapedOperands())) { + for (OpOperand *opOperand : op.getInputAndOutputOperands()) { // The method `getRangeFromOperandShape` requires using SubViewOp or // SubTensorOps. If the value isnt defined from there continue. // todo: The method should be adapted to get the values from @@ -80,27 +79,26 @@ // currently returns a `linalg.range`. The fix here is to move this op to // `std` dialect and add the method to `ViewInterface`. if (fromSubViewOpOnly && !isa_and_nonnull( - en.value().getDefiningOp())) + opOperand->get().getDefiningOp())) continue; - unsigned idx = en.index(); - auto map = maps[idx].cast().getValue(); - LLVM_DEBUG(llvm::dbgs() - << "getShapeDefiningLoopRange I/O idx: " << idx << "\n"); + AffineMap map = op.getTiedIndexingMap(opOperand); + LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange I/O idx: " + << opOperand->getOperandNumber() << "\n"); LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange map: " << map << "\n"); - Value shape = en.value(); SmallVector shapeRanges(map.getNumResults(), nullptr); - for (auto en2 : llvm::enumerate(map.getResults())) { - auto dimExpr = en2.value().dyn_cast(); + for (auto en : llvm::enumerate(map.getResults())) { + auto dimExpr = en.value().dyn_cast(); if (!dimExpr) continue; - if (loopDepth == en2.value().cast().getPosition()) { + if (loopDepth == en.value().cast().getPosition()) { LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: " << loopDepth << "\n"); - LLVM_DEBUG(llvm::dbgs() - << "getShapeDefiningLoopRange shape: " << shape << "\n"); - return ShapeDimension{shape, static_cast(en2.index())}; + LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange shape: " + << opOperand->get() << "\n"); + return ShapeDimension{opOperand->get(), + static_cast(en.index())}; } } } @@ -122,26 +120,24 @@ // would need to add the intermediate results to `linalg.yield`. After that a // canonicalization pass would move the unused output args of the `tiled_loop` // to the `input` section. -static SmallVector getTiledOperands(OpBuilder &b, LinalgOp producer) { +static SmallVector getTiledOperands(OpBuilder &b, LinalgOp producer) { auto tiledLoop = dyn_cast(b.getBlock()->getParentOp()); if (!tiledLoop) - return llvm::to_vector<4>(producer.getShapedOperands()); + return producer.getInputAndOutputOperands(); - SmallVector tiledOperands; + SmallVector tiledOperands; assert(producer.hasTensorSemantics() && "only fusion on tensors is currently supported for TiledLinalgOp"); - for (auto producerInput : producer.getInputTensors()) { - OpOperand *addedInput = tiledLoop.findInputOperand(producerInput); + for (OpOperand *producerInput : producer.getInputTensorOperands()) { + OpOperand *addedInput = tiledLoop.findInputOperand(producerInput->get()); if (addedInput == nullptr) - addedInput = &tiledLoop.appendInputOperand(b, producerInput); + addedInput = &tiledLoop.appendInputOperand(b, producerInput->get()); BlockArgument addedBlockArg = tiledLoop.getTiedBlockArgument(*addedInput); tiledOperands.push_back(addedBlockArg); } - for (auto &en : llvm::enumerate(producer.getOutputTensors())) { - Value producerOutput = en.value(); - - Value result = producer->getResult(en.index()); + for (OpOperand *producerOutput : producer.getOutputTensorOperands()) { + OpResult result = producer.getTiedOpResult(producerOutput); OpOperand *resultInputOperand = tiledLoop.findInputOperand(result); OpOperand *resultOutputOperand = tiledLoop.findOutputOperand(result); assert((resultInputOperand != nullptr) ^ (resultOutputOperand != nullptr) && @@ -152,10 +148,11 @@ int opNumber = isInput ? resultInputOperand->getOperandNumber() : resultOutputOperand->getOperandNumber(); - OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput); + OpOperand *addedOutput = tiledLoop.findOutputOperand(producerOutput->get()); if (addedOutput == nullptr) - addedOutput = isInput ? &tiledLoop.appendInputOperand(b, producerOutput) - : &tiledLoop.appendOutputOperand(b, producerOutput); + addedOutput = + isInput ? &tiledLoop.appendInputOperand(b, producerOutput->get()) + : &tiledLoop.appendOutputOperand(b, producerOutput->get()); OpOperand &resultOperand = tiledLoop->getOpOperand(opNumber); auto addedBlockArg = tiledLoop.getTiedBlockArgument(*addedOutput); @@ -200,7 +197,7 @@ } SmallVector clonedShapes; - clonedShapes.reserve(producer.getNumShapedOperands()); + clonedShapes.reserve(producer.getNumInputsAndOutputs()); // Compute subranges for all tensor input/output operands. clonedShapes.append(makeTiledShapes(b, loc, producer, @@ -267,16 +264,9 @@ llvm_unreachable("SubviewOp or SubTensorOp expected"); } -/// Fuses the producer of `producerIdx` into the loop immediately enclosing -/// `consumer`. This is achieved by "recomputing" the `producer` at the time it -/// is needed just before the `consumer. -/// -/// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are -/// 2 cases: -/// 1. Buffer case: `producerIdx` is the index of the buffer in -/// `producer.getOutputBuffers()`. -/// 2. Tensor case: `producerIdx` is the index of the tensor in -/// `producer.getResults()`. +/// Fuses the producer into the loop immediately enclosing the consumer. +/// This is achieved by "recomputing" the producer at the time it +/// is needed just before the consumer. static LinalgOp fuse(OpBuilder &b, LinalgOp producerOp, AffineMap producerMap, OpOperand &consumerOpOperand) { LLVM_DEBUG(llvm::dbgs() << "Producer map: " << producerMap << "\n"); @@ -548,9 +538,10 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(consumerOp); LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumerOp << "\n"); + OpOperand *opOperand = + producerOp.getOutputOperand(producerOpResult.getResultNumber()); LinalgOp fusedProducer = - fuse(b, producerOp, - producerOp.getOutputIndexingMap(producerOpResult.getResultNumber()), + fuse(b, producerOp, producerOp.getTiedIndexingMap(opOperand), consumerOpOperand); // Replace use. @@ -770,9 +761,9 @@ FusableOpDependencesTy fusableDependences; DenseMap> fusedProducerIndexingMap; for (LinalgOp op : reverse(ops)) { - for (OpOperand &opOperand : op.getShapedOpOperands()) { + for (OpOperand *opOperand : op.getInputAndOutputOperands()) { Optional - fusableDependence = findFusableProducer(opOperand, dependenceGraph); + fusableDependence = findFusableProducer(*opOperand, dependenceGraph); if (!fusableDependence) continue; // Canonicalize indexed generic ops before fusion. @@ -905,10 +896,16 @@ // To keep the second type of information while letting the unfused op die // unused, we need to forward the producer output operand. if (auto forOp = dyn_cast(tiledLinalgOp.loops.front())) { - for (auto &operand : forOp.getIterOpOperands()) - if (auto opResult = operand.get().dyn_cast()) - if (opResult.getOwner() == origOp) - operand.set(origOp.getOutputTensors()[opResult.getResultNumber()]); + for (auto &operand : forOp.getIterOpOperands()) { + if (auto opResult = operand.get().dyn_cast()) { + if (opResult.getOwner() == origOp) { + Value output = + origOp.getOutputOperand(opResult.getResultNumber())->get(); + assert(output.getType().isa()); + operand.set(output); + } + } + } } } return fusedOps;