diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -860,67 +860,128 @@ PatternRewriter &rewriter) const override { // Create a map from argument position in the original op to the argument // position in the new op. If the argument is dropped it wont have an entry. - llvm::SmallDenseMap origToNewPos; - unsigned numNewArgs = 0; SmallVector droppedOpOperands; - llvm::SmallDenseSet droppedOutputs; // Information needed to build the new op. SmallVector newInputOperands, newOutputOperands; SmallVector newIndexingMaps; + + // Gather information about duplicate input operands. + llvm::SmallDenseMap origInsToNewInsPos = + deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands, + newIndexingMaps); + + // Gather information about the dropped outputs. + llvm::SmallDenseMap origOutsToNewOutsPos = + deduplicateOutputOperands(genericOp, droppedOpOperands, + newOutputOperands, newIndexingMaps); + + // Check if there is any change to operands. + if (newInputOperands.size() + newOutputOperands.size() == + static_cast(genericOp.getNumInputsAndOutputs())) + return failure(); + + // Create the new op with the body being empty. + Location loc = genericOp.getLoc(); SmallVector newResultTypes; + if (genericOp.hasTensorSemantics()) { + newResultTypes = llvm::to_vector(llvm::map_range( + newOutputOperands, [](Value v) { return v.getType(); })); + } + auto newOp = rewriter.create( + loc, newResultTypes, newInputOperands, newOutputOperands, + rewriter.getAffineMapArrayAttr(newIndexingMaps), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr(), + [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) { + return; + }); + // Copy over unknown attributes. They might be load bearing for some flow. + ArrayRef odsAttrs = genericOp.getAttributeNames(); + for (NamedAttribute kv : genericOp->getAttrs()) + if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) + newOp->setAttr(kv.getName(), kv.getValue()); - // Input argument can be dropped if - // - it has no uses, or, - // - there is a duplicate operand which is accessed using the same - // indexing map. - llvm::SmallDenseMap, unsigned> dedupedInputs; - auto indexingMaps = genericOp.getIndexingMaps(); - ArrayRef unprocessedIndexingMaps(indexingMaps); - for (OpOperand *inputOpOperand : genericOp.getInputOperands()) { - BlockArgument arg = genericOp.getTiedBlockArgument(inputOpOperand); - unsigned argNum = arg.getArgNumber(); - unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front(); + // Fix up the payload of the canonicalized operation. + populateOpPayload(genericOp, newOp, origInsToNewInsPos, + origOutsToNewOutsPos, rewriter); + // Replace all live uses of the op. + SmallVector replacementsVals(genericOp->getNumResults(), nullptr); + for (auto result : llvm::enumerate(genericOp.getResults())) { + auto it = origOutsToNewOutsPos.find(result.index()); + if (it == origOutsToNewOutsPos.end()) + continue; + replacementsVals[result.index()] = newOp.getResult(it->second); + } + rewriter.replaceOp(genericOp, replacementsVals); + return success(); + } + +private: + // Deduplicate input operands, and return the + // - Mapping from operand position in the original op, to operand position in + // the canonicalized op. + // - The preserved input operands list (by reference). + llvm::SmallDenseMap + deduplicateInputOperands(GenericOp genericOp, + SmallVector &droppedOpOperands, + SmallVector &newInputOperands, + SmallVector &newIndexingMaps) const { + llvm::SmallDenseMap origToNewPos; + llvm::SmallDenseMap, unsigned> dedupedInputs; + for (auto inputOpOperand : llvm::enumerate(genericOp.getInputOperands())) { // Check if operand is dead and if dropping the indexing map makes the // loops to shape computation invalid. - if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) { + if (!genericOp.payloadUsesValueFromOperand(inputOpOperand.value())) { // Add the current operands to the list of potentially droppable // operands. If it cannot be dropped, this needs to be popped back. - droppedOpOperands.push_back(inputOpOperand); + droppedOpOperands.push_back(inputOpOperand.value()); if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) continue; droppedOpOperands.pop_back(); } // Check if this operand is a duplicate. - AffineMap indexingMap = genericOp.getTiedIndexingMap(inputOpOperand); + AffineMap indexingMap = + genericOp.getTiedIndexingMap(inputOpOperand.value()); auto it = dedupedInputs.find( - std::make_pair(inputOpOperand->get(), indexingMap)); + std::make_pair(inputOpOperand.value()->get(), indexingMap)); if (it != dedupedInputs.end()) { - origToNewPos[argNum] = it->second; - droppedOpOperands.push_back(inputOpOperand); + origToNewPos[inputOpOperand.index()] = it->second; + droppedOpOperands.push_back(inputOpOperand.value()); continue; } // This is a preserved argument. - origToNewPos[argNum] = numNewArgs; - dedupedInputs[{inputOpOperand->get(), indexingMap}] = numNewArgs; - newInputOperands.push_back(inputOpOperand->get()); + origToNewPos[inputOpOperand.index()] = newInputOperands.size(); + dedupedInputs[{inputOpOperand.value()->get(), indexingMap}] = + newInputOperands.size(); + newInputOperands.push_back(inputOpOperand.value()->get()); newIndexingMaps.push_back(indexingMap); - numNewArgs++; } + return origToNewPos; + } + // Deduplicate output operands, and return the + // - Mapping from operand position in the original op, to operand position in + // the canonicalized op. + // - The preserved output operands list (by reference). + llvm::SmallDenseMap + deduplicateOutputOperands(GenericOp genericOp, + SmallVector &droppedOpOperands, + SmallVector &newOutputOperands, + SmallVector &newIndexingMaps) const { + llvm::SmallDenseMap origToNewPos; // If the op doesnt have tensor semantics, keep all the outputs as // preserved. if (!genericOp.hasTensorSemantics()) { - for (OpOperand *outputOpOperand : genericOp.getOutputOperands()) { - unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front(); - BlockArgument arg = genericOp.getTiedBlockArgument(outputOpOperand); - origToNewPos[arg.getArgNumber()] = numNewArgs++; - newOutputOperands.push_back(outputOpOperand->get()); + for (auto outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + newOutputOperands.push_back(outputOpOperand.value()->get()); newIndexingMaps.push_back( - genericOp.getTiedIndexingMap(outputOpOperand)); + genericOp.getTiedIndexingMap(outputOpOperand.value())); } } else { // Output argument can be dropped if the result has @@ -928,12 +989,9 @@ // - it is not used in the payload, and // - the corresponding indexing maps are not needed for loop bound // computation. - for (const auto &outputOpOperand : + for (auto outputOpOperand : llvm::enumerate(genericOp.getOutputOperands())) { - unprocessedIndexingMaps = unprocessedIndexingMaps.drop_front(); Value result = genericOp.getResult(outputOpOperand.index()); - BlockArgument arg = - genericOp.getTiedBlockArgument(outputOpOperand.value()); if (result.use_empty() && !genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { // Check if the opoperand can be dropped without affecting loop bound @@ -941,77 +999,75 @@ // checking. If it cannot be dropped, need to pop the value back. droppedOpOperands.push_back(outputOpOperand.value()); if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { - droppedOutputs.insert(outputOpOperand.index()); continue; } droppedOpOperands.pop_back(); } - origToNewPos[arg.getArgNumber()] = numNewArgs++; + origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); newOutputOperands.push_back(outputOpOperand.value()->get()); newIndexingMaps.push_back( genericOp.getTiedIndexingMap(outputOpOperand.value())); - newResultTypes.push_back(result.getType()); } } - // Check if there is any change to operands. - if (newInputOperands.size() + newOutputOperands.size() == - static_cast(genericOp.getNumInputsAndOutputs())) - return failure(); - - // Create the new op with the body being empty. - Location loc = genericOp.getLoc(); - auto newOp = rewriter.create( - loc, newResultTypes, newInputOperands, newOutputOperands, - rewriter.getAffineMapArrayAttr(newIndexingMaps), - genericOp.iterator_types(), genericOp.docAttr(), - genericOp.library_callAttr(), - [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) { - return; - }); - // Copy over unknown attributes. They might be load bearing for some flow. - ArrayRef odsAttrs = genericOp.getAttributeNames(); - for (NamedAttribute kv : genericOp->getAttrs()) - if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) - newOp->setAttr(kv.getName(), kv.getValue()); + return origToNewPos; + } + // Populate the body of the canonicalized operation. + void populateOpPayload( + GenericOp genericOp, GenericOp newOp, + const llvm::SmallDenseMap &origInsToNewInsPos, + const llvm::SmallDenseMap &origOutsToNewOutsPos, + PatternRewriter &rewriter) const { // Merge the body of the original op with the new op. Block *newOpBlock = &newOp.region().front(); + assert(newOpBlock->empty() && "expected new op to have an empty payload"); Block *origOpBlock = &genericOp.region().front(); SmallVector replacements(origOpBlock->getNumArguments(), nullptr); - for (auto argNum : llvm::seq(0, origOpBlock->getNumArguments())) { - auto it = origToNewPos.find(argNum); - if (it != origToNewPos.end()) - replacements[argNum] = newOpBlock->getArgument(it->second); - } + + // Replace all arguments in the original op, with arguments from the + // canonicalized op. + auto updateReplacements = + [&](OpOperandVector &origOperands, OpOperandVector &newOperands, + const llvm::SmallDenseMap &map) { + for (auto origOperand : llvm::enumerate(origOperands)) { + auto it = map.find(origOperand.index()); + if (it == map.end()) + continue; + OpOperand *newOperand = newOperands[it->second]; + replacements[origOperand.value()->getOperandNumber()] = + newOpBlock->getArgument(newOperand->getOperandNumber()); + } + }; + + OpOperandVector origInputOperands = genericOp.getInputOperands(); + OpOperandVector newInputOperands = newOp.getInputOperands(); + updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); + + OpOperandVector origOutputOperands = genericOp.getOutputOperands(); + OpOperandVector newOutputOperands = newOp.getOutputOperands(); + updateReplacements(origOutputOperands, newOutputOperands, + origOutsToNewOutsPos); + rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); // Drop the unused yield args. - Block *block = &newOp.region().front(); - if (!droppedOutputs.empty()) { + if (newOp.getNumOutputs() != genericOp.getNumOutputs()) { OpBuilder::InsertionGuard g(rewriter); - SmallVector newYieldVals; - YieldOp origYieldOp = cast(block->getTerminator()); + YieldOp origYieldOp = cast(newOpBlock->getTerminator()); rewriter.setInsertionPoint(origYieldOp); + + SmallVector newYieldVals(newOp.getNumOutputs(), nullptr); for (const auto &yieldOpOperands : llvm::enumerate(origYieldOp.values())) { - if (!droppedOutputs.count(yieldOpOperands.index())) { - newYieldVals.push_back(yieldOpOperands.value()); + auto it = origOutsToNewOutsPos.find(yieldOpOperands.index()); + if (it == origOutsToNewOutsPos.end()) continue; - } + newYieldVals[it->second] = yieldOpOperands.value(); } rewriter.replaceOpWithNewOp(origYieldOp, newYieldVals); } - - // Replace all live uses of the op. - SmallVector replacementsVals(genericOp->getNumResults(), nullptr); - unsigned newResultNum = 0; - for (const auto &result : llvm::enumerate(genericOp.getResults())) - if (!droppedOutputs.count(result.index())) - replacementsVals[result.index()] = newOp.getResult(newResultNum++); - rewriter.replaceOp(genericOp, replacementsVals); - return success(); } };