diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -36,6 +36,11 @@ continue; indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&opOperand)); } + if (indexingMaps.empty()) { + // If there are no indexing maps, the operand can only be dropped + // if the op has no loops. + return linalgOp.getNumLoops() == 0; + } return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -143,10 +143,10 @@ /// Generate the region of the fused tensor operation. The region of the fused /// op must be empty. -static void -generateFusedElementwiseOpRegion(RewriterBase &rewriter, GenericOp fusedOp, - AffineMap consumerToProducerLoopsMap, - OpOperand *fusedOperand, unsigned nloops) { +static void generateFusedElementwiseOpRegion( + RewriterBase &rewriter, GenericOp fusedOp, + AffineMap consumerToProducerLoopsMap, OpOperand *fusedOperand, + unsigned nloops, llvm::SmallDenseSet &preservedProducerResults) { auto producer = cast(fusedOperand->get().getDefiningOp()); auto consumer = cast(fusedOperand->getOwner()); // Build the region of the fused op. @@ -202,9 +202,13 @@ mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); // 6. All of the producer's output operands - for (BlockArgument bbArg : - producerBlock.getArguments().take_back(producer.getNumDpsInits())) - mapper.map(bbArg, fusedBlock->addArgument(bbArg.getType(), bbArg.getLoc())); + for (auto bbArg : llvm::enumerate( + producerBlock.getArguments().take_back(producer.getNumDpsInits()))) { + if (!preservedProducerResults.count(bbArg.index())) + continue; + mapper.map(bbArg.value(), fusedBlock->addArgument(bbArg.value().getType(), + bbArg.value().getLoc())); + } // 7. All of consumer's output operands. for (BlockArgument bbArg : @@ -247,8 +251,11 @@ SmallVector fusedYieldValues; fusedYieldValues.reserve(producerYieldOp.getNumOperands() + consumerYieldOp.getNumOperands()); - for (auto producerYieldVal : producerYieldOp.getOperands()) - fusedYieldValues.push_back(mapper.lookupOrDefault(producerYieldVal)); + for (auto producerYieldVal : llvm::enumerate(producerYieldOp.getOperands())) { + if (preservedProducerResults.count(producerYieldVal.index())) + fusedYieldValues.push_back( + mapper.lookupOrDefault(producerYieldVal.value())); + } for (auto consumerYieldVal : consumerYieldOp.getOperands()) fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); rewriter.create(fusedOp.getLoc(), fusedYieldValues); @@ -269,6 +276,18 @@ // TODO: allow fusing the producer of an output operand. assert(consumer.isDpsInput(fusedOperand) && "expected producer of input operand"); + /// Find the results of the producer that have uses outside of the consumer. + llvm::SmallDenseSet preservedProducerResults; + for (auto producerResult : llvm::enumerate(producer->getResults())) { + auto outputOperand = producer.getDpsInitOperand(producerResult.index()); + if (producer.payloadUsesValueFromOperand(outputOperand) || + !producer.canOpOperandsBeDropped(outputOperand) || + llvm::any_of(producerResult.value().getUsers(), [&](Operation *user) { + return user != consumer.getOperation(); + })) { + preservedProducerResults.insert(producerResult.index()); + } + } // Compute the fused operands list and indexing maps. SmallVector fusedInputOperands, fusedOutputOperands; @@ -276,9 +295,9 @@ SmallVector fusedIndexMaps; fusedInputOperands.reserve(producer.getNumDpsInputs() + consumer.getNumDpsInputs()); - fusedOutputOperands.reserve(producer.getNumDpsInits() + + fusedOutputOperands.reserve(preservedProducerResults.size() + consumer.getNumDpsInits()); - fusedResultTypes.reserve(producer.getNumDpsInits() + + fusedResultTypes.reserve(preservedProducerResults.size() + consumer.getNumDpsInits()); fusedIndexMaps.reserve(producer->getNumOperands() + consumer->getNumOperands()); @@ -313,13 +332,16 @@ } // 6. Collect all of the producer outputs. - for (OpOperand *opOperand : producer.getDpsInitOperands()) { - fusedOutputOperands.push_back(opOperand->get()); + for (auto opOperand : llvm::enumerate(producer.getDpsInitOperands())) { + if (!preservedProducerResults.count(opOperand.index())) + continue; + + fusedOutputOperands.push_back(opOperand.value()->get()); AffineMap map = getIndexingMapOfProducerOperandsInCoordinatesOfFusedOp( - opOperand, producerResultIndexMap, + opOperand.value(), producerResultIndexMap, consumer.getMatchingIndexingMap(fusedOperand)); fusedIndexMaps.push_back(map); - fusedResultTypes.push_back(opOperand->get().getType()); + fusedResultTypes.push_back(opOperand.value()->get().getType()); } // 7. All of consumer's output operands (skip operands: added by the builder). @@ -358,9 +380,9 @@ AffineMap consumerToProducerLoopsMap = invProducerResultIndexMap.compose(consumerResultIndexMap); - generateFusedElementwiseOpRegion(rewriter, fusedOp, - consumerToProducerLoopsMap, fusedOperand, - consumer.getNumLoops()); + generateFusedElementwiseOpRegion( + rewriter, fusedOp, consumerToProducerLoopsMap, fusedOperand, + consumer.getNumLoops(), preservedProducerResults); return fusedOp.getOperation(); }