diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -658,6 +658,7 @@ let verifier = [{ return ::verify(*this); }]; + let hasCanonicalizer = 1; let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -671,6 +671,138 @@ static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } +namespace { +// Deduplicate redundant args of a linalg generic op. +// An arg is redundant if it has the same Value and indexing map as another. +struct DeduplicateGenericOpInputs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // Associate each input to an equivalent "canonical" input that has the same + // Value and indexing map. + // + // In the non-duplicate case, input `i` will have canonical input `i`. But + // in the case of duplicated inputs, the canonical input could be some other + // input `< i`. That is, a later input will have some earlier input as its + // canonical input. + llvm::SmallDenseMap, unsigned> canonicalInput; + // For later remapping tasks like deduplicating payload block arguments, + // having a simple "inputIndex -> canonicalInputIndex" integer mapping is + // convenient. + SmallVector canonicalInputIndices; + for (OpOperand *opOperand : genericOp.getInputOperands()) { + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + // STL-like maps have a convenient behavior for our use case here. In the + // case of duplicate keys, the insertion is rejected, and the returned + // iterator gives access to the value already in the map. + auto pair = canonicalInput.insert( + {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); + canonicalInputIndices.push_back(pair.first->second); + } + + // If there are no duplicate args, then bail out. + if (canonicalInput.size() == genericOp.getNumInputs()) + return failure(); + + // The operands for the newly canonicalized op. + SmallVector newInputOperands; + for (OpOperand *opOperand : genericOp.getInputOperands()) + if (canonicalInputIndices[opOperand->getOperandNumber()] == + opOperand->getOperandNumber()) + newInputOperands.push_back(opOperand->get()); + + // Repair the indexing maps by filtering out the ones that have been + // eliminated. + SmallVector newIndexingMaps; + for (OpOperand *opOperand : genericOp.getInputOperands()) + if (canonicalInputIndices[opOperand->getOperandNumber()] == + opOperand->getOperandNumber()) + newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); + for (OpOperand *opOperand : genericOp.getOutputOperands()) + newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); + + // Clone the old op with new operands. + SmallVector outputOperands = genericOp.getOutputOperands(); + auto newOp = rewriter.create( + genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, + outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), + genericOp.iterator_types(), genericOp.docAttr(), + genericOp.library_callAttr()); + rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), + newOp.region().begin()); + + // Repair the payload entry block by RAUW'ing redundant arguments and + // erasing them. + Block &payload = newOp.region().front(); + SmallVector inputOperands = genericOp.getInputOperands(); + for (OpOperand *opOperand : llvm::reverse(inputOperands)) { + // Iterate in reverse, so that we erase later args first, preventing the + // argument list from shifting unexpectedly and invalidating all our + // indices. + unsigned operandNumber = opOperand->getOperandNumber(); + if (canonicalInputIndices[operandNumber] == operandNumber) + continue; + payload.getArgument(operandNumber) + .replaceAllUsesWith( + payload.getArgument(canonicalInputIndices[operandNumber])); + payload.eraseArgument(operandNumber); + } + + rewriter.replaceOp(genericOp, newOp->getResults()); + return success(); + } +}; + +/// Remove generic operations (on tensors) that are just copying +/// the values from inputs to the results. Requirements are +/// 1) All iterator types are parallel +/// 2) The body contains just a yield operation with the yielded values being +/// the arguments corresponding to the operands. +struct EraseIdentityGenericOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + // Check all indexing maps are identity. + if (llvm::any_of(genericOp.getIndexingMaps(), + [](AffineMap map) { return !map.isIdentity(); })) + return failure(); + + // Check that the body of the linalg operation is just a linalg.yield + // operation. + Block &body = genericOp.region().front(); + if (!llvm::hasSingleElement(body)) + return failure(); + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp) + return failure(); + + // Get the argument number of the returned values. That is the operand + // number to use for replacing uses of this operation. + SmallVector returnedArgs; + for (Value yieldVal : yieldOp.values()) { + auto yieldArg = yieldVal.dyn_cast(); + if (!yieldArg || yieldArg.getOwner() != &body) + return failure(); + unsigned argumentNumber = yieldArg.getArgNumber(); + returnedArgs.push_back(genericOp->getOperand(argumentNumber)); + } + if (returnedArgs.size() != genericOp->getNumResults()) + return failure(); + rewriter.replaceOp(genericOp, returnedArgs); + return success(); + } +}; +} // namespace + +void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //===----------------------------------------------------------------------===// // InitTensorOp //===----------------------------------------------------------------------===// @@ -2539,143 +2671,6 @@ }; } // namespace -namespace { -// Deduplicate redundant args of a linalg op. -// An arg is redundant if it has the same Value and indexing map as another. -struct DeduplicateInputs : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override { - // This pattern reduces the number of arguments of an op, which breaks - // the invariants of semantically charged named ops. - if (!isa(op)) - return failure(); - - // Associate each input to an equivalent "canonical" input that has the same - // Value and indexing map. - // - // In the non-duplicate case, input `i` will have canonical input `i`. But - // in the case of duplicated inputs, the canonical input could be some other - // input `< i`. That is, a later input will have some earlier input as its - // canonical input. - llvm::SmallDenseMap, unsigned> canonicalInput; - // For later remapping tasks like deduplicating payload block arguments, - // having a simple "inputIndex -> canonicalInputIndex" integer mapping is - // convenient. - SmallVector canonicalInputIndices; - for (OpOperand *opOperand : op.getInputOperands()) { - AffineMap indexingMap = op.getTiedIndexingMap(opOperand); - // STL-like maps have a convenient behavior for our use case here. In the - // case of duplicate keys, the insertion is rejected, and the returned - // iterator gives access to the value already in the map. - auto pair = canonicalInput.insert( - {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); - canonicalInputIndices.push_back(pair.first->second); - } - - // If there are no duplicate args, then bail out. - if (canonicalInput.size() == op.getNumInputs()) - return failure(); - - // The operands for the newly canonicalized op. - SmallVector newOperands; - for (OpOperand *opOperand : op.getInputOperands()) - if (canonicalInputIndices[opOperand->getOperandNumber()] == - opOperand->getOperandNumber()) - newOperands.push_back(opOperand->get()); - SmallVector outputOperands = op.getOutputOperands(); - llvm::append_range(newOperands, outputOperands); - - // Repair the indexing maps by filtering out the ones that have been - // eliminated. - SmallVector newIndexingMaps; - for (OpOperand *opOperand : op.getInputOperands()) - if (canonicalInputIndices[opOperand->getOperandNumber()] == - opOperand->getOperandNumber()) - newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); - for (OpOperand *opOperand : op.getOutputOperands()) - newIndexingMaps.push_back(op.getTiedIndexingMap(opOperand)); - - // Clone the old op with new operands. - Operation *newOp = - op.clone(rewriter, op->getLoc(), op->getResultTypes(), newOperands); - auto newLinalgOp = cast(newOp); - newOp->setAttr("indexing_maps", - rewriter.getAffineMapArrayAttr(newIndexingMaps)); - - // Set the number of inputs to the new value. The `clone` call above kept - // the value from the original op. - newLinalgOp.setNumInputs(canonicalInput.size()); - - // Repair the payload entry block by RAUW'ing redundant arguments and - // erasing them. - Block &payload = newOp->getRegion(0).front(); - SmallVector inputOperands = op.getInputOperands(); - for (OpOperand *opOperand : llvm::reverse(inputOperands)) { - // Iterate in reverse, so that we erase later args first, preventing the - // argument list from shifting unexpectedly and invalidating all our - // indices. - unsigned operandNumber = opOperand->getOperandNumber(); - if (canonicalInputIndices[operandNumber] == operandNumber) - continue; - payload.getArgument(operandNumber) - .replaceAllUsesWith( - payload.getArgument(canonicalInputIndices[operandNumber])); - payload.eraseArgument(operandNumber); - } - - rewriter.replaceOp(op, newOp->getResults()); - return success(); - } -}; - -/// Remove generic operations (on tensors) that are just copying -/// the values from inputs to the results. Requirements are -/// 1) All iterator types are parallel -/// 2) The body contains just a yield operation with the yielded values being -/// the arguments corresponding to the operands. -struct RemoveIdentityLinalgOps : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - LogicalResult matchAndRewrite(LinalgOp op, - PatternRewriter &rewriter) const override { - if (!isa(op)) - return failure(); - if (!op.hasTensorSemantics()) - return failure(); - // Check all indexing maps are identity. - if (llvm::any_of(op.getIndexingMaps(), - [](AffineMap map) { return !map.isIdentity(); })) - return failure(); - - // Check that the body of the linalg operation is just a linalg.yield - // operation. - Block &body = op->getRegion(0).front(); - if (!llvm::hasSingleElement(body)) - return failure(); - auto yieldOp = dyn_cast(body.getTerminator()); - if (!yieldOp) - return failure(); - - // Get the argument number of the returned values. That is the operand - // number to use for replacing uses of this operation. - SmallVector returnedArgs; - for (Value yieldVal : yieldOp.values()) { - auto yieldArg = yieldVal.dyn_cast(); - if (!yieldArg || yieldArg.getOwner() != &body) - return failure(); - unsigned argumentNumber = yieldArg.getArgNumber(); - returnedArgs.push_back(op->getOperand(argumentNumber)); - } - if (returnedArgs.size() != op.getOperation()->getNumResults()) - return failure(); - rewriter.replaceOp(op, returnedArgs); - return success(); - } -}; -} // namespace - #define LINALGOP_FOLDERS(XXX) \ LogicalResult XXX::fold(ArrayRef, \ SmallVectorImpl &) { \ @@ -2699,6 +2694,5 @@ void LinalgDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { - results.add(getContext()); + results.add(getContext()); }