diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1027,8 +1027,52 @@ droppedOpOperands.push_back(outputOpOperand.value()); continue; } - } + } else { + // The out operand that is part of a payload can be dropped if + // these conditions are met: + // - Result from out operand is dead. + // - Block arg from out operand has a single use in the %cycle + // instruction. + // - Cycle has a single use in the yield or it itself is a yield. + + // Checking result from out operand is dead. + if (result.use_empty()) { + // Obtaining valid block argument. + Block *genericBlock = &genericOp.getRegion().front(); + uint64_t argIndex = + genericOp.getNumInputs() + outputOpOperand.index(); + if (argIndex >= genericBlock->getNumArguments()) + continue; + + // Check block arg and cycle from out operand has a single use. + Value outArg = genericBlock->getArgument(argIndex); + if (!outArg.hasOneUse()) + continue; + Operation *cycleOp = *outArg.user_begin(); + // Check cycle has at most one use. + if (!cycleOp->hasOneUse() && !cycleOp->use_empty()) + continue; + + // Check that if it has one use it is a yield. + if (cycleOp->hasOneUse()) { + if (!isa(*outArg.user_begin())) + continue; + } + + // Check that if it has no use, cycle itself is a yield. + if (cycleOp->use_empty()) { + if (!isa(cycleOp)) + continue; + } + + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + continue; + } + droppedOpOperands.pop_back(); + } + } origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); dedupedOutpts[key] = newOutputOperands.size(); newOutputOperands.push_back(outputOpOperand.value()->get()); @@ -1059,8 +1103,16 @@ const llvm::SmallDenseMap &map) { for (const auto &origOperand : llvm::enumerate(origOperands)) { auto it = map.find(origOperand.index()); - if (it == map.end()) + if (it == map.end()) { + // Creates a placeholder constOp for deleted OpOperands. + Type elementType = origOperand.value()->get().getType(); + Value placeHolder = rewriter.create( + newOpBlock->getParent()->getLoc(), elementType, + rewriter.getZeroAttr(elementType)); + replacements[origOperand.value()->getOperandNumber()] = + placeHolder; continue; + } OpOperand *newOperand = newOperands[it->second]; replacements[origOperand.value()->getOperandNumber()] = newOpBlock->getArgument(newOperand->getOperandNumber());