diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -857,6 +857,44 @@ outputBuffers); } +static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) { + if (!result.use_empty()) + return false; + // If out operand not used in payload, we can drop it. + OpOperand *outputOpOperand = + genericOp.getOutputOperand(result.getResultNumber()); + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) + return true; + + // The out operand that is part of a payload can be dropped if + // these conditions are met: + // - Result from out operand is dead. + // - User of arg is yield. + // - outArg data is not being used by other outArgs. + + // Check block arg and cycle from out operand has a single use. + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[result.getResultNumber()]; + if (!outputArg.hasOneUse()) + return false; + Operation *argUserOp = *outputArg.user_begin(); + + // Check argUser has no other use. + if (!argUserOp->use_empty()) + return false; + + // Check that argUser is a yield. + auto yieldOp = dyn_cast(argUserOp); + if (!yieldOp) + return false; + + // Check outArg data is not being used by other outArgs. + if (yieldOp.getOperand(result.getResultNumber()) != outputArg) + return false; + + return true; +} + LogicalResult GenericOp::verify() { return success(); } namespace { @@ -995,57 +1033,55 @@ newIndexingMaps.push_back( genericOp.getMatchingIndexingMap(outputOpOperand.value())); } - } else { - // Output argument can be dropped if the result has - // - no users, and - // - it is not used in the payload, and - // - the corresponding indexing maps are not needed for loop bound - // computation. - auto yieldOp = cast(genericOp.getBody()->getTerminator()); - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { - Value result = genericOp.getResult(outputOpOperand.index()); - AffineMap indexingMap = - genericOp.getMatchingIndexingMap(outputOpOperand.value()); - auto key = - std::make_tuple(outputOpOperand.value()->get(), indexingMap, - yieldOp->getOperand(outputOpOperand.index())); - - // Do not drop an out if its value is used in the payload. - if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { - if (result.use_empty()) { - // Check if the opoperand can be dropped without affecting loop - // bound computation. Add the operand to the list of dropped op - // operand for checking. If it cannot be dropped, need to pop the - // value back. - droppedOpOperands.push_back(outputOpOperand.value()); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { - continue; - } - droppedOpOperands.pop_back(); - } - - // The out operand can also be dropped if it is computed redundantly - // by another result, the conditions for that are - // - The same operand is used as the out operand - // - The same indexing map is used - // - The same yield value is used. - auto it = dedupedOutpts.find(key); - if (it != dedupedOutpts.end()) { - origToNewPos[outputOpOperand.index()] = it->second; - droppedOpOperands.push_back(outputOpOperand.value()); - continue; - } + return origToNewPos; + } + // Output argument can be dropped if the result has + // - no users, and + // - it is not used in the payload, and + // - the corresponding indexing maps are not needed for loop bound + // computation. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + OpResult result = genericOp.getTiedOpResult(outputOpOperand.value()); + AffineMap indexingMap = + genericOp.getMatchingIndexingMap(outputOpOperand.value()); + auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap, + yieldOp->getOperand(outputOpOperand.index())); + assert(genericOp.getNumOutputs() >= outputOpOperand.index() && + "Output op idx greater than number of outputs."); + if (isResultValueDead(genericOp, result)) { + // Check if the opoperand can be dropped without affecting loop + // bound computation. Add the operand to the list of dropped op + // operand for checking. If it cannot be dropped, need to pop the + // value back. + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + continue; } + droppedOpOperands.pop_back(); + } - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); - dedupedOutpts[key] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); - newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + // The out operand can also be dropped if it is computed redundantly + // by another result, the conditions for that are + // - The same operand is used as the out operand + // - The same indexing map is used + // - The same yield value is used. + auto it = dedupedOutpts.find(key); + if (it != dedupedOutpts.end()) { + origToNewPos[outputOpOperand.index()] = it->second; + droppedOpOperands.push_back(outputOpOperand.value()); + continue; + } } - } + origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + dedupedOutpts[key] = newOutputOperands.size(); + newOutputOperands.push_back(outputOpOperand.value()->get()); + newIndexingMaps.push_back( + genericOp.getMatchingIndexingMap(outputOpOperand.value())); + } return origToNewPos; } @@ -1085,12 +1121,10 @@ updateReplacements(origOutputOperands, newOutputOperands, origOutsToNewOutsPos); - rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); - // Drop the unused yield args. if (newOp.getNumOutputs() != genericOp.getNumOutputs()) { OpBuilder::InsertionGuard g(rewriter); - YieldOp origYieldOp = cast(newOpBlock->getTerminator()); + YieldOp origYieldOp = cast(origOpBlock->getTerminator()); rewriter.setInsertionPoint(origYieldOp); SmallVector newYieldVals(newOp.getNumOutputs(), nullptr); @@ -1103,6 +1137,8 @@ } rewriter.replaceOpWithNewOp(origYieldOp, newYieldVals); } + + rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); } }; @@ -1178,13 +1214,75 @@ return success(); } }; + +/// Remove unused cycles. +/// We can remove unused cycle within a payload of generic region +/// if these conditions are met: +/// - Result from out operand is dead. +/// - Block arg from out operand has a single use in the %cycle +/// instruction. +/// - Cycle has a single use and it is in yield. +struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + + // If the op doesnt have tensor semantics, preserve the outputs as is. + if (!genericOp.hasTensorSemantics()) + return failure(); + + bool hasRemovedCycles = false; + // Iterate over output operands and remove any unused cycles. + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + + // Check that result from out operand is dead. + Value result = genericOp.getResult(outputOpOperand.index()); + if (!result.use_empty()) + continue; + + // Check that outputArg has one use in cycle. + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[outputOpOperand.index()]; + if (!outputArg.hasOneUse()) + continue; + + // Check cycle has at most one use. + Operation *cycleOp = *outputArg.user_begin(); + if (!cycleOp->hasOneUse()) + continue; + + // Check that the cycleUser is a yield. + Operation *cycleUserOp = *cycleOp->user_begin(); + if (!isa(cycleUserOp)) + continue; + + // Check that argIndex matches yieldIndex, else data is being used. + if (cycleUserOp->getOperand(outputOpOperand.index()) != + cycleOp->getResult(0)) + continue; + + // Directly replace the cycle with the blockArg such that + // Deduplicate pattern can eliminate it along with unused yield. + rewriter.replaceOp(cycleOp, outputArg); + rewriter.updateRootInPlace(genericOp, [] {}); + hasRemovedCycles = true; + } + + if (hasRemovedCycles) { + return success(); + } + + return failure(); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } LogicalResult GenericOp::fold(ArrayRef, diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -286,3 +286,162 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: outs(%[[ARG0]] : // CHECK: return %[[GENERIC]] + +// ----- + +// Drop dead result with different tensors. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @drop_dead_results_with_different_tensors(%arg0 : tensor) -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:4 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3, #map4], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %arg0, %init0, %init0 + : tensor, tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + linalg.yield %b0, %b0, %b3, %b4 : f32, f32, f32, f32 + } -> (tensor, tensor, tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_dead_results_with_different_tensors( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[ARG0]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 + +// ----- + +// Drop dead result with unused cycles. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @drop_dead_results_with_unused_cycles(%arg0 : tensor) -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:4 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3, #map4], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %arg0, %init0, %init0 + : tensor, tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %1 = arith.addf %b0, %b0: f32 + %2 = arith.addf %b0, %b3: f32 + %3 = arith.addf %b0, %b4: f32 + linalg.yield %1, %1, %2, %3 : f32, f32, f32, f32 + } -> (tensor, tensor, tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_dead_results_with_unused_cycles( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[ARG0]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 + +// ----- + +// Drop only the results not used by others. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +func.func @drop_only_the_results_not_used_by_others(%arg0 : tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:3 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %init0, %init0 + : tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32) : + linalg.yield %b2, %b1, %b3 : f32, f32, f32 + } -> (tensor, tensor, tensor) + return %0#0 : tensor +} + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_only_the_results_not_used_by_others( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[INIT:.+]] = tensor.empty +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[INIT]] : +// CHECK: return %[[GENERIC]]#0 + +// ----- + +// Drop only the cycles not used by others. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +func.func @drop_only_the_cycles_not_used_by_others(%arg0 : tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:3 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %init0, %init0 + : tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32) : + %1 = arith.addf %b1, %b2: f32 + %2 = arith.addf %b1, %b3 : f32 + linalg.yield %1, %b1, %2 : f32, f32, f32 + } -> (tensor, tensor, tensor) + return %0#0 : tensor +} + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_only_the_cycles_not_used_by_others( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[INIT:.+]] = tensor.empty +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[INIT]] : +// CHECK: return %[[GENERIC]]#0