diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -973,6 +973,8 @@ SmallVector &newOutputOperands, SmallVector &newIndexingMaps) const { llvm::SmallDenseMap origToNewPos; + llvm::SmallDenseMap, unsigned> + dedupedOutpts; // If the op doesnt have tensor semantics, keep all the outputs as // preserved. if (!genericOp.hasTensorSemantics()) { @@ -989,22 +991,45 @@ // - it is not used in the payload, and // - the corresponding indexing maps are not needed for loop bound // computation. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); for (auto outputOpOperand : llvm::enumerate(genericOp.getOutputOperands())) { Value result = genericOp.getResult(outputOpOperand.index()); - if (result.use_empty() && - !genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { - // Check if the opoperand can be dropped without affecting loop bound - // computation. Add the operand to the list of dropped op operand for - // checking. If it cannot be dropped, need to pop the value back. - droppedOpOperands.push_back(outputOpOperand.value()); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + AffineMap indexingMap = + genericOp.getTiedIndexingMap(outputOpOperand.value()); + auto key = + std::make_tuple(outputOpOperand.value()->get(), indexingMap, + yieldOp->getOperand(outputOpOperand.index())); + + // Do not drop an out if its value is used in the payload. + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + if (result.use_empty()) { + // Check if the opoperand can be dropped without affecting loop + // bound computation. Add the operand to the list of dropped op + // operand for checking. If it cannot be dropped, need to pop the + // value back. + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + continue; + } + droppedOpOperands.pop_back(); + } + + // The out operand can also be dropped if it is computed redundantly + // by another result, the conditions for that are + // - The same operand is used as the out operand + // - The same indexing map is used + // - The same yield value is used. + auto it = dedupedOutpts.find(key); + if (it != dedupedOutpts.end()) { + origToNewPos[outputOpOperand.index()] = it->second; + droppedOpOperands.push_back(outputOpOperand.value()); continue; } - droppedOpOperands.pop_back(); } origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + dedupedOutpts[key] = newOutputOperands.size(); newOutputOperands.push_back(outputOpOperand.value()->get()); newIndexingMaps.push_back( genericOp.getTiedIndexingMap(outputOpOperand.value())); diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -262,3 +262,27 @@ // CHECK: %[[T5:.+]] = arith.addi %[[T4]], %[[B4]] // CHECK: linalg.yield %[[T5]] // CHECK: return %[[RETURN]] + +// ----- + +// Drop redundant results. + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @drop_redundant_results( + %arg0 : tensor) -> (tensor, tensor) { + %0:2 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %arg0 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %1 = arith.addf %b0, %b0 : f32 + linalg.yield %1, %1 : f32, f32 + } -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} +// CHECK: func @drop_redundant_results +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: outs(%[[ARG0]] : +// CHECK: return %[[GENERIC]]