diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -28,7 +28,7 @@ /// Conditions for elementwise fusion of generic operations. static bool areElementwiseOpsFusable(GenericOp producer, GenericOp consumer, - unsigned consumerIdx) { + OpOperand *consumerOpOperand) { // Producer and consumer must have tensor semantics. if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics()) return false; @@ -40,12 +40,12 @@ // Only allow fusing the producer of an input operand for now. // TODO: allow fusing the producer of an output operand. - if (consumerIdx >= consumer.getNumInputs()) + if (!consumer.isInputTensor(consumerOpOperand)) return false; // Get the consumer index map. The number of results of the consumer index // map must match the number of loops of the producer. - AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx); + AffineMap consumerIndexMap = consumer.getTiedIndexingMap(consumerOpOperand); if (consumerIndexMap.getNumResults() != producer.getNumLoops()) return false; @@ -55,7 +55,8 @@ // Finally the index_map for the result must be invertible. For now just // verify it is a permutation. - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); + AffineMap producerResultIndexMap = + producer.getTiedIndexingMap(producer.getOutputOperand(0)); return producerResultIndexMap.isPermutation(); } @@ -63,7 +64,7 @@ /// the `producer` to use in the fused operation given the indexing map of the /// result of the producer in the consumer. static AffineMap getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - OpOperand &producerOpOperand, AffineMap producerResultIndexMap, + OpOperand *producerOpOperand, AffineMap producerResultIndexMap, AffineMap fusedConsumerArgIndexMap) { // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map // from consumer loop -> consumer arg tensor index/producer result tensor @@ -78,10 +79,9 @@ assert(invProducerResultIndexMap && "expected producer result indexig map to be invertible"); - LinalgOp producer = cast(producerOpOperand.getOwner()); + LinalgOp producer = cast(producerOpOperand->getOwner()); // argMap is a map from producer loop -> producer arg tensor index. - AffineMap argMap = - producer.getIndexingMap(producerOpOperand.getOperandNumber()); + AffineMap argMap = producer.getTiedIndexingMap(producerOpOperand); // Compose argMap with invProducerResultIndexMap to get a map from // producer result tensor index -> producer arg tensor index. @@ -94,11 +94,10 @@ /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. -static void -generateFusedElementwiseOpRegion(PatternRewriter &rewriter, GenericOp fusedOp, - GenericOp producer, GenericOp consumer, - AffineMap consumerToProducerLoopsMap, - unsigned consumerIdx, unsigned nloops) { +static void generateFusedElementwiseOpRegion( + PatternRewriter &rewriter, GenericOp fusedOp, GenericOp producer, + GenericOp consumer, AffineMap consumerToProducerLoopsMap, + OpOperand *consumerOpOperand, unsigned nloops) { // Build the region of the fused op. Block &producerBlock = producer->getRegion(0).front(); Block &consumerBlock = consumer->getRegion(0).front(); @@ -129,11 +128,11 @@ } } // TODO: allow fusing the producer of an output operand. - assert(consumerIdx < consumer.getNumInputs() && + assert(consumer.isInputTensor(consumerOpOperand) && "expected producer of input operand"); // 3. Consumer input operands up to consumerIdx (exclusive). for (BlockArgument bbArg : consumerBlock.getArguments().take_front( - consumerIdx)) // input assumption. + consumerOpOperand->getOperandNumber())) // input assumption. mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // Replacing consumerIdx requires getting the cloned, yielded, value from @@ -147,7 +146,7 @@ // 4.b. Producer output operand/map that is fused needs to be mapped to the // producer bbArg if it is an "initTensor" (i.e. its value is actually read). assert(producer->getNumResults() == 1 && "expected single result producer"); - if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) { + if (producer.isInitTensor(producer.getOutputOperand(0))) { BlockArgument bbArg = producerBlock.getArguments() .drop_front(producer.getNumInputs()) // TODO: bbArg index of @@ -155,9 +154,10 @@ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); } // 5. Remaining consumer's input operands (drop past index `consumerIdx`). - for (BlockArgument bbArg : consumerBlock.getArguments() - .take_front(consumer.getNumInputs()) - .drop_front(consumerIdx + 1)) + for (BlockArgument bbArg : + consumerBlock.getArguments() + .take_front(consumer.getNumInputs()) + .drop_front(consumerOpOperand->getOperandNumber() + 1)) mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType())); // 6. All of consumer's output operands. for (BlockArgument bbArg : @@ -191,7 +191,8 @@ assert(!producer->isAncestor(replacement.getDefiningOp()) && "yielded value must have been mapped"); } - mapper.map(consumerBlock.getArgument(consumerIdx), replacement); + mapper.map(consumerBlock.getArgument(consumerOpOperand->getOperandNumber()), + replacement); // 10. Clone operations from the consumer to the fused op. for (auto &op : consumerBlock.getOperations()) rewriter.clone(op, mapper); @@ -202,17 +203,16 @@ } static Optional> -fuseElementwiseOpsImpl(GenericOp producer, OpOperand &consumerOpOperand, +fuseElementwiseOpsImpl(GenericOp producer, OpOperand *consumerOpOperand, const ControlElementwiseOpsFusionFn &controlFn, PatternRewriter &rewriter) { - auto consumer = cast(consumerOpOperand.getOwner()); - unsigned consumerIdx = consumerOpOperand.getOperandNumber(); - if (!areElementwiseOpsFusable(producer, consumer, consumerIdx) || - !controlFn(producer->getResult(0), consumerOpOperand)) + auto consumer = cast(consumerOpOperand->getOwner()); + if (!areElementwiseOpsFusable(producer, consumer, consumerOpOperand) || + !controlFn(producer->getResult(0), *consumerOpOperand)) return llvm::None; // TODO: allow fusing the producer of an output operand. - assert(consumerIdx < consumer.getNumInputs() && + assert(consumer.isInputTensor(consumerOpOperand) && "expected producer of input operand"); // Compute the fused operands list and indexing maps. @@ -224,62 +224,66 @@ consumer->getNumOperands()); // In the following, numbering matches that of `generateFusedTensorOpRegion`. // 3. Consumer input operands/maps up to consumerIdx (exclusive). - llvm::append_range(fusedOperands, - consumer.getInputs().take_front(consumerIdx)); - llvm::append_range( - fusedIndexMaps, - ArrayRef{consumer.getInputIndexingMaps()}.take_front( - consumerIdx)); + SmallVector consumerInputs = consumer.getInputOperands(); + SmallVector::iterator it = + llvm::find(consumerInputs, consumerOpOperand); + assert(it != consumerInputs.end() && "expected to find the consumer operand"); + for (OpOperand *opOperand : llvm::make_range(consumerInputs.begin(), it)) { + fusedOperands.push_back(opOperand->get()); + fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); + } // 4. Splice in producer's input operands/maps. - llvm::append_range(fusedOperands, producer.getInputs()); assert(producer->getNumResults() == 1 && "expected single result producer"); - AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0); - for (auto &inputOpOperand : producer.getInputOpOperands()) { + AffineMap producerResultIndexMap = + producer.getTiedIndexingMap(producer.getOutputOperand(0)); + for (OpOperand *opOperand : producer.getInputOperands()) { + fusedOperands.push_back(opOperand->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - inputOpOperand, producerResultIndexMap, - consumer.getInputIndexingMap(consumerIdx)); + opOperand, producerResultIndexMap, + consumer.getTiedIndexingMap(consumerOpOperand)); fusedIndexMaps.push_back(map); } // 4.b. Producer output operand/map that is fused needs to be passed if it is // an "initTensor" (i.e. its value is actually read). assert(producer->getNumResults() == 1 && "expected single result producer"); - if (producer.isInitTensor(&producer.getOutputOpOperands()[0])) { - llvm::append_range(fusedOperands, producer.getOutputs().take_front()); + if (producer.isInitTensor(producer.getOutputOperand(0))) { + fusedOperands.push_back(producer.getOutputOperand(0)->get()); // Compute indexing maps for the producer args in the fused operation. AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - producer.getOutputOpOperands().front(), producerResultIndexMap, - consumer.getOutputIndexingMap(0)); + producer.getOutputOperand(0), producerResultIndexMap, + consumer.getTiedIndexingMap(consumerOpOperand)); fusedIndexMaps.push_back(map); } // 5. Remaining consumer's input operands/maps (drop past index // `consumerIdx`). - llvm::append_range(fusedOperands, - consumer.getInputs().drop_front(consumerIdx + 1)); - llvm::append_range( - fusedIndexMaps, - ArrayRef{consumer.getInputIndexingMaps()}.drop_front( - consumerIdx + 1)); + for (OpOperand *opOperand : + llvm::make_range(std::next(it), consumerInputs.end())) { + fusedOperands.push_back(opOperand->get()); + fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); + } // 6. All of consumer's output operands (skip operands: added by the builder). - // llvm::append_range(fusedOperands, consumer.getOutputs()); - llvm::append_range(fusedIndexMaps, consumer.getOutputIndexingMaps()); + for (OpOperand *opOperand : consumer.getOutputOperands()) + fusedIndexMaps.push_back(consumer.getTiedIndexingMap(opOperand)); // 7. All of producer's output operands/maps except the one fused. // TODO: allow fusion of multi-result producers. assert(producer->getNumResults() == 1 && "expected single result producer"); // Generate the fused op. + SmallVector consumerOutputs = consumer.getOutputOperands(); auto fusedOp = rewriter.create( consumer.getLoc(), consumer->getResultTypes(), /*inputs=*/fusedOperands, // TODO: handle outputs. - consumer.getOutputs(), rewriter.getAffineMapArrayAttr(fusedIndexMaps), + consumerOutputs, rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.iterator_types(), /*doc=*/nullptr, /*library_call=*/nullptr); // Construct an AffineMap from consumer loops to producer loops. // consumer loop -> tensor index - AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx); + AffineMap consumerResultIndexMap = + consumer.getTiedIndexingMap(consumerOpOperand); // tensor index -> producer loop AffineMap invProducerResultIndexMap = inversePermutation(producerResultIndexMap); @@ -290,8 +294,8 @@ invProducerResultIndexMap.compose(consumerResultIndexMap); generateFusedElementwiseOpRegion(rewriter, fusedOp, producer, consumer, - consumerToProducerLoopsMap, consumerIdx, - consumer.getNumLoops()); + consumerToProducerLoopsMap, + consumerOpOperand, consumer.getNumLoops()); return SmallVector(fusedOp->getResults()); } @@ -449,7 +453,7 @@ /// The added reshapes are again expanding patterns, so they will get fused /// with its producers if possible. static bool isFusableWithReshapeByDimExpansion(GenericOp genericOp, - unsigned fusedTensorIndex) { + OpOperand *fusableOpOperand) { // Is fusable only if: // - All the indexing maps for operands and results are projected // permutations. @@ -462,7 +466,7 @@ .getValue() .isProjectedPermutation(); }) && - genericOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 && + genericOp.getTiedIndexingMap(fusableOpOperand).getNumResults() > 0 && llvm::all_of(genericOp.iterator_types(), [](Attribute attr) { return attr.cast().getValue() == getParallelIteratorTypeName(); @@ -478,7 +482,7 @@ // of the expanded op given the `indexingMap` of the fused operand/result of // the generic op, the `reassocationMaps` of the reshape op and the shape of // the expanded op. - LogicalResult compute(LinalgOp linalgOp, unsigned fusedTensorIndex, + LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, ArrayRef reassociationMaps, ArrayRef expandedShape, PatternRewriter &rewriter); @@ -503,13 +507,13 @@ } // namespace LogicalResult ExpansionInfo::compute(LinalgOp linalgOp, - unsigned fusedTensorIndex, + OpOperand *fusableOpOperand, ArrayRef reassociationMaps, ArrayRef expandedShape, PatternRewriter &rewriter) { if (reassociationMaps.empty()) return failure(); - AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); + AffineMap fusedIndexMap = linalgOp.getTiedIndexingMap(fusableOpOperand); Optional> originalLoopRange = linalgOp.getStaticLoopRanges(); @@ -676,9 +680,9 @@ /// been satisfied. static Optional> fuseWithReshapeByExpansion(GenericOp genericOp, TensorReshapeOp reshapeOp, - unsigned fusedTensorIndex, + OpOperand *fusableOpOperand, PatternRewriter &rewriter) { - assert(isFusableWithReshapeByDimExpansion(genericOp, fusedTensorIndex) && + assert(isFusableWithReshapeByDimExpansion(genericOp, fusableOpOperand) && "preconditions for fuse operation failed"); // Check if reshape is expanding or collapsing. bool isExpanding = @@ -687,7 +691,7 @@ isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); ExpansionInfo expansionInfo; - if (failed(expansionInfo.compute(genericOp, fusedTensorIndex, + if (failed(expansionInfo.compute(genericOp, fusableOpOperand, reshapeOp.getReassociationMaps(), expandedType.getShape(), rewriter))) return llvm::None; @@ -701,39 +705,39 @@ })); SmallVector expandedOpOperands; - for (auto operand : llvm::enumerate(genericOp.getInputs())) { - if (operand.index() == fusedTensorIndex) { + for (OpOperand *opOperand : genericOp.getInputOperands()) { + if (opOperand == fusableOpOperand) { expandedOpOperands.push_back(reshapeOp.src()); continue; } - AffineMap indexingMap = genericOp.getInputIndexingMap(operand.index()); + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); RankedTensorType expandedOperandType = - getExpandedType(operand.value().getType().cast(), + getExpandedType(opOperand->get().getType().cast(), indexingMap, expansionInfo); - if (expandedOperandType != operand.value().getType()) { + if (expandedOperandType != opOperand->get().getType()) { // Reshape the operand to get the right type. SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); expandedOpOperands.push_back(rewriter.create( - genericOp.getLoc(), expandedOperandType, operand.value(), + genericOp.getLoc(), expandedOperandType, opOperand->get(), reassociation)); continue; } - expandedOpOperands.push_back(operand.value()); + expandedOpOperands.push_back(opOperand->get()); } Location loc = genericOp.getLoc(); SmallVector outputs; - for (auto result : llvm::enumerate(genericOp.getOutputs())) { - AffineMap indexingMap = genericOp.getOutputIndexingMap(result.index()); + for (OpOperand *opOperand : genericOp.getOutputOperands()) { + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); RankedTensorType expandedOutputType = - getExpandedType(result.value().getType().cast(), + getExpandedType(opOperand->get().getType().cast(), indexingMap, expansionInfo); - if (expandedOutputType != result.value().getType()) { + if (expandedOutputType != opOperand->get().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); outputs.push_back(rewriter.create( - genericOp.getLoc(), expandedOutputType, result.value(), + genericOp.getLoc(), expandedOutputType, opOperand->get(), reassociation)); } } @@ -757,17 +761,19 @@ // Reshape the result values to their original shape if this is a collapsing // reshape folded into its consumer. SmallVector resultVals; - for (auto result : llvm::enumerate(genericOp->getResults())) { - if (!isExpanding && - resultTypes[result.index()] != result.value().getType()) { + for (OpResult opResult : genericOp->getOpResults()) { + int64_t resultNumber = opResult.getResultNumber(); + if (!isExpanding && resultTypes[resultNumber] != opResult.getType()) { SmallVector reassociation = getReassociationForExpansion( - genericOp.getOutputIndexingMap(result.index()), expansionInfo); + genericOp.getTiedIndexingMap( + genericOp.getOutputOperand(resultNumber)), + expansionInfo); resultVals.push_back(rewriter.create( - genericOp.getLoc(), result.value().getType(), - fusedOp->getResult(result.index()), reassociation)); + genericOp.getLoc(), opResult.getType(), + fusedOp->getResult(resultNumber), reassociation)); } else { - resultVals.push_back(fusedOp->getResult(result.index())); + resultVals.push_back(fusedOp->getResult(resultNumber)); } } // Assuming a single result. @@ -809,12 +815,12 @@ PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); - for (auto operand : llvm::enumerate(genericOp.getInputs())) { + for (auto en : llvm::enumerate(genericOp.getInputOperands())) { TensorReshapeOp reshapeOp = - operand.value().getDefiningOp(); + en.value()->get().getDefiningOp(); if (!reshapeOp || !isTensorReshapeOpFoldableByLinearization( - reshapeOp, genericOp.getInputIndexingMap(operand.index()), + reshapeOp, genericOp.getTiedIndexingMap(en.value()), /*asProducer =*/true) || (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), @@ -822,18 +828,17 @@ continue; // Compute the fused operands list, - SmallVector fusedOperands(genericOp.getInputs()); - fusedOperands[operand.index()] = reshapeOp.src(); - fusedOperands.append(genericOp.getOutputs().begin(), - genericOp.getOutputs().end()); + SmallVector fusedOperands = genericOp.getInputOperands(); + fusedOperands[en.index()] = reshapeOp.src(); + SmallVector outputOperands = genericOp.getOutputOperands(); + llvm::append_range(fusedOperands, outputOperands); // Compute indexing_maps for the fused operation. The indexing_maps for // the operands of the consumers that arent fused are the same. - SmallVector fusedIndexMaps = llvm::to_vector<4>( - genericOp.indexing_maps().template getAsValueRange()); + SmallVector fusedIndexMaps = genericOp.getIndexingMaps(); // Accepted consumer maps are either identity or permutation. - auto invMap = inversePermutation(fusedIndexMaps[operand.index()]); + auto invMap = inversePermutation(fusedIndexMaps[en.index()]); // Compute the indexing map to use for the result of the producer. AffineMap modifiedMap = @@ -843,7 +848,7 @@ if (!expr.isPureAffine()) return failure(); } - fusedIndexMaps[operand.index()] = modifiedMap; + fusedIndexMaps[en.index()] = modifiedMap; // Further check that the resulting index maps can be fused and // inverted. Without this the resultant op is not legal. @@ -917,35 +922,35 @@ return failure(); // Only support identity output maps. It could be extended to permuations if // needed. - if (llvm::any_of(genericOp.getOutputIndexingMaps(), - [](AffineMap map) { return !map.isIdentity(); })) + if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { + return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); + })) return failure(); int64_t destRank = genericOp.getNumParallelLoops(); - SmallVector newOperands = - llvm::to_vector<4>(genericOp.getInputs()); + SmallVector newOperands = genericOp.getInputOperands(); TensorReshapeOp reshapeFound; // 1. Look for tensor_reshape operands and figure out save the dimensions // merged. - for (auto operand : llvm::enumerate(genericOp.getInputs())) { + for (auto en : llvm::enumerate(genericOp.getInputOperands())) { TensorReshapeOp reshapeOp = - operand.value().template getDefiningOp(); + en.value()->get().template getDefiningOp(); if (!reshapeOp || reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank()) { continue; } // TODO: We could support non-identity map as long as the merged // dimensions are still contiguous. - if (!genericOp.getIndexingMaps()[operand.index()].isIdentity()) + if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) continue; if (reshapeFound) { // Only support a second reshape op if it has the same reassociate maps. if (reshapeFound.getReassociationMaps() == reshapeOp.getReassociationMaps()) - newOperands[operand.index()] = reshapeOp.src(); + newOperands[en.index()] = reshapeOp.src(); continue; } reshapeFound = reshapeOp; - newOperands[operand.index()] = reshapeOp.src(); + newOperands[en.index()] = reshapeOp.src(); } if (!reshapeFound) return failure(); @@ -962,9 +967,9 @@ // 2. Verify that we can merge the dimensions in the linalg and that we // don't need to create new reshapes operands. Inserting new reshape // operands would defeat the purpose of the transformation. - for (auto operand : llvm::enumerate(genericOp.getInputs())) { - if (operand.value() == newOperands[operand.index()]) { - AffineMap map = genericOp.getIndexingMaps()[operand.index()]; + for (auto en : llvm::enumerate(genericOp.getInputOperands())) { + if (en.value()->get() == newOperands[en.index()]) { + AffineMap map = genericOp.getTiedIndexingMap(en.value()); for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { if (reassociation[remap[map.getDimPosition(i)]].size() > 1) return failure(); @@ -1036,9 +1041,9 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - for (auto operand : llvm::enumerate(genericOp.getInputs())) { + for (OpOperand *opOperand : genericOp.getInputOperands()) { TensorReshapeOp reshapeOp = - operand.value().getDefiningOp(); + opOperand->get().getDefiningOp(); if (!reshapeOp) continue; // Fold only if @@ -1046,15 +1051,12 @@ // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || - !isFusableWithReshapeByDimExpansion(genericOp, operand.index()) || - (!controlFoldingReshapes( - reshapeOp->getResult(0), - genericOp.getInputOpOperands()[operand.index()]))) + !isFusableWithReshapeByDimExpansion(genericOp, opOperand) || + (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) continue; Optional> replacementValues = - fuseWithReshapeByExpansion(genericOp, reshapeOp, operand.index(), - rewriter); + fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); if (!replacementValues) return failure(); rewriter.replaceOp(genericOp, replacementValues.getValue()); @@ -1080,7 +1082,8 @@ if (!producer || !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || !isTensorReshapeOpFoldableByLinearization( - reshapeOp, producer.getOutputIndexingMap(0), + reshapeOp, + producer.getTiedIndexingMap(producer.getOutputOperand(0)), /*asProducer =*/false) || (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), @@ -1088,10 +1091,10 @@ return failure(); // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. - SmallVector fusedIndexMaps = llvm::to_vector<4>( - producer.indexing_maps().getAsValueRange()); + SmallVector fusedIndexMaps = producer.getIndexingMaps(); - auto invMap = inversePermutation(producer.getOutputIndexingMap(0)); + auto invMap = inversePermutation( + producer.getTiedIndexingMap(producer.getOutputOperand(0))); // Compute the indexing map to use for the operand of the producer. AffineMap modifiedMap = @@ -1113,11 +1116,13 @@ } Location loc = producer.getLoc(); + SmallVector inputOperands = producer.getInputOperands(); Value output = rewriter.create( - loc, producer.getOutputs()[0], reshapeOp.getReassociationExprs()); + loc, producer.getOutputOperand(0)->get(), + reshapeOp.getReassociationExprs()); auto fusedOp = rewriter.create( loc, reshapeOp.getResultType(), - /*inputs=*/producer.getInputs(), + /*inputs=*/inputOperands, // TODO: handle outputs. /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), producer.iterator_types(), @@ -1147,12 +1152,12 @@ GenericOp producer = reshapeOp.src().getDefiningOp(); if (!producer || producer.getNumOutputs() != 1 || !isFusableWithReshapeByDimExpansion(producer, - producer.getNumInputs()) || + producer.getOutputOperand(0)) || isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), reshapeOp.getReassociationMaps())) return failure(); Optional> replacementValues = fuseWithReshapeByExpansion( - producer, reshapeOp, producer.getNumInputs(), rewriter); + producer, reshapeOp, producer.getOutputOperand(0), rewriter); if (!replacementValues) return failure(); rewriter.replaceOp(reshapeOp, replacementValues.getValue()); @@ -1171,21 +1176,29 @@ PatternRewriter &rewriter) const override { if (!genericOp.hasTensorSemantics()) return failure(); - for (auto operand : llvm::enumerate(genericOp.getInputOpOperands())) { - Operation *def = operand.value().get().getDefiningOp(); + for (OpOperand *opOperand : genericOp.getInputOperands()) { + Operation *def = opOperand->get().getDefiningOp(); DenseElementsAttr constantAttr; if (!def || !matchPattern(def, m_Constant(&constantAttr)) || - !constantAttr.isSplat() || - !controlFn(def->getResult(0), operand.value())) + !constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand)) continue; - // The indexing_maps for the operands of the fused operation are same as - // those for the operands of the genericOp without the indexing map at - // operand.index() - SmallVector fusedIndexMaps = llvm::to_vector<4>( - genericOp.indexing_maps().getAsValueRange()); - fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index())); + // The operands and the indexing_maps of the fused operation the same as + // the operands and indexing_maps of the generic operations with the + // values at the constant index dropped. + SmallVector fusedIndexMaps; + SmallVector fusedOperands; + fusedIndexMaps.reserve(genericOp.getNumInputsAndOutputs()); + fusedOperands.reserve(genericOp.getNumInputs()); + for (OpOperand *inputOperand : genericOp.getInputOperands()) { + if (inputOperand == opOperand) + continue; + fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(inputOperand)); + fusedOperands.push_back(inputOperand->get()); + } + for (OpOperand *outputOperand : genericOp.getOutputOperands()) + fusedIndexMaps.push_back(genericOp.getTiedIndexingMap(outputOperand)); // Check if the operation shapes to loops map is computable. if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { @@ -1193,20 +1206,16 @@ genericOp, "fused op loop bound computation failed"); } - // The operands list is same as the genericOp with the argument for - // constant index dropped. - SmallVector fusedOperands(genericOp.getInputs()); - fusedOperands.erase(std::next(fusedOperands.begin(), operand.index())); - // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( def->getLoc(), constantAttr.getSplatValue(), constantAttr.getType().getElementType()); + SmallVector outputOperands = genericOp.getOutputOperands(); auto fusedOp = rewriter.create( rewriter.getUnknownLoc(), genericOp->getResultTypes(), /*inputs=*/fusedOperands, - /*outputs=*/genericOp.getOutputs(), + /*outputs=*/outputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), genericOp.iterator_types(), /*doc=*/nullptr, @@ -1217,7 +1226,8 @@ Region ®ion = genericOp->getRegion(0); Block &entryBlock = *region.begin(); BlockAndValueMapping mapping; - mapping.map(entryBlock.getArgument(operand.index()), scalarConstant); + mapping.map(entryBlock.getArgument(opOperand->getOperandNumber()), + scalarConstant); Region &fusedRegion = fusedOp->getRegion(0); rewriter.cloneRegionBefore(region, fusedRegion, fusedRegion.begin(), mapping); @@ -1233,7 +1243,7 @@ } // namespace static Optional> -fuseElementwiseOps(PatternRewriter &rewriter, OpOperand &consumerOpOperand, +fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, GenericOp producer, const ControlElementwiseOpsFusionFn &controlFn) { if (producer->getNumResults() != 1) @@ -1261,9 +1271,9 @@ LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Find the first operand that is defined by another generic op on tensors. - for (OpOperand &opOperand : genericOp.getShapedOpOperands()) { + for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { auto producer = - dyn_cast_or_null(opOperand.get().getDefiningOp()); + dyn_cast_or_null(opOperand->get().getDefiningOp()); if (!producer || !producer.hasTensorSemantics()) continue; Optional> fusedOpResults = @@ -1322,9 +1332,9 @@ rewriter.startRootUpdate(op); bool modifiedOutput = false; Location loc = op.getLoc(); - for (OpOperand &opOperand : op.getOutputOpOperands()) { - if (!op.payloadUsesValueFromOpOperand(&opOperand)) { - Value operandVal = opOperand.get(); + for (OpOperand *opOperand : op.getOutputOperands()) { + if (!op.payloadUsesValueFromOperand(opOperand)) { + Value operandVal = opOperand->get(); auto operandType = operandVal.getType().dyn_cast(); if (!operandType) continue; @@ -1344,7 +1354,7 @@ Value initTensor = rewriter.create( loc, dynamicDims, operandType.getShape(), operandType.getElementType()); - op->setOperand(opOperand.getOperandNumber(), initTensor); + op->setOperand(opOperand->getOperandNumber(), initTensor); } } if (!modifiedOutput) {