diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -27,6 +27,10 @@ using namespace mlir; using namespace mlir::linalg; +//===---------------------------------------------------------------------===// +// Methods and patterns that fuse elementwise `linalg.generic` operations. +//===---------------------------------------------------------------------===// + /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of /// the `producer` to use in the fused operation given the indexing map of the /// result of the producer in the consumer. @@ -345,6 +349,58 @@ return SmallVector(fusedOp->getResults()); } +static Optional> +fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, + GenericOp producer, + const ControlElementwiseOpsFusionFn &controlFn) { + if (producer->getNumResults() != 1) + return llvm::None; + + return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, + rewriter); +} + +namespace { +/// Patterns to fuse a generic op, with the producer of its operands. +class FuseElementwiseOps : public OpRewritePattern { +public: + FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(fun) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // Find the first operand that is defined by another generic op on tensors. + for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { + auto producer = + dyn_cast_or_null(opOperand->get().getDefiningOp()); + if (!producer || !producer.hasTensorSemantics()) + continue; + Optional> fusedOpResults = + fuseElementwiseOps(rewriter, opOperand, producer, controlFn); + if (fusedOpResults) { + rewriter.replaceOp(genericOp, *fusedOpResults); + return success(); + } + } + return failure(); + } + +private: + ControlElementwiseOpsFusionFn controlFn; +}; +} // namespace + +//===---------------------------------------------------------------------===// +// Methods and patterns that fuse reshape ops with elementwise operations by +// linearization of indexing maps. +//===---------------------------------------------------------------------===// + +// TODO(ravishankarm): These patterns need to be deprecated. The indexing maps +// these produce in the general case are detrimental to transformations. +// They are useful now only in the limited case of unit-dimension folding. +// Remove these in favor of more general folding by dimension contraction. + /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` /// provided, given the shape of the source tensor that corresponds to the /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions @@ -445,6 +501,157 @@ return true; } +namespace { +/// Pattern to fold tensor_expand_shape op with its consumer by using the source +/// of the reshape op as the operand in the consumer (instead of the result of +/// the tensor_collapse_shape). The corresponding index map in the consumer +/// needs to be modified to linearize the folded dimension. +/// +/// For example, +/// +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] +/// tensor into tensor +/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } +/// ins(%0, %arg1 : tensor, tensor) ... +/// -> tensor +/// +/// can be folded into +/// +/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> +/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } +/// ins(%arg0, %arg1 : tensor, tensor) ... +/// -> tensor +template +struct FoldProducerReshapeOpByLinearization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + SmallVector inputOperands = genericOp.getInputOperands(); + for (const auto &en : llvm::enumerate(inputOperands)) { + auto reshapeOp = en.value()->get().getDefiningOp(); + if (!reshapeOp) + continue; + + if (!isTensorReshapeOpFoldableByLinearization( + reshapeOp, genericOp.getTiedIndexingMap(en.value()), + /*asProducer =*/true) || + (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) + continue; + + // Compute the fused operands list, + SmallVector fusedOperands = genericOp.getInputOperands(); + fusedOperands[en.index()] = reshapeOp.src(); + SmallVector outputOperands = genericOp.getOutputOperands(); + llvm::append_range(fusedOperands, outputOperands); + + // Compute indexing_maps for the fused operation. The indexing_maps for + // the operands of the consumers that arent fused are the same. + SmallVector fusedIndexMaps = genericOp.getIndexingMaps(); + + // Compute the indexing map to use for the result of the producer. + AffineMap modifiedMap = + linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); + // The modified map cannot have symbols. + if (modifiedMap.getNumSymbols()) + return failure(); + for (AffineExpr expr : modifiedMap.getResults()) { + if (!expr.isPureAffine()) + return failure(); + } + fusedIndexMaps[en.index()] = modifiedMap; + + // Further check that the resulting index maps can be fused and + // inverted. Without this the resultant op is not legal. + if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { + return rewriter.notifyMatchFailure( + genericOp, "fused op loop bound computation failed"); + } + + rewriter.startRootUpdate(genericOp); + genericOp->setOperands(fusedOperands); + genericOp.indexing_mapsAttr( + rewriter.getAffineMapArrayAttr(fusedIndexMaps)); + rewriter.finalizeRootUpdate(genericOp); + return success(); + } + return failure(); + } +}; + +/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its +/// producer. The corresponding index map in the consumer needs to be modified +/// to linearize the folded dimension. +template +struct FoldConsumerReshapeOpByLinearization + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + GenericOp producer = reshapeOp.src().template getDefiningOp(); + if (!producer || !producer.hasTensorSemantics() || + producer.getNumOutputs() != 1 || + !isTensorReshapeOpFoldableByLinearization( + reshapeOp, + producer.getTiedIndexingMap(producer.getOutputOperand(0)), + /*asProducer =*/false) || + (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) + return failure(); + // The indexing_maps for the operands of the fused operation are same as + // those for the operands of the producer. + SmallVector fusedIndexMaps = producer.getIndexingMaps(); + + // Compute the indexing map to use for the operand of the producer. + AffineMap modifiedMap = linearizeCollapsedDims( + producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); + for (AffineExpr expr : modifiedMap.getResults()) { + if (!expr.isPureAffine()) { + return rewriter.notifyMatchFailure( + producer, "fused op indexing map is not affine"); + } + } + fusedIndexMaps.back() = modifiedMap; + + // Further check that the resulting index maps can be fused and + // inverted. Without this the resultant op is not legal. + if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { + return rewriter.notifyMatchFailure( + producer, "fused op loop bound computation failed"); + } + + Location loc = producer.getLoc(); + SmallVector inputOperands = producer.getInputOperands(); + Value output = rewriter.create( + loc, producer.getOutputOperand(0)->get(), + reshapeOp.getReassociationExprs()); + auto fusedOp = rewriter.create( + loc, reshapeOp.getResultType(), + /*inputs=*/inputOperands, + // TODO: handle outputs. + /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), + producer.iterator_types(), + /*doc=*/nullptr, + /*library_call=*/nullptr); + auto &fusedRegion = fusedOp->getRegion(0); + rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, + fusedRegion.begin()); + rewriter.replaceOp(reshapeOp, fusedOp->getResults()); + return success(); + } +}; +} // namespace + +//===---------------------------------------------------------------------===// +// Methods and patterns that fuse reshape ops with elementwise operations by +// expanding the dimensionality of the elementwise operations. +//===---------------------------------------------------------------------===// + /// Conditions for folding a generic operation with a reshape op by expanding /// the iteration space dimensionality for tensor operations. These are /// preconditions assumed by `foldReshapeByDimExpansion` which implements the @@ -612,9 +819,9 @@ /// Note that this could be extended to handle dynamic case, but the /// implementation below uses `affine.apply` which seems to have issues when the /// shapes are not static. -LogicalResult isGenericOpExpandable(GenericOp genericOp, - const ExpansionInfo &expansionInfo, - PatternRewriter &rewriter) { +static LogicalResult isGenericOpExpandable(GenericOp genericOp, + const ExpansionInfo &expansionInfo, + PatternRewriter &rewriter) { if (!genericOp.hasIndexSemantics()) return success(); for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { @@ -863,88 +1070,85 @@ namespace { -/// Pattern to fold tensor_expand_shape op with its consumer by using the source -/// of the reshape op as the operand in the consumer (instead of the result of -/// the tensor_collapse_shape). The corresponding index map in the consumer -/// needs to be modified to linearize the folded dimension. -/// -/// For example, -/// -/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] -/// tensor into tensor -/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } -/// ins(%0, %arg1 : tensor, tensor) ... -/// -> tensor -/// -/// can be folded into -/// -/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } -/// ins(%arg0, %arg1 : tensor, tensor) ... -/// -> tensor -template -struct FoldProducerReshapeOpByLinearization +/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, +/// when the reshape op is collapsing dimensions. The dimensionality of the loop +/// in the consumer is expanded. +class FoldWithProducerReshapeOpByExpansion : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +public: + FoldWithProducerReshapeOpByExpansion( + MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - SmallVector inputOperands = genericOp.getInputOperands(); - for (const auto &en : llvm::enumerate(inputOperands)) { - auto reshapeOp = en.value()->get().getDefiningOp(); + for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { + tensor::CollapseShapeOp reshapeOp = + opOperand->get().getDefiningOp(); if (!reshapeOp) continue; - - if (!isTensorReshapeOpFoldableByLinearization( - reshapeOp, genericOp.getTiedIndexingMap(en.value()), - /*asProducer =*/true) || - (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) + // Fold only if + // - The tensor reshape op is folding. + // - All constraints of fusing with reshape by expansion are met. + if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || + (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) continue; - // Compute the fused operands list, - SmallVector fusedOperands = genericOp.getInputOperands(); - fusedOperands[en.index()] = reshapeOp.src(); - SmallVector outputOperands = genericOp.getOutputOperands(); - llvm::append_range(fusedOperands, outputOperands); - - // Compute indexing_maps for the fused operation. The indexing_maps for - // the operands of the consumers that arent fused are the same. - SmallVector fusedIndexMaps = genericOp.getIndexingMaps(); - - // Compute the indexing map to use for the result of the producer. - AffineMap modifiedMap = - linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); - // The modified map cannot have symbols. - if (modifiedMap.getNumSymbols()) + Optional> replacementValues = + fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); + if (!replacementValues) return failure(); - for (AffineExpr expr : modifiedMap.getResults()) { - if (!expr.isPureAffine()) - return failure(); - } - fusedIndexMaps[en.index()] = modifiedMap; - - // Further check that the resulting index maps can be fused and - // inverted. Without this the resultant op is not legal. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { - return rewriter.notifyMatchFailure( - genericOp, "fused op loop bound computation failed"); - } - - rewriter.startRootUpdate(genericOp); - genericOp->setOperands(fusedOperands); - genericOp.indexing_mapsAttr( - rewriter.getAffineMapArrayAttr(fusedIndexMaps)); - rewriter.finalizeRootUpdate(genericOp); + rewriter.replaceOp(genericOp, replacementValues.getValue()); return success(); } return failure(); } + +private: + ControlElementwiseOpsFusionFn controlFoldingReshapes; }; +/// Pattern to fold a tensor_expand_shape op with its producer generic op +/// by expanding the dimensionality of the loop in the producer op. +struct FoldReshapeWithGenericOpByExpansion + : public OpRewritePattern { + + FoldReshapeWithGenericOpByExpansion( + MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, + PatternRewriter &rewriter) const override { + // Fold only if all constraints of fusing with reshape by expansion are met. + GenericOp producer = reshapeOp.src().getDefiningOp(); + if (!producer || producer.getNumOutputs() != 1 || + !isFusableWithReshapeByDimExpansion(producer, + producer.getOutputOperand(0)) || + !controlFoldingReshapes(producer->getResult(0), + reshapeOp->getOpOperand(0))) + return failure(); + Optional> replacementValues = fuseWithReshapeByExpansion( + producer, reshapeOp, producer.getOutputOperand(0), rewriter); + if (!replacementValues) + return failure(); + rewriter.replaceOp(reshapeOp, replacementValues.getValue()); + return success(); + } + +private: + ControlElementwiseOpsFusionFn controlFoldingReshapes; +}; +} // namespace + +//===---------------------------------------------------------------------===// +// Methods and patterns to convert tensor.expand_shape -> linalg.generic +// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down. +//===---------------------------------------------------------------------===// + static SmallVector getReassociationIndices(ArrayRef maps) { SmallVector reassociation; @@ -959,6 +1163,7 @@ return reassociation; } +namespace { /// Pattern to move rank reducing reshape after an elementwise linalg generic /// op. This is useful to expose more fusion opportunities between named ops and /// generic ops. This can only be done if there is no broadcast or permuation @@ -1100,142 +1305,13 @@ return success(); } }; +} // namespace -/// Pattern to fuse a tensor_collapse_shape op with its consumer generic op, -/// when the reshape op is collapsing dimensions. The dimensionality of the loop -/// in the consumer is expanded. -class FoldWithProducerReshapeOpByExpansion - : public OpRewritePattern { -public: - FoldWithProducerReshapeOpByExpansion( - MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - controlFoldingReshapes(std::move(foldReshapes)) {} - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { - tensor::CollapseShapeOp reshapeOp = - opOperand->get().getDefiningOp(); - if (!reshapeOp) - continue; - // Fold only if - // - The tensor reshape op is folding. - // - All constraints of fusing with reshape by expansion are met. - if (!isFusableWithReshapeByDimExpansion(genericOp, opOperand) || - (!controlFoldingReshapes(reshapeOp->getResult(0), *opOperand))) - continue; - - Optional> replacementValues = - fuseWithReshapeByExpansion(genericOp, reshapeOp, opOperand, rewriter); - if (!replacementValues) - return failure(); - rewriter.replaceOp(genericOp, replacementValues.getValue()); - return success(); - } - return failure(); - } - -private: - ControlElementwiseOpsFusionFn controlFoldingReshapes; -}; - -/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its -/// producer. The corresponding index map in the consumer needs to be modified -/// to linearize the folded dimension. -template -struct FoldConsumerReshapeOpByLinearization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - GenericOp producer = reshapeOp.src().template getDefiningOp(); - if (!producer || !producer.hasTensorSemantics() || - producer.getNumOutputs() != 1 || - !isTensorReshapeOpFoldableByLinearization( - reshapeOp, - producer.getTiedIndexingMap(producer.getOutputOperand(0)), - /*asProducer =*/false) || - (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) - return failure(); - // The indexing_maps for the operands of the fused operation are same as - // those for the operands of the producer. - SmallVector fusedIndexMaps = producer.getIndexingMaps(); - - // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = linearizeCollapsedDims( - producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); - for (AffineExpr expr : modifiedMap.getResults()) { - if (!expr.isPureAffine()) { - return rewriter.notifyMatchFailure( - producer, "fused op indexing map is not affine"); - } - } - fusedIndexMaps.back() = modifiedMap; - - // Further check that the resulting index maps can be fused and - // inverted. Without this the resultant op is not legal. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { - return rewriter.notifyMatchFailure( - producer, "fused op loop bound computation failed"); - } - - Location loc = producer.getLoc(); - SmallVector inputOperands = producer.getInputOperands(); - Value output = rewriter.create( - loc, producer.getOutputOperand(0)->get(), - reshapeOp.getReassociationExprs()); - auto fusedOp = rewriter.create( - loc, reshapeOp.getResultType(), - /*inputs=*/inputOperands, - // TODO: handle outputs. - /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), - producer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr); - auto &fusedRegion = fusedOp->getRegion(0); - rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, - fusedRegion.begin()); - rewriter.replaceOp(reshapeOp, fusedOp->getResults()); - return success(); - } -}; - -/// Pattern to fold a tensor_expand_shape op with its producer generic op -/// by expanding the dimensionality of the loop in the producer op. -struct FoldReshapeWithGenericOpByExpansion - : public OpRewritePattern { - - FoldReshapeWithGenericOpByExpansion( - MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - controlFoldingReshapes(std::move(foldReshapes)) {} - - LogicalResult matchAndRewrite(tensor::ExpandShapeOp reshapeOp, - PatternRewriter &rewriter) const override { - // Fold only if all constraints of fusing with reshape by expansion are met. - GenericOp producer = reshapeOp.src().getDefiningOp(); - if (!producer || producer.getNumOutputs() != 1 || - !isFusableWithReshapeByDimExpansion(producer, - producer.getOutputOperand(0)) || - !controlFoldingReshapes(producer->getResult(0), - reshapeOp->getOpOperand(0))) - return failure(); - Optional> replacementValues = fuseWithReshapeByExpansion( - producer, reshapeOp, producer.getOutputOperand(0), rewriter); - if (!replacementValues) - return failure(); - rewriter.replaceOp(reshapeOp, replacementValues.getValue()); - return success(); - } - -private: - ControlElementwiseOpsFusionFn controlFoldingReshapes; -}; +//===---------------------------------------------------------------------===// +// Methods and patterns that fuse constants with linalg.generic operations. +//===---------------------------------------------------------------------===// +namespace { /// Pattern to fold a generic op with a splat constant/scalar constant. Does not /// handle cases where the constant is not single-valued. class FoldScalarOrSplatConstant : public OpRewritePattern { @@ -1624,98 +1700,11 @@ } // namespace -static Optional> -fuseElementwiseOps(PatternRewriter &rewriter, OpOperand *consumerOpOperand, - GenericOp producer, - const ControlElementwiseOpsFusionFn &controlFn) { - if (producer->getNumResults() != 1) - return llvm::None; - - return fuseElementwiseOpsImpl(producer, consumerOpOperand, controlFn, - rewriter); -} - -bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, - OpOperand &consumer) { - if (auto producerCollapseOp = - dyn_cast(producer.getOwner())) { - return !isUnitDimExpansionOnly(producerCollapseOp); - } - if (auto consumerExpandOp = - dyn_cast(consumer.getOwner())) { - return !isUnitDimExpansionOnly(consumerExpandOp); - } - return true; -} +//===---------------------------------------------------------------------===// +// Miscellaneous patterns that help fusion. +//===---------------------------------------------------------------------===// namespace { -/// Patterns to fuse a generic op, with the producer of its operands. -class FuseElementwiseOps : public OpRewritePattern { -public: - FuseElementwiseOps(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(fun) {} - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - // Find the first operand that is defined by another generic op on tensors. - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { - auto producer = - dyn_cast_or_null(opOperand->get().getDefiningOp()); - if (!producer || !producer.hasTensorSemantics()) - continue; - Optional> fusedOpResults = - fuseElementwiseOps(rewriter, opOperand, producer, controlFn); - if (fusedOpResults) { - rewriter.replaceOp(genericOp, *fusedOpResults); - return success(); - } - } - return failure(); - } - -private: - ControlElementwiseOpsFusionFn controlFn; -}; - -/// Pass that fuses generic ops on tensors. Used only for testing. -struct LinalgElementwiseOpFusionPass - : public LinalgElementwiseOpFusionBase { - void runOnOperation() override { - Operation *op = getOperation(); - RewritePatternSet patterns(op->getContext()); - ControlElementwiseOpsFusionFn allowFoldingFn = - [](const OpResult &producer, const OpOperand &consumer) { - return true; - }; - populateElementwiseOpsFusionPatterns( - patterns, - LinalgElementwiseFusionOptions().setControlFoldingReshapes( - allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); - - // Use TopDownTraversal for compile time reasons - GreedyRewriteConfig grc; - grc.useTopDownTraversal = true; - (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), - grc); - } -}; - -/// Pass to test folding of reshape ops with generic ops by linearization. -struct FoldReshapeOpsByLinearizationPass - : public LinalgFoldReshapeOpsByLinearizationBase< - FoldReshapeOpsByLinearizationPass> { - void runOnOperation() override { - Operation *op = getOperation(); - RewritePatternSet patterns(op->getContext()); - populateFoldReshapeOpsByLinearizationPatterns(patterns); - if (allowFoldingUnitDimReshapes) { - populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); - } - (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); - } -}; - /// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if /// the value of the `outs` operand is not used within the op. This is only /// implemented for `linalg.generic` operations for now, but should hold for all @@ -1761,9 +1750,12 @@ return success(); } }; - } // namespace +//===---------------------------------------------------------------------===// +// Methods that add patterns descrined in this file to a pattern list. +//===---------------------------------------------------------------------===// + void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns) { patterns @@ -1815,6 +1807,65 @@ patterns.add(context); } +//===---------------------------------------------------------------------===// +// Passes +//===---------------------------------------------------------------------===// + +bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, + OpOperand &consumer) { + if (auto producerCollapseOp = + dyn_cast(producer.getOwner())) { + return !isUnitDimExpansionOnly(producerCollapseOp); + } + if (auto consumerExpandOp = + dyn_cast(consumer.getOwner())) { + return !isUnitDimExpansionOnly(consumerExpandOp); + } + return true; +} + +namespace { + +/// Pass that fuses generic ops on tensors. Used only for testing. +struct LinalgElementwiseOpFusionPass + : public LinalgElementwiseOpFusionBase { + void runOnOperation() override { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + ControlElementwiseOpsFusionFn allowFoldingFn = + [](const OpResult &producer, const OpOperand &consumer) { + return true; + }; + populateElementwiseOpsFusionPatterns( + patterns, + LinalgElementwiseFusionOptions().setControlFoldingReshapes( + allowFoldingUnitDimReshapes ? allowFoldingFn : skipUnitDimReshape)); + + // Use TopDownTraversal for compile time reasons + GreedyRewriteConfig grc; + grc.useTopDownTraversal = true; + (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns), + grc); + } +}; + +/// Pass to test folding of reshape ops with generic ops by linearization. +struct FoldReshapeOpsByLinearizationPass + : public LinalgFoldReshapeOpsByLinearizationBase< + FoldReshapeOpsByLinearizationPass> { + void runOnOperation() override { + Operation *op = getOperation(); + RewritePatternSet patterns(op->getContext()); + populateFoldReshapeOpsByLinearizationPatterns(patterns); + if (allowFoldingUnitDimReshapes) { + populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); + } + (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); + } +}; + +} // namespace + std::unique_ptr mlir::createLinalgElementwiseOpFusionPass() { return std::make_unique(); }