diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -848,6 +848,36 @@ outputBuffers); } +bool isOutputOpOperandDead(linalg::GenericOp genericOp, + OpOperand *outputOpOperand, BlockArgument &outputArg, + Value result) { + if (!result.use_empty()) + return false; + // If out operand not used in payload, we can drop it. + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) + return true; + + // The out operand that is part of a payload can be dropped if + // these conditions are met: + // - Result from out operand is dead. + // - User of arg is yield. + + // Check block arg and cycle from out operand has a single use. + if (!outputArg.hasOneUse()) + return false; + Operation *argUserOp = *outputArg.user_begin(); + + // Check argUser has no other use. + if (!argUserOp->use_empty()) + return false; + + // Check that argUser is a yield. + if (!isa(argUserOp)) + return false; + + return true; +} + LogicalResult GenericOp::verify() { return success(); } namespace { @@ -986,57 +1016,58 @@ newIndexingMaps.push_back( genericOp.getMatchingIndexingMap(outputOpOperand.value())); } - } else { - // Output argument can be dropped if the result has - // - no users, and - // - it is not used in the payload, and - // - the corresponding indexing maps are not needed for loop bound - // computation. - auto yieldOp = cast(genericOp.getBody()->getTerminator()); - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { - Value result = genericOp.getResult(outputOpOperand.index()); - AffineMap indexingMap = - genericOp.getMatchingIndexingMap(outputOpOperand.value()); - auto key = - std::make_tuple(outputOpOperand.value()->get(), indexingMap, - yieldOp->getOperand(outputOpOperand.index())); - - // Do not drop an out if its value is used in the payload. - if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { - if (result.use_empty()) { - // Check if the opoperand can be dropped without affecting loop - // bound computation. Add the operand to the list of dropped op - // operand for checking. If it cannot be dropped, need to pop the - // value back. - droppedOpOperands.push_back(outputOpOperand.value()); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { - continue; - } - droppedOpOperands.pop_back(); - } - - // The out operand can also be dropped if it is computed redundantly - // by another result, the conditions for that are - // - The same operand is used as the out operand - // - The same indexing map is used - // - The same yield value is used. - auto it = dedupedOutpts.find(key); - if (it != dedupedOutpts.end()) { - origToNewPos[outputOpOperand.index()] = it->second; - droppedOpOperands.push_back(outputOpOperand.value()); - continue; - } + return origToNewPos; + } + // Output argument can be dropped if the result has + // - no users, and + // - it is not used in the payload, and + // - the corresponding indexing maps are not needed for loop bound + // computation. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + Value result = genericOp.getResult(outputOpOperand.index()); + AffineMap indexingMap = + genericOp.getMatchingIndexingMap(outputOpOperand.value()); + auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap, + yieldOp->getOperand(outputOpOperand.index())); + assert(genericOp.getNumOutputs() >= outputOpOperand.index() && + "Output op idx greater than number of outputs."); + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[outputOpOperand.index()]; + if (isOutputOpOperandDead(genericOp, outputOpOperand.value(), outputArg, + result)) { + // Check if the opoperand can be dropped without affecting loop + // bound computation. Add the operand to the list of dropped op + // operand for checking. If it cannot be dropped, need to pop the + // value back. + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + continue; } + droppedOpOperands.pop_back(); + } - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); - dedupedOutpts[key] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); - newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + // The out operand can also be dropped if it is computed redundantly + // by another result, the conditions for that are + // - The same operand is used as the out operand + // - The same indexing map is used + // - The same yield value is used. + auto it = dedupedOutpts.find(key); + if (it != dedupedOutpts.end()) { + origToNewPos[outputOpOperand.index()] = it->second; + droppedOpOperands.push_back(outputOpOperand.value()); + continue; + } } - } + origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + dedupedOutpts[key] = newOutputOperands.size(); + newOutputOperands.push_back(outputOpOperand.value()->get()); + newIndexingMaps.push_back( + genericOp.getMatchingIndexingMap(outputOpOperand.value())); + } return origToNewPos; } @@ -1059,8 +1090,16 @@ const llvm::SmallDenseMap &map) { for (const auto &origOperand : llvm::enumerate(origOperands)) { auto it = map.find(origOperand.index()); - if (it == map.end()) + if (it == map.end()) { + uint64_t argIndex = origOperand.value()->getOperandNumber(); + BlockArgument blockArg = origOpBlock->getArgument(argIndex); + Type operandType = blockArg.getType(); + Value placeHolder = rewriter.create( + blockArg.getLoc(), operandType, + rewriter.getZeroAttr(operandType)); + blockArg.replaceAllUsesWith(placeHolder); continue; + } OpOperand *newOperand = newOperands[it->second]; replacements[origOperand.value()->getOperandNumber()] = newOpBlock->getArgument(newOperand->getOperandNumber()); @@ -1169,13 +1208,68 @@ return success(); } }; + +/// Remove unused cycles. +/// We can remove unused cycle within a payload of generic region +/// if these conditions are met: +/// - Result from out operand is dead. +/// - Block arg from out operand has a single use in the %cycle +/// instruction. +/// - Cycle has a single use and it is in yield. +struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + + Value result = genericOp.getResult(outputOpOperand.index()); + + // Check that result from out operand is dead. + if (!result.use_empty()) + continue; + + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[outputOpOperand.index()]; + + // Check that blockArg has one use in cycle. + if (!outputArg.hasOneUse()) + continue; + + // Check cycle has at most one use. + Operation *cycleOp = *outputArg.user_begin(); + if (!cycleOp->hasOneUse()) + continue; + + // Check that the cycleUser is a yield. + Operation *cycleUserOp = *cycleOp->user_begin(); + if (!isa(cycleUserOp)) + continue; + + // Check that arg index matches yield index, otherwise it's an invalid + // cycle. + if (cycleUserOp->getOperand(outputOpOperand.index()) != + cycleOp->getResult(0)) + continue; + + // Directly replace the cycle with the blockArg such that + // DeduplicateAndRemoveDeadOperandsAndResults pattern can handle and + // remove it. + rewriter.replaceOp(cycleOp, outputArg); + rewriter.updateRootInPlace(genericOp, [] {}); + return success(); + } + return failure(); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } LogicalResult GenericOp::fold(ArrayRef,