diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -48,7 +48,12 @@ /// Populate patterns for splitting a `LinalgOp` with multiple statements within /// its payload into multiple `GenericOp` that have a single statement. -void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns); +/// The option `removeDeadArgsAndResults` adds patterns to remove dead arguments +/// and results from the generated decomposed ops. This is default `true` since +/// the core decomposition patterns relies on these clean up patterns. It is set +/// to false only for testing purposes. +void populateDecomposeLinalgOpsPattern(RewritePatternSet &patterns, + bool removeDeadArgsAndResults = true); /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops @@ -76,6 +81,10 @@ RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion); +/// Pattern to remove dead operands and results of `linalg.generic` operations. +/// This is effectively DCE for a linalg op. +void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); + /// Function type to control generic op dimension collapsing. It is expected /// to return an array of `ReassociationIndices` representing dimensions that /// should be merged. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -871,285 +871,10 @@ getDpsInputOperands(), getDpsInitOperands()); } -static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) { - if (!result.use_empty()) - return false; - // If out operand not used in payload, we can drop it. - OpOperand *outputOpOperand = - genericOp.getDpsInitOperand(result.getResultNumber()); - if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) - return true; - - // The out operand that is part of a payload can be dropped if - // these conditions are met: - // - Result from out operand is dead. - // - User of arg is yield. - // - outArg data is not being used by other outArgs. - - // Check block arg and cycle from out operand has a single use. - BlockArgument outputArg = - genericOp.getRegionOutputArgs()[result.getResultNumber()]; - if (!outputArg.hasOneUse()) - return false; - Operation *argUserOp = *outputArg.user_begin(); - - // Check argUser has no other use. - if (!argUserOp->use_empty()) - return false; - - // Check that argUser is a yield. - auto yieldOp = dyn_cast(argUserOp); - if (!yieldOp) - return false; - - // Check outArg data is not being used by other outArgs. - if (yieldOp.getOperand(result.getResultNumber()) != outputArg) - return false; - - return true; -} - LogicalResult GenericOp::verify() { return success(); } namespace { -struct DeduplicateAndRemoveDeadOperandsAndResults - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - // Create a map from argument position in the original op to the argument - // position in the new op. If the argument is dropped it wont have an entry. - SmallVector droppedOpOperands; - - // Information needed to build the new op. - SmallVector newInputOperands, newOutputOperands; - SmallVector newIndexingMaps; - - // Gather information about duplicate input operands. - llvm::SmallDenseMap origInsToNewInsPos = - deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands, - newIndexingMaps); - - // Gather information about the dropped outputs. - llvm::SmallDenseMap origOutsToNewOutsPos = - deduplicateOutputOperands(genericOp, droppedOpOperands, - newOutputOperands, newIndexingMaps); - - // Check if there is any change to operands. - if (newInputOperands.size() + newOutputOperands.size() == - genericOp->getNumOperands()) - return failure(); - - // Create the new op with the body being empty. - Location loc = genericOp.getLoc(); - SmallVector newResultTypes; - for (Value v : newOutputOperands) - if (v.getType().isa()) - newResultTypes.push_back(v.getType()); - auto newOp = rewriter.create( - loc, newResultTypes, newInputOperands, newOutputOperands, - rewriter.getAffineMapArrayAttr(newIndexingMaps), - genericOp.getIteratorTypes(), genericOp.getDocAttr(), - genericOp.getLibraryCallAttr(), - [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) { - return; - }); - // Copy over unknown attributes. They might be load bearing for some flow. - ArrayRef odsAttrs = genericOp.getAttributeNames(); - for (NamedAttribute kv : genericOp->getAttrs()) - if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) - newOp->setAttr(kv.getName(), kv.getValue()); - - // Fix up the payload of the canonicalized operation. - populateOpPayload(genericOp, newOp, origInsToNewInsPos, - origOutsToNewOutsPos, rewriter); - - // Replace all live uses of the op. - SmallVector replacementsVals(genericOp->getNumResults(), nullptr); - for (const auto &result : llvm::enumerate(genericOp.getResults())) { - auto it = origOutsToNewOutsPos.find(result.index()); - if (it == origOutsToNewOutsPos.end()) - continue; - replacementsVals[result.index()] = newOp.getResult(it->second); - } - rewriter.replaceOp(genericOp, replacementsVals); - return success(); - } - -private: - // Deduplicate input operands, and return the - // - Mapping from operand position in the original op, to operand position in - // the canonicalized op. - // - The preserved input operands list (by reference). - llvm::SmallDenseMap - deduplicateInputOperands(GenericOp genericOp, - SmallVector &droppedOpOperands, - SmallVector &newInputOperands, - SmallVector &newIndexingMaps) const { - llvm::SmallDenseMap origToNewPos; - llvm::SmallDenseMap, unsigned> dedupedInputs; - for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) { - OpOperand *inputOpOperand = en.value(); - // Check if operand is dead and if dropping the indexing map makes the - // loops to shape computation invalid. - if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) { - // Add the current operands to the list of potentially droppable - // operands. If it cannot be dropped, this needs to be popped back. - droppedOpOperands.push_back(inputOpOperand); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) - continue; - droppedOpOperands.pop_back(); - } - - // Check if this operand is a duplicate. - AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand); - auto it = dedupedInputs.find( - std::make_pair(inputOpOperand->get(), indexingMap)); - if (it != dedupedInputs.end()) { - origToNewPos[en.index()] = it->second; - droppedOpOperands.push_back(inputOpOperand); - continue; - } - - // This is a preserved argument. - origToNewPos[en.index()] = newInputOperands.size(); - dedupedInputs[{inputOpOperand->get(), indexingMap}] = - newInputOperands.size(); - newInputOperands.push_back(inputOpOperand->get()); - newIndexingMaps.push_back(indexingMap); - } - return origToNewPos; - } - - // Deduplicate output operands, and return the - // - Mapping from operand position in the original op, to operand position in - // the canonicalized op. - // - The preserved output operands list (by reference). - llvm::SmallDenseMap - deduplicateOutputOperands(GenericOp genericOp, - SmallVector &droppedOpOperands, - SmallVector &newOutputOperands, - SmallVector &newIndexingMaps) const { - llvm::SmallDenseMap origToNewPos; - llvm::SmallDenseMap, unsigned> - dedupedOutpts; - // If the op doesnt have tensor semantics, keep all the outputs as - // preserved. - if (!genericOp.hasTensorSemantics()) { - for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) { - origToNewPos[en.index()] = newOutputOperands.size(); - newOutputOperands.push_back(en.value()->get()); - newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value())); - } - return origToNewPos; - } - // Output argument can be dropped if the result has - // - no users, and - // - it is not used in the payload, and - // - the corresponding indexing maps are not needed for loop bound - // computation. - auto yieldOp = cast(genericOp.getBody()->getTerminator()); - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getDpsInitOperands())) { - OpResult result = genericOp.getTiedOpResult(outputOpOperand.value()); - AffineMap indexingMap = - genericOp.getMatchingIndexingMap(outputOpOperand.value()); - auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap, - yieldOp->getOperand(outputOpOperand.index())); - if (isResultValueDead(genericOp, result)) { - // Check if the opoperand can be dropped without affecting loop - // bound computation. Add the operand to the list of dropped op - // operand for checking. If it cannot be dropped, need to pop the - // value back. - droppedOpOperands.push_back(outputOpOperand.value()); - if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { - continue; - } - droppedOpOperands.pop_back(); - } - - if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { - // The out operand can also be dropped if it is computed redundantly - // by another result, the conditions for that are - // - The same operand is used as the out operand - // - The same indexing map is used - // - The same yield value is used. - auto it = dedupedOutpts.find(key); - if (it != dedupedOutpts.end()) { - origToNewPos[outputOpOperand.index()] = it->second; - droppedOpOperands.push_back(outputOpOperand.value()); - continue; - } - } - - origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); - dedupedOutpts[key] = newOutputOperands.size(); - newOutputOperands.push_back(outputOpOperand.value()->get()); - newIndexingMaps.push_back( - genericOp.getMatchingIndexingMap(outputOpOperand.value())); - } - return origToNewPos; - } - - // Populate the body of the canonicalized operation. - void populateOpPayload( - GenericOp genericOp, GenericOp newOp, - const llvm::SmallDenseMap &origInsToNewInsPos, - const llvm::SmallDenseMap &origOutsToNewOutsPos, - PatternRewriter &rewriter) const { - // Merge the body of the original op with the new op. - Block *newOpBlock = &newOp.getRegion().front(); - assert(newOpBlock->empty() && "expected new op to have an empty payload"); - Block *origOpBlock = &genericOp.getRegion().front(); - SmallVector replacements(origOpBlock->getNumArguments(), nullptr); - - // Replace all arguments in the original op, with arguments from the - // canonicalized op. - auto updateReplacements = - [&](OpOperandVector &origOperands, OpOperandVector &newOperands, - const llvm::SmallDenseMap &map) { - for (const auto &origOperand : llvm::enumerate(origOperands)) { - auto it = map.find(origOperand.index()); - if (it == map.end()) - continue; - OpOperand *newOperand = newOperands[it->second]; - replacements[origOperand.value()->getOperandNumber()] = - newOpBlock->getArgument(newOperand->getOperandNumber()); - } - }; - - OpOperandVector origInputOperands = genericOp.getDpsInputOperands(); - OpOperandVector newInputOperands = newOp.getDpsInputOperands(); - updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); - - OpOperandVector origOutputOperands = genericOp.getDpsInitOperands(); - OpOperandVector newOutputOperands = newOp.getDpsInitOperands(); - updateReplacements(origOutputOperands, newOutputOperands, - origOutsToNewOutsPos); - - // Drop the unused yield args. - if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) { - OpBuilder::InsertionGuard g(rewriter); - YieldOp origYieldOp = cast(origOpBlock->getTerminator()); - rewriter.setInsertionPoint(origYieldOp); - - SmallVector newYieldVals(newOp.getNumDpsInits(), nullptr); - for (const auto &yieldOpOperands : - llvm::enumerate(origYieldOp.getValues())) { - auto it = origOutsToNewOutsPos.find(yieldOpOperands.index()); - if (it == origOutsToNewOutsPos.end()) - continue; - newYieldVals[it->second] = yieldOpOperands.value(); - } - rewriter.replaceOpWithNewOp(origYieldOp, newYieldVals); - } - - rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); - } -}; - /// Remove generic operations (on tensors) that are just copying /// the values from inputs to the results. Requirements are /// 1) All iterator types are parallel @@ -1227,74 +952,11 @@ } }; -/// Remove unused cycles. -/// We can remove unused cycle within a payload of generic region -/// if these conditions are met: -/// - Result from out operand is dead. -/// - Block arg from out operand has a single use in the %cycle -/// instruction. -/// - Cycle has a single use and it is in yield. -struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - - // If the op doesnt have tensor semantics, preserve the outputs as is. - if (!genericOp.hasTensorSemantics()) - return failure(); - - bool hasRemovedCycles = false; - // Iterate over output operands and remove any unused cycles. - for (const auto &outputOpOperand : - llvm::enumerate(genericOp.getDpsInitOperands())) { - - // Check that result from out operand is dead. - Value result = genericOp.getResult(outputOpOperand.index()); - if (!result.use_empty()) - continue; - - // Check that outputArg has one use in cycle. - BlockArgument outputArg = - genericOp.getRegionOutputArgs()[outputOpOperand.index()]; - if (!outputArg.hasOneUse()) - continue; - - // Check cycle has at most one use. - Operation *cycleOp = *outputArg.user_begin(); - if (!cycleOp->hasOneUse()) - continue; - - // Check that the cycleUser is a yield. - Operation *cycleUserOp = *cycleOp->user_begin(); - if (!isa(cycleUserOp)) - continue; - - // Check that argIndex matches yieldIndex, else data is being used. - if (cycleUserOp->getOperand(outputOpOperand.index()) != - cycleOp->getResult(0)) - continue; - - // Directly replace the cycle with the blockArg such that - // Deduplicate pattern can eliminate it along with unused yield. - rewriter.replaceOp(cycleOp, outputArg); - rewriter.updateRootInPlace(genericOp, [] {}); - hasRemovedCycles = true; - } - - if (hasRemovedCycles) { - return success(); - } - - return failure(); - } -}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } LogicalResult GenericOp::fold(ArrayRef, diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -8,6 +8,7 @@ DropUnitDims.cpp ElementwiseOpFusion.cpp ElementwiseToLinalg.cpp + EraseUnusedOperandsAndResults.cpp FusePadOpWithLinalgProducer.cpp Fusion.cpp FusionOnTensors.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -376,6 +376,9 @@ } void mlir::linalg::populateDecomposeLinalgOpsPattern( - RewritePatternSet &patterns) { + RewritePatternSet &patterns, bool removeDeadArgsAndResults) { patterns.insert(patterns.getContext()); + // Add the patterns to clean up the dead operands and results. + if (removeDeadArgsAndResults) + populateEraseUnusedOperandsAndResultsPatterns(patterns); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1780,6 +1780,8 @@ patterns.add(context, controlElementwiseOpsFusion); patterns.add(context); + // Add the patterns that clean up dead operands and results. + populateEraseUnusedOperandsAndResultsPatterns(patterns); } void mlir::linalg::populateCollapseDimensions( diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -0,0 +1,362 @@ +//===- EraseUnusedOperandsAndResults.cpp ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +#include "mlir/Dialect/Linalg/IR/Linalg.h" + +using namespace mlir; +using namespace mlir::linalg; + +/// Return `true` if the `result` of an operation `genericOp` is dead. +static bool isResultValueDead(linalg::GenericOp genericOp, OpResult result) { + if (!result.use_empty()) + return false; + // If out operand not used in payload, we can drop it. + OpOperand *outputOpOperand = + genericOp.getDpsInitOperand(result.getResultNumber()); + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand)) + return true; + + // The out operand that is part of a payload can be dropped if + // these conditions are met: + // - Result from out operand is dead. + // - User of arg is yield. + // - outArg data is not being used by other outArgs. + + // Check block arg and cycle from out operand has a single use. + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[result.getResultNumber()]; + if (!outputArg.hasOneUse()) + return false; + Operation *argUserOp = *outputArg.user_begin(); + + // Check argUser has no other use. + if (!argUserOp->use_empty()) + return false; + + // Check that argUser is a yield. + auto yieldOp = dyn_cast(argUserOp); + if (!yieldOp) + return false; + + // Check outArg data is not being used by other outArgs. + if (yieldOp.getOperand(result.getResultNumber()) != outputArg) + return false; + + return true; +} + +namespace { + +struct DeduplicateAndRemoveDeadOperandsAndResults + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // Create a map from argument position in the original op to the argument + // position in the new op. If the argument is dropped it wont have an entry. + SmallVector droppedOpOperands; + + // Information needed to build the new op. + SmallVector newInputOperands, newOutputOperands; + SmallVector newIndexingMaps; + + // Gather information about duplicate input operands. + llvm::SmallDenseMap origInsToNewInsPos = + deduplicateInputOperands(genericOp, droppedOpOperands, newInputOperands, + newIndexingMaps); + + // Gather information about the dropped outputs. + llvm::SmallDenseMap origOutsToNewOutsPos = + deduplicateOutputOperands(genericOp, droppedOpOperands, + newOutputOperands, newIndexingMaps); + + // Check if there is any change to operands. + if (newInputOperands.size() + newOutputOperands.size() == + genericOp->getNumOperands()) + return failure(); + + // Create the new op with the body being empty. + Location loc = genericOp.getLoc(); + SmallVector newResultTypes; + for (Value v : newOutputOperands) + if (v.getType().isa()) + newResultTypes.push_back(v.getType()); + auto newOp = rewriter.create( + loc, newResultTypes, newInputOperands, newOutputOperands, + rewriter.getAffineMapArrayAttr(newIndexingMaps), + genericOp.getIteratorTypes(), genericOp.getDocAttr(), + genericOp.getLibraryCallAttr(), + [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) { + return; + }); + // Copy over unknown attributes. They might be load bearing for some flow. + ArrayRef odsAttrs = genericOp.getAttributeNames(); + for (NamedAttribute kv : genericOp->getAttrs()) + if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) + newOp->setAttr(kv.getName(), kv.getValue()); + + // Fix up the payload of the canonicalized operation. + populateOpPayload(genericOp, newOp, origInsToNewInsPos, + origOutsToNewOutsPos, rewriter); + + // Replace all live uses of the op. + SmallVector replacementsVals(genericOp->getNumResults(), nullptr); + for (const auto &result : llvm::enumerate(genericOp.getResults())) { + auto it = origOutsToNewOutsPos.find(result.index()); + if (it == origOutsToNewOutsPos.end()) + continue; + replacementsVals[result.index()] = newOp.getResult(it->second); + } + rewriter.replaceOp(genericOp, replacementsVals); + return success(); + } + +private: + // Deduplicate input operands, and return the + // - Mapping from operand position in the original op, to operand position in + // the canonicalized op. + // - The preserved input operands list (by reference). + llvm::SmallDenseMap + deduplicateInputOperands(GenericOp genericOp, + SmallVector &droppedOpOperands, + SmallVector &newInputOperands, + SmallVector &newIndexingMaps) const { + llvm::SmallDenseMap origToNewPos; + llvm::SmallDenseMap, unsigned> dedupedInputs; + for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) { + OpOperand *inputOpOperand = en.value(); + // Check if operand is dead and if dropping the indexing map makes the + // loops to shape computation invalid. + if (!genericOp.payloadUsesValueFromOperand(inputOpOperand)) { + // Add the current operands to the list of potentially droppable + // operands. If it cannot be dropped, this needs to be popped back. + droppedOpOperands.push_back(inputOpOperand); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) + continue; + droppedOpOperands.pop_back(); + } + + // Check if this operand is a duplicate. + AffineMap indexingMap = genericOp.getMatchingIndexingMap(inputOpOperand); + auto it = dedupedInputs.find( + std::make_pair(inputOpOperand->get(), indexingMap)); + if (it != dedupedInputs.end()) { + origToNewPos[en.index()] = it->second; + droppedOpOperands.push_back(inputOpOperand); + continue; + } + + // This is a preserved argument. + origToNewPos[en.index()] = newInputOperands.size(); + dedupedInputs[{inputOpOperand->get(), indexingMap}] = + newInputOperands.size(); + newInputOperands.push_back(inputOpOperand->get()); + newIndexingMaps.push_back(indexingMap); + } + return origToNewPos; + } + + // Deduplicate output operands, and return the + // - Mapping from operand position in the original op, to operand position in + // the canonicalized op. + // - The preserved output operands list (by reference). + llvm::SmallDenseMap + deduplicateOutputOperands(GenericOp genericOp, + SmallVector &droppedOpOperands, + SmallVector &newOutputOperands, + SmallVector &newIndexingMaps) const { + llvm::SmallDenseMap origToNewPos; + llvm::SmallDenseMap, unsigned> + dedupedOutpts; + // If the op doesnt have tensor semantics, keep all the outputs as + // preserved. + if (!genericOp.hasTensorSemantics()) { + for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) { + origToNewPos[en.index()] = newOutputOperands.size(); + newOutputOperands.push_back(en.value()->get()); + newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(en.value())); + } + return origToNewPos; + } + // Output argument can be dropped if the result has + // - no users, and + // - it is not used in the payload, and + // - the corresponding indexing maps are not needed for loop bound + // computation. + auto yieldOp = cast(genericOp.getBody()->getTerminator()); + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getDpsInitOperands())) { + OpResult result = genericOp.getTiedOpResult(outputOpOperand.value()); + AffineMap indexingMap = + genericOp.getMatchingIndexingMap(outputOpOperand.value()); + auto key = std::make_tuple(outputOpOperand.value()->get(), indexingMap, + yieldOp->getOperand(outputOpOperand.index())); + if (isResultValueDead(genericOp, result)) { + // Check if the opoperand can be dropped without affecting loop + // bound computation. Add the operand to the list of dropped op + // operand for checking. If it cannot be dropped, need to pop the + // value back. + droppedOpOperands.push_back(outputOpOperand.value()); + if (genericOp.canOpOperandsBeDropped(droppedOpOperands)) { + continue; + } + droppedOpOperands.pop_back(); + } + + if (!genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + // The out operand can also be dropped if it is computed redundantly + // by another result, the conditions for that are + // - The same operand is used as the out operand + // - The same indexing map is used + // - The same yield value is used. + auto it = dedupedOutpts.find(key); + if (it != dedupedOutpts.end()) { + origToNewPos[outputOpOperand.index()] = it->second; + droppedOpOperands.push_back(outputOpOperand.value()); + continue; + } + } + + origToNewPos[outputOpOperand.index()] = newOutputOperands.size(); + dedupedOutpts[key] = newOutputOperands.size(); + newOutputOperands.push_back(outputOpOperand.value()->get()); + newIndexingMaps.push_back( + genericOp.getMatchingIndexingMap(outputOpOperand.value())); + } + return origToNewPos; + } + + // Populate the body of the canonicalized operation. + void populateOpPayload( + GenericOp genericOp, GenericOp newOp, + const llvm::SmallDenseMap &origInsToNewInsPos, + const llvm::SmallDenseMap &origOutsToNewOutsPos, + PatternRewriter &rewriter) const { + // Merge the body of the original op with the new op. + Block *newOpBlock = &newOp.getRegion().front(); + assert(newOpBlock->empty() && "expected new op to have an empty payload"); + Block *origOpBlock = &genericOp.getRegion().front(); + SmallVector replacements(origOpBlock->getNumArguments(), nullptr); + + // Replace all arguments in the original op, with arguments from the + // canonicalized op. + auto updateReplacements = + [&](OpOperandVector &origOperands, OpOperandVector &newOperands, + const llvm::SmallDenseMap &map) { + for (const auto &origOperand : llvm::enumerate(origOperands)) { + auto it = map.find(origOperand.index()); + if (it == map.end()) + continue; + OpOperand *newOperand = newOperands[it->second]; + replacements[origOperand.value()->getOperandNumber()] = + newOpBlock->getArgument(newOperand->getOperandNumber()); + } + }; + + OpOperandVector origInputOperands = genericOp.getDpsInputOperands(); + OpOperandVector newInputOperands = newOp.getDpsInputOperands(); + updateReplacements(origInputOperands, newInputOperands, origInsToNewInsPos); + + OpOperandVector origOutputOperands = genericOp.getDpsInitOperands(); + OpOperandVector newOutputOperands = newOp.getDpsInitOperands(); + updateReplacements(origOutputOperands, newOutputOperands, + origOutsToNewOutsPos); + + // Drop the unused yield args. + if (newOp.getNumDpsInits() != genericOp.getNumDpsInits()) { + OpBuilder::InsertionGuard g(rewriter); + YieldOp origYieldOp = cast(origOpBlock->getTerminator()); + rewriter.setInsertionPoint(origYieldOp); + + SmallVector newYieldVals(newOp.getNumDpsInits(), nullptr); + for (const auto &yieldOpOperands : + llvm::enumerate(origYieldOp.getValues())) { + auto it = origOutsToNewOutsPos.find(yieldOpOperands.index()); + if (it == origOutsToNewOutsPos.end()) + continue; + newYieldVals[it->second] = yieldOpOperands.value(); + } + rewriter.replaceOpWithNewOp(origYieldOp, newYieldVals); + } + + rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); + } +}; + +/// Remove unused cycles. +/// We can remove unused cycle within a payload of generic region +/// if these conditions are met: +/// - Result from out operand is dead. +/// - Block arg from out operand has a single use in the %cycle +/// instruction. +/// - Cycle has a single use and it is in yield. +struct RemoveUnusedCycleInGenericOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + + // If the op doesnt have tensor semantics, preserve the outputs as is. + if (!genericOp.hasTensorSemantics()) + return failure(); + + bool hasRemovedCycles = false; + // Iterate over output operands and remove any unused cycles. + for (const auto &outputOpOperand : + llvm::enumerate(genericOp.getDpsInitOperands())) { + + // Check that result from out operand is dead. + Value result = genericOp.getResult(outputOpOperand.index()); + if (!result.use_empty()) + continue; + + // Check that outputArg has one use in cycle. + BlockArgument outputArg = + genericOp.getRegionOutputArgs()[outputOpOperand.index()]; + if (!outputArg.hasOneUse()) + continue; + + // Check cycle has at most one use. + Operation *cycleOp = *outputArg.user_begin(); + if (!cycleOp->hasOneUse()) + continue; + + // Check that the cycleUser is a yield. + Operation *cycleUserOp = *cycleOp->user_begin(); + if (!isa(cycleUserOp)) + continue; + + // Check that argIndex matches yieldIndex, else data is being used. + if (cycleUserOp->getOperand(outputOpOperand.index()) != + cycleOp->getResult(0)) + continue; + + // Directly replace the cycle with the blockArg such that + // Deduplicate pattern can eliminate it along with unused yield. + rewriter.replaceOp(cycleOp, outputArg); + rewriter.updateRootInPlace(genericOp, [] {}); + hasRemovedCycles = true; + } + + if (hasRemovedCycles) { + return success(); + } + + return failure(); + } +}; +} // namespace + +void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns( + RewritePatternSet &patterns) { + patterns.insert(patterns.getContext()); +} \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -272,54 +272,6 @@ return } -// ----- - -// CHECK-LABEL: func @remove_deadargs_generic_basic -// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { -// CHECK: %[[GENERIC_OP:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]] : tensor) -// CHECK-SAME: outs({{.*}} : tensor) { -#map0 = affine_map<(d0) -> (d0)> -func.func @remove_deadargs_generic_basic(%arg0: tensor) -> (tensor) { - %c0 = arith.constant 0 : index - %cst = arith.constant 7.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.empty(%0) : tensor - %2 = tensor.empty(%0) : tensor - %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor, tensor) outs (%2:tensor) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): - %4 = arith.addf %arg1, %cst : f32 - linalg.yield %4 : f32 - } -> tensor - return %3 : tensor -} - -// ----- - -// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess -// CHECK: %[[GENERIC_OP:.*]] = linalg.generic -// CHECK-NOT: ins -// CHECK-SAME: outs({{.*}} : tensor) { -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1, d0)> -func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor) -> (tensor) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 0 : index - %cst1 = arith.constant 7.0 : f32 - %cst2 = arith.constant 6.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = tensor.empty(%0, %1) : tensor - %3 = tensor.empty(%1, %0) : tensor - %4 = tensor.empty(%0, %1) : tensor - %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor, tensor) outs (%4:tensor) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): - %6 = arith.divf %cst1, %cst2 : f32 - linalg.yield %6 : f32 - } -> tensor - return %5 : tensor -} - // ----- // CHECK-LABEL: func @fold_fill_reshape() func.func @fold_fill_reshape() -> tensor<6x4xf32> { diff --git a/mlir/test/Dialect/Linalg/decompose-ops.mlir b/mlir/test/Dialect/Linalg/decompose-ops.mlir --- a/mlir/test/Dialect/Linalg/decompose-ops.mlir +++ b/mlir/test/Dialect/Linalg/decompose-ops.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt -test-linalg-decompose-ops -cse -split-input-file %s | FileCheck %s -// RUN: mlir-opt -test-linalg-decompose-ops -cse -canonicalize -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK +// RUN: mlir-opt -test-linalg-decompose-ops=remove-dead-args-and-results -cse -split-input-file %s | FileCheck %s --check-prefix=CANONICALIZECHECK func.func @simple_op(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> (tensor, tensor) { diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir rename from mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir rename to mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir @@ -1,4 +1,52 @@ -// RUN: mlir-opt %s -split-input-file -canonicalize | FileCheck %s +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unused-operands-and-results | FileCheck %s + +// CHECK-LABEL: func @remove_deadargs_generic_basic +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0) -> (d0)> +func.func @remove_deadargs_generic_basic(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.empty(%0) : tensor + %2 = tensor.empty(%0) : tensor + %3 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %1 : tensor, tensor) outs (%2:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %4 = arith.addf %arg1, %cst : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} + +// ----- + +// CHECK-LABEL: func @remove_deadargs_generic_mixedaccess +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-NOT: ins +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +func.func @remove_deadargs_generic_mixedaccess(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %cst1 = arith.constant 7.0 : f32 + %cst2 = arith.constant 6.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = tensor.empty(%1, %0) : tensor + %4 = tensor.empty(%0, %1) : tensor + %5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%2, %3 : tensor, tensor) outs (%4:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %6 = arith.divf %cst1, %cst2 : f32 + linalg.yield %6 : f32 + } -> tensor + return %5 : tensor +} + +// ----- // Test case: Most basic case. Adding a vector to itself. diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-options.mlir @@ -1,5 +1,4 @@ // RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file | FileCheck %s -// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-generic-ops -split-input-file -canonicalize | FileCheck %s --check-prefix=CANONICALIZE #map0 = affine_map<(d0, d1) -> (d0, d1)> #binary2Dpointwise = { @@ -50,6 +49,7 @@ } -> tensor return %6 : tensor } + // CHECK-LABEL: func @test_fusion_limit // CHECK-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor @@ -59,17 +59,5 @@ // CHECK-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor // CHECK: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]] // CHECK: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]] -// CHECK: %[[OP3:.+]]:2 = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] -// CHECK: return %[[OP3]]#1 - -// CANONICALIZE-LABEL: func @test_fusion_limit -// CANONICALIZE-SAME: %[[ARG0:[a-zA-z0-9_]+]]: tensor -// CANONICALIZE-SAME: %[[ARG1:[a-zA-z0-9_]+]]: tensor -// CANONICALIZE-SAME: %[[ARG2:[a-zA-z0-9_]+]]: tensor -// CANONICALIZE-SAME: %[[ARG3:[a-zA-z0-9_]+]]: tensor -// CANONICALIZE-SAME: %[[ARG4:[a-zA-z0-9_]+]]: tensor -// CANONICALIZE-SAME: %[[ARG5:[a-zA-z0-9_]+]]: tensor -// CANONICALIZE: %[[OP1:.+]] = linalg.generic {{.+}} ins(%[[ARG2]], %[[ARG3]] -// CANONICALIZE: %[[OP2:.+]] = linalg.generic {{.+}} ins(%[[ARG4]], %[[ARG5]] -// CANONICALIZE: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] -// CANONICALIZE: return %[[OP3]] +// CHECK: %[[OP3:.+]] = linalg.generic {{.+}} ins(%[[ARG0]], %[[ARG1]], %[[OP1]], %[[OP2]] +// CHECK: return %[[OP3]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDecomposeOps.cpp @@ -22,8 +22,8 @@ : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgDecomposeOps) - TestLinalgDecomposeOps() = default; - TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass) = default; + TestLinalgDecomposeOps(){}; + TestLinalgDecomposeOps(const TestLinalgDecomposeOps &pass){}; void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -32,10 +32,16 @@ return "Test Linalg decomposition patterns"; } + Option removeDeadArgsAndResults{ + *this, "remove-dead-args-and-results", + llvm::cl::desc("Test patterns to erase unused operands and results"), + llvm::cl::init(false)}; + void runOnOperation() override { MLIRContext *context = &this->getContext(); RewritePatternSet decompositionPatterns(context); - linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns); + linalg::populateDecomposeLinalgOpsPattern(decompositionPatterns, + removeDeadArgsAndResults); if (failed(applyPatternsAndFoldGreedily( getOperation(), std::move(decompositionPatterns)))) { return signalPassFailure(); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -109,6 +109,10 @@ llvm::cl::desc( "Test patterns to swap tensor.extract_slice(linalg.fill())"), llvm::cl::init(false)}; + Option testEraseUnusedOperandsAndResults{ + *this, "test-erase-unused-operands-and-results", + llvm::cl::desc("Test patterns to erase unused operands and results"), + llvm::cl::init(false)}; }; } // namespace @@ -175,6 +179,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyEraseUnusedOperandsAndResultsPatterns(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateEraseUnusedOperandsAndResultsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -193,6 +203,8 @@ return applyBubbleUpExtractSliceOpPattern(getOperation()); if (testSwapExtractSliceWithFill) return applySwapExtractSliceWithFillPattern(getOperation()); + if (testEraseUnusedOperandsAndResults) + return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); } namespace mlir {