diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -848,6 +848,46 @@ outputBuffers); } +bool isOutputOpOperandDead(linalg::GenericOp genericOp, + OpOperand *outputOpOperand, BlockArgument &outputArg, + Value result) { + if (!result.use_empty()) + return false; + // If out operand not used in payload, we can drop it. + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) + return true; + + // The out operand that is part of a payload can be dropped if + // these conditions are met: + // - Result from out operand is dead. + // - Block arg from out operand has a single use in the %cycle + // instruction. + // - Cycle has a single use in the yield or it itself is a yield. + + // Check block arg and cycle from out operand has a single use. + if (!outputArg.hasOneUse()) + return false; + Operation *cycleOp = *outputArg.user_begin(); + + // Check cycle has at most one use. + if (!cycleOp->hasOneUse() && !cycleOp->use_empty()) + return false; + + // Check that if it has one use it is a yield. + if (cycleOp->hasOneUse()) { + if (!isa(*cycleOp->user_begin())) + return false; + } + + // Check that if it has no use, cycle itself is a yield. + if (cycleOp->use_empty()) { + if (!isa(cycleOp)) + return false; + } + + return true; +} + LogicalResult GenericOp::verify() { return success(); } namespace { @@ -986,57 +1026,58 @@ newIndexingMaps.push_back( genericOp.getMatchingIndexingMap(outputOpOperand.value())); } - } else { - // Output argument can be dropped if the result has - // - no users, and - // - it is not used in the payload, and - // - the corresponding indexing maps are not needed for loop bound - // computation. - auto yieldOp = cast(genericOp.getBody()->getTerminator()); - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { - Value result = genericOp.getResult(outputOpOperand.index()); - AffineMap indexingMap = - genericOp.getMatchingIndexingMap(outputOpOperand.value()); - auto key = - std::make_tuple(outputOpOperand.value()->get(), indexingMap, - yieldOp->getOperand(outputOpOperand.index())); - - // Do not drop an out if its value is used in the payload. - if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { - if (result.use_empty()) { - // Check if the opoperand can be dropped without affecting loop - // bound computation. Add the operand to the list of dropped op - // operand for checking. If it cannot be dropped, need to pop the - // value back. - droppedOpOperands.push_back(outputOpOperand.value()); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { - continue; - } - droppedOpOperands.pop_back(); - } - - // The out operand can also be dropped if it is computed redundantly - // by another result, the conditions for that are - // - The same operand is used as the out operand - // - The same indexing map is used - // - The same yield value is used. - auto it = dedupedOutpts.find(key); - if (it != dedupedOutpts.end()) { - origToNewPos[outputOpOperand.index()] = it->second; - droppedOpOperands.push_back(outputOpOperand.value()); - continue; - } + return origToNewPos; + } + // Output argument can be dropped if the result has + // - no users, and + // - it is not used in the payload, and + // - the corresponding indexing maps are not needed for loop bound + // computation. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + Value result = genericOp.getResult(outputOpOperand.index()); + AffineMap indexingMap = + genericOp.getMatchingIndexingMap(outputOpOperand.value()); + auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap, + yieldOp->getOperand(outputOpOperand.index())); + assert(genericOp.getNumOutputs() >= outputOpOperand.index() && + "Output op idx greater than number of outputs."); + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[outputOpOperand.index()]; + if (isOutputOpOperandDead(genericOp, outputOpOperand.value(), outputArg, + result)) { + // Check if the opoperand can be dropped without affecting loop + // bound computation. Add the operand to the list of dropped op + // operand for checking. If it cannot be dropped, need to pop the + // value back. + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + continue; } + droppedOpOperands.pop_back(); + } - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); - dedupedOutpts[key] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); - newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + // The out operand can also be dropped if it is computed redundantly + // by another result, the conditions for that are + // - The same operand is used as the out operand + // - The same indexing map is used + // - The same yield value is used. + auto it = dedupedOutpts.find(key); + if (it != dedupedOutpts.end()) { + origToNewPos[outputOpOperand.index()] = it->second; + droppedOpOperands.push_back(outputOpOperand.value()); + continue; + } } - } + origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + dedupedOutpts[key] = newOutputOperands.size(); + newOutputOperands.push_back(outputOpOperand.value()->get()); + newIndexingMaps.push_back( + genericOp.getMatchingIndexingMap(outputOpOperand.value())); + } return origToNewPos; } @@ -1059,8 +1100,16 @@ const llvm::SmallDenseMap &map) { for (const auto &origOperand : llvm::enumerate(origOperands)) { auto it = map.find(origOperand.index()); - if (it == map.end()) + if (it == map.end()) { + // Creates a placeholder constOp for deleted OpOperands. + Type elementType = origOperand.value()->get().getType(); + Value placeHolder = rewriter.create( + newOpBlock->getParent()->getLoc(), elementType, + rewriter.getZeroAttr(elementType)); + replacements[origOperand.value()->getOperandNumber()] = + placeHolder; continue; + } OpOperand *newOperand = newOperands[it->second]; replacements[origOperand.value()->getOperandNumber()] = newOpBlock->getArgument(newOperand->getOperandNumber());