diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -857,6 +857,42 @@ outputBuffers); } +bool isOutputOpOperandDead(linalg::GenericOp genericOp, + OpOperand *outputOpOperand, BlockArgument &outputArg, + Value result) { + if (!result.use_empty()) + return false; + // If out operand not used in payload, we can drop it. + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) + return true; + + // The out operand that is part of a payload can be dropped if + // these conditions are met: + // - Result from out operand is dead. + // - User of arg is yield. + // - outArg data is not being used by other outArgs. + + // Check block arg and cycle from out operand has a single use. + if (!outputArg.hasOneUse()) + return false; + Operation *argUserOp = *outputArg.user_begin(); + + // Check argUser has no other use. + if (!argUserOp->use_empty()) + return false; + + // Check that argUser is a yield. + if (!isa(argUserOp)) + return false; + + // Check outArg data is not being used by other outArgs. + int64_t outputIndex = outputArg.getArgNumber() - genericOp.getNumInputs(); + if (argUserOp->getOperand(outputIndex) != outputArg) + return false; + + return true; +} + LogicalResult GenericOp::verify() { return success(); } namespace { @@ -995,57 +1031,58 @@ newIndexingMaps.push_back( genericOp.getMatchingIndexingMap(outputOpOperand.value())); } - } else { - // Output argument can be dropped if the result has - // - no users, and - // - it is not used in the payload, and - // - the corresponding indexing maps are not needed for loop bound - // computation. - auto yieldOp = cast(genericOp.getBody()->getTerminator()); - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getOutputOperands())) { - Value result = genericOp.getResult(outputOpOperand.index()); - AffineMap indexingMap = - genericOp.getMatchingIndexingMap(outputOpOperand.value()); - auto key = - std::make_tuple(outputOpOperand.value()->get(), indexingMap, - yieldOp->getOperand(outputOpOperand.index())); - - // Do not drop an out if its value is used in the payload. - if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { - if (result.use_empty()) { - // Check if the opoperand can be dropped without affecting loop - // bound computation. Add the operand to the list of dropped op - // operand for checking. If it cannot be dropped, need to pop the - // value back. - droppedOpOperands.push_back(outputOpOperand.value()); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { - continue; - } - droppedOpOperands.pop_back(); - } - - // The out operand can also be dropped if it is computed redundantly - // by another result, the conditions for that are - // - The same operand is used as the out operand - // - The same indexing map is used - // - The same yield value is used. - auto it = dedupedOutpts.find(key); - if (it != dedupedOutpts.end()) { - origToNewPos[outputOpOperand.index()] = it->second; - droppedOpOperands.push_back(outputOpOperand.value()); - continue; - } + return origToNewPos; + } + // Output argument can be dropped if the result has + // - no users, and + // - it is not used in the payload, and + // - the corresponding indexing maps are not needed for loop bound + // computation. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + Value result = genericOp.getResult(outputOpOperand.index()); + AffineMap indexingMap = + genericOp.getMatchingIndexingMap(outputOpOperand.value()); + auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap, + yieldOp->getOperand(outputOpOperand.index())); + assert(genericOp.getNumOutputs() >= outputOpOperand.index() && + "Output op idx greater than number of outputs."); + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[outputOpOperand.index()]; + if (isOutputOpOperandDead(genericOp, outputOpOperand.value(), outputArg, + result)) { + // Check if the opoperand can be dropped without affecting loop + // bound computation. Add the operand to the list of dropped op + // operand for checking. If it cannot be dropped, need to pop the + // value back. + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + continue; } + droppedOpOperands.pop_back(); + } - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); - dedupedOutpts[key] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); - newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + // The out operand can also be dropped if it is computed redundantly + // by another result, the conditions for that are + // - The same operand is used as the out operand + // - The same indexing map is used + // - The same yield value is used. + auto it = dedupedOutpts.find(key); + if (it != dedupedOutpts.end()) { + origToNewPos[outputOpOperand.index()] = it->second; + droppedOpOperands.push_back(outputOpOperand.value()); + continue; + } } - } + origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + dedupedOutpts[key] = newOutputOperands.size(); + newOutputOperands.push_back(outputOpOperand.value()->get()); + newIndexingMaps.push_back( + genericOp.getMatchingIndexingMap(outputOpOperand.value())); + } return origToNewPos; } @@ -1068,8 +1105,16 @@ const llvm::SmallDenseMap &map) { for (const auto &origOperand : llvm::enumerate(origOperands)) { auto it = map.find(origOperand.index()); - if (it == map.end()) + if (it == map.end()) { + uint64_t argIndex = origOperand.value()->getOperandNumber(); + BlockArgument blockArg = origOpBlock->getArgument(argIndex); + Type operandType = blockArg.getType(); + Value placeHolder = rewriter.create( + blockArg.getLoc(), operandType, + rewriter.getZeroAttr(operandType)); + blockArg.replaceAllUsesWith(placeHolder); continue; + } OpOperand *newOperand = newOperands[it->second]; replacements[origOperand.value()->getOperandNumber()] = newOpBlock->getArgument(newOperand->getOperandNumber()); @@ -1178,13 +1223,71 @@ return success(); } }; + +/// Remove unused cycles. +/// We can remove unused cycle within a payload of generic region +/// if these conditions are met: +/// - Result from out operand is dead. +/// - Block arg from out operand has a single use in the %cycle +/// instruction. +/// - Cycle has a single use and it is in yield. +struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + + // Do not apply to generic op that does in place update. + if (genericOp.getNumResults() == 0) { + return failure(); + } + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + + Value result = genericOp.getResult(outputOpOperand.index()); + + // Check that result from out operand is dead. + if (!result.use_empty()) + continue; + + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[outputOpOperand.index()]; + + // Check that blockArg has one use in cycle. + if (!outputArg.hasOneUse()) + continue; + + // Check cycle has at most one use. + Operation *cycleOp = *outputArg.user_begin(); + if (!cycleOp->hasOneUse()) + continue; + + // Check that the cycleUser is a yield. + Operation *cycleUserOp = *cycleOp->user_begin(); + if (!isa(cycleUserOp)) + continue; + + // Check that argIndex matches yieldIndex, else data is being used. + if (cycleUserOp->getOperand(outputOpOperand.index()) != + cycleOp->getResult(0)) + continue; + + // Directly replace the cycle with the blockArg such that + // Deduplicate pattern can eliminate it along with unused yield. + rewriter.replaceOp(cycleOp, outputArg); + rewriter.updateRootInPlace(genericOp, [] {}); + return success(); + } + return failure(); + } +}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results - .add( - context); + results.add(context); } LogicalResult GenericOp::fold(ArrayRef, diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -286,3 +286,162 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: outs(%[[ARG0]] : // CHECK: return %[[GENERIC]] + +// ----- + +// Drop dead result with different tensors. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @drop_dead_results_with_different_tensors(%arg0 : tensor) -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:4 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3, #map4], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %arg0, %init0, %init0 + : tensor, tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + linalg.yield %b0, %b0, %b3, %b4 : f32, f32, f32, f32 + } -> (tensor, tensor, tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_dead_results_with_different_tensors( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[ARG0]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 + +// ----- + +// Drop dead result with unused cycles. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func.func @drop_dead_results_with_unused_cycles(%arg0 : tensor) -> (tensor, tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:4 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3, #map4], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %arg0, %init0, %init0 + : tensor, tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %1 = arith.addf %b0, %b0: f32 + %2 = arith.addf %b0, %b3: f32 + %3 = arith.addf %b0, %b4: f32 + linalg.yield %1, %1, %2, %3 : f32, f32, f32, f32 + } -> (tensor, tensor, tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_dead_results_with_unused_cycles( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[ARG0]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 + +// ----- + +// Drop only when results not used by others. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +func.func @drop_only_results_when_not_used_by_others(%arg0 : tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:3 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %init0, %init0 + : tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32) : + linalg.yield %b2, %b1, %b3 : f32, f32, f32 + } -> (tensor, tensor, tensor) + return %0#0 : tensor +} + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_only_results_when_not_used_by_others( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[INIT:.+]] = tensor.empty +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[INIT]] : +// CHECK: return %[[GENERIC]]#0 + +// ----- + +// drop only cycle when not used by others. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +func.func @drop_only_cycles_when_not_used_by_others(%arg0 : tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %d1 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %d2 = tensor.dim %arg0, %c2 : tensor + %init0 = tensor.empty(%d0, %d1, %d2) : tensor + %0:3 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %init0, %init0 + : tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32) : + %1 = arith.addf %b1, %b2: f32 + %2 = arith.addf %b1, %b3 : f32 + linalg.yield %1, %b1, %2 : f32, f32, f32 + } -> (tensor, tensor, tensor) + return %0#0 : tensor +} + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +// CHECK: func @drop_only_cycles_when_not_used_by_others( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[INIT:.+]] = tensor.empty +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[INIT]] : +// CHECK: return %[[GENERIC]]#0