diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -90,31 +90,11 @@ def LinalgElementwiseOpFusion : Pass<"linalg-fuse-elementwise-ops"> { let summary = "Fuse elementwise operations on tensors"; let constructor = "mlir::createLinalgElementwiseOpFusionPass()"; - let options = [ - Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes", - "bool", /*default=*/"false", - "Allow fusing linalg.tensor_reshape ops that performs unit " - "dimension collapsing"> - ]; let dependentDialects = [ "AffineDialect", "linalg::LinalgDialect", "memref::MemRefDialect" ]; } -def LinalgFoldReshapeOpsByLinearization : - Pass<"linalg-fold-reshape-ops-by-linearization"> { - let summary = "Fold TensorReshapeOps with generic/indexed generic ops by " - "linearization"; - let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()"; - let options = [ - Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes", - "bool", /*default=*/"false", - "Allow fusing linalg.tensor_reshape ops that performs unit " - "dimension collapsing"> - ]; - let dependentDialects = ["AffineDialect", "memref::MemRefDialect"]; -} - def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> { let summary = "Convert from one named linalg op to another."; let constructor = "mlir::createLinalgNamedOpConversionPass()"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -37,10 +37,6 @@ struct LinalgFusionOptions; struct LinalgTilingOptions; -/// Default function to control reshape folding. Skips folding unit dimension -/// reshapes. -bool skipUnitDimReshape(const OpResult &producer, OpOperand &consumer); - //===----------------------------------------------------------------------===// // Transformations exposed as function calls. //===----------------------------------------------------------------------===// @@ -91,24 +87,6 @@ void populateConstantFoldLinalgOperations(RewritePatternSet &patterns, const ControlFusionFn &controlFn); -/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its -/// producer (consumer) generic operation by linearizing the indexing map used -/// to access the source (target) of the reshape operation in the generic -/// operation. -/// TODO(ravishankarm): These patterns are to be deprecated in favor of using -/// the `populateFoldReshapeByCollapsingPatterns`. -void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns); - -/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its -/// producer (consumer) generic operation by linearizing the indexing map used -/// to access the source (target) of the reshape operation in the generic -/// operation. The patterns are applied only when the tensor reshape involved is -/// collapsing (introducing) unit-extent dimensions. -/// TODO(ravishankarm): These patterns are to be deprecated in favor of using -/// the `populateFoldReshapeByCollapsingPatterns`. -void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( - RewritePatternSet &patterns); - /// Pattern to fuse a `tensor.pad` operation with the producer of its source, /// if the producer is a `linalg` operation with all parallel iterator types. void populateFuseTensorPadWithProducerLinalgOpPatterns( @@ -128,12 +106,6 @@ /// Patterns that are used to bubble up extract slice op above linalg op. void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns); -/// Patterns to push reshape op towards the end of the graph in order to expose -/// more fusion opportunities. -/// TODO(ravishankarm): These patterns are to be deprecated in favor of using -/// the `populateFoldReshapeByCollapsingPatterns`. -void populatePushReshapeOpsPatterns(RewritePatternSet &patterns); - /// Perform standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` /// The permutation is expressed as a list of integers that specify diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -392,263 +392,6 @@ }; } // namespace -//===---------------------------------------------------------------------===// -// Methods and patterns that fuse reshape ops with elementwise operations by -// linearization of indexing maps. -//===---------------------------------------------------------------------===// - -// TODO(ravishankarm): The indexing maps -// these produce in the general case are detrimental to transformations. -// These patterns are on deprecation path in favor of using fusion by -// collapsing, which covers the only legitimate use case of this pattern of -// folding unit-extent dims. - -/// Linearize the expressions in `sourceMap` based on the `reassociationMaps` -/// provided, given the shape of the source tensor that corresponds to the -/// `sourceMap`. Note that this implicitly assumes that the tensors dimensions -/// are "row-major" ordered logically. -/// -/// For example: -/// -/// %0 = op ... : tensor -/// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>` -/// -/// and reshape: -/// %1 = tensor.collapse_shape %0 [[0], [0, 1, 2]] : -/// tensor into tensor -/// -/// would be rewritten into: -/// %0 = op ... : tensor -/// with output index_map -/// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` -template -static AffineMap linearizeCollapsedDims(AffineMap sourceMap, - TensorReshapeOp reshapeOp) { - constexpr bool isExpanding = - std::is_same::value; - ArrayRef sourceShape = - (isExpanding ? reshapeOp.getResultType().getShape() - : reshapeOp.getSrcType().getShape()); - SmallVector resultExprs; - ArrayRef sourceExprs = sourceMap.getResults(); - MLIRContext *context = sourceMap.getContext(); - - // Compute the result exprs based on the reassociation maps. - for (auto &indices : reshapeOp.getReassociationIndices()) { - // Assume that they are in-order and contiguous (already checked in - // verifier). - assert(!indices.empty()); - SmallVector sizes; - SmallVector dimExprs; - for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), - sourceExprs.slice(indices[0], indices.size()))) { - if (std::get<0>(en) == 1) - continue; - sizes.push_back(std::get<0>(en)); - dimExprs.push_back(std::get<1>(en)); - } - AffineExpr linearizedExpr = - makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); - resultExprs.push_back(linearizedExpr); - } - // The new affine map cannot drop unused dimension but some new symbols may - // have been added. Create a map with at least as many dimensions/symbols as - // the original affine map. - int64_t maxDim = -1; - int64_t maxSym = -1; - getMaxDimAndSymbol>({resultExprs}, maxDim, maxSym); - unsigned numDims = std::max(unsigned(maxDim + 1), sourceMap.getNumDims()); - unsigned numSyms = std::max(unsigned(maxSym + 1), sourceMap.getNumSymbols()); - return AffineMap::get(numDims, numSyms, resultExprs, context); -} - -// tensor::ExpandShapeOp is fusable with its consumer (i.e. reshape as a -// producer). Fusing when operand has higher rank will require use of mods and -// divs in the indexing maps of the fused op which would make it non-invertible. -static bool isTensorReshapeOpFoldableByLinearization( - tensor::ExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { - if (!asProducer) - return false; - return useIndexMap.isPermutation(); -} - -// tensor::CollapseShapeOp is fusable with its producer (i.e. reshape as a -// consumer). -static bool -isTensorReshapeOpFoldableByLinearization(tensor::CollapseShapeOp collapseOp, - AffineMap useIndexMap, - bool asProducer) { - if (asProducer) - return false; - return useIndexMap.isPermutation(); -} - -/// Check if the reshape operation is only expansion into/collapsing of -/// unit-dimension. -template -static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { - constexpr bool isExpanding = - std::is_same::value; - ArrayRef expandedShape = - (isExpanding ? reshapeOp.getResultType().getShape() - : reshapeOp.getSrcType().getShape()); - for (auto &indices : reshapeOp.getReassociationIndices()) { - unsigned numUnitDims = 0; - for (int64_t position : indices) - if (expandedShape[position] == 1) - numUnitDims++; - if (numUnitDims != indices.size() - 1) - return false; - } - return true; -} - -namespace { -/// Pattern to fold tensor_expand_shape op with its consumer by using the source -/// of the reshape op as the operand in the consumer (instead of the result of -/// the tensor_collapse_shape). The corresponding index map in the consumer -/// needs to be modified to linearize the folded dimension. -/// -/// For example, -/// -/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -/// %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] -/// tensor into tensor -/// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... } -/// ins(%0, %arg1 : tensor, tensor) ... -/// -> tensor -/// -/// can be folded into -/// -/// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -/// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -/// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } -/// ins(%arg0, %arg1 : tensor, tensor) ... -/// -> tensor -template -struct FoldProducerReshapeOpByLinearization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - SmallVector inputOperands = genericOp.getInputOperands(); - for (const auto &en : llvm::enumerate(inputOperands)) { - auto reshapeOp = en.value()->get().getDefiningOp(); - if (!reshapeOp) - continue; - - if (!isTensorReshapeOpFoldableByLinearization( - reshapeOp, genericOp.getTiedIndexingMap(en.value()), - /*asProducer =*/true) || - (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) - continue; - - // Compute the fused operands list, - SmallVector fusedOperands = genericOp.getInputOperands(); - fusedOperands[en.index()] = reshapeOp.src(); - SmallVector outputOperands = genericOp.getOutputOperands(); - llvm::append_range(fusedOperands, outputOperands); - - // Compute indexing_maps for the fused operation. The indexing_maps for - // the operands of the consumers that arent fused are the same. - SmallVector fusedIndexMaps = genericOp.getIndexingMaps(); - - // Compute the indexing map to use for the result of the producer. - AffineMap modifiedMap = - linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); - // The modified map cannot have symbols. - if (modifiedMap.getNumSymbols()) - return failure(); - for (AffineExpr expr : modifiedMap.getResults()) { - if (!expr.isPureAffine()) - return failure(); - } - fusedIndexMaps[en.index()] = modifiedMap; - - // Further check that the resulting index maps can be fused and - // inverted. Without this the resultant op is not legal. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { - return rewriter.notifyMatchFailure( - genericOp, "fused op loop bound computation failed"); - } - - rewriter.startRootUpdate(genericOp); - genericOp->setOperands(fusedOperands); - genericOp.indexing_mapsAttr( - rewriter.getAffineMapArrayAttr(fusedIndexMaps)); - rewriter.finalizeRootUpdate(genericOp); - return success(); - } - return failure(); - } -}; - -/// Pattern to fold tensor_collapse_shape or tensor_expand_shape op with its -/// producer. The corresponding index map in the consumer needs to be modified -/// to linearize the folded dimension. -template -struct FoldConsumerReshapeOpByLinearization - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, - PatternRewriter &rewriter) const override { - GenericOp producer = reshapeOp.src().template getDefiningOp(); - if (!producer || !producer.hasTensorSemantics() || - producer.getNumOutputs() != 1 || - !isTensorReshapeOpFoldableByLinearization( - reshapeOp, - producer.getTiedIndexingMap(producer.getOutputOperand(0)), - /*asProducer =*/false) || - (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) - return failure(); - // The indexing_maps for the operands of the fused operation are same as - // those for the operands of the producer. - SmallVector fusedIndexMaps = producer.getIndexingMaps(); - - // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = linearizeCollapsedDims( - producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); - for (AffineExpr expr : modifiedMap.getResults()) { - if (!expr.isPureAffine()) { - return rewriter.notifyMatchFailure( - producer, "fused op indexing map is not affine"); - } - } - fusedIndexMaps.back() = modifiedMap; - - // Further check that the resulting index maps can be fused and - // inverted. Without this the resultant op is not legal. - if (!inversePermutation(concatAffineMaps(fusedIndexMaps))) { - return rewriter.notifyMatchFailure( - producer, "fused op loop bound computation failed"); - } - - Location loc = producer.getLoc(); - SmallVector inputOperands = producer.getInputOperands(); - Value output = rewriter.create( - loc, producer.getOutputOperand(0)->get(), - reshapeOp.getReassociationExprs()); - auto fusedOp = rewriter.create( - loc, reshapeOp.getResultType(), - /*inputs=*/inputOperands, - // TODO: handle outputs. - /*outputs=*/output, rewriter.getAffineMapArrayAttr(fusedIndexMaps), - producer.iterator_types(), - /*doc=*/nullptr, - /*library_call=*/nullptr); - auto &fusedRegion = fusedOp->getRegion(0); - rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion, - fusedRegion.begin()); - rewriter.replaceOp(reshapeOp, fusedOp->getResults()); - return success(); - } -}; -} // namespace - //===---------------------------------------------------------------------===// // Methods and patterns that fuse reshape ops with elementwise operations by // expanding the dimensionality of the elementwise operations. @@ -1737,174 +1480,6 @@ }; } // namespace -//===---------------------------------------------------------------------===// -// Methods and patterns to convert tensor.expand_shape -> linalg.generic -// into linalg.generic -> tensor.expand_shape, i.e. push the reshape down. -//===---------------------------------------------------------------------===// - -// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by -// collapsing that provides a more general functionality. This pattern is very -// specific to a particular use case. The fusion by collapsing can provide the -// same control to clients using the control function there. - -static SmallVector -getReassociationIndices(ArrayRef maps) { - SmallVector reassociation; - for (AffineMap map : maps) { - ReassociationIndices indices; - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - unsigned pos = map.getResult(i).cast().getPosition(); - indices.push_back(pos); - } - reassociation.push_back(indices); - } - return reassociation; -} - -namespace { -/// Pattern to move rank reducing reshape after an elementwise linalg generic -/// op. This is useful to expose more fusion opportunities between named ops and -/// generic ops. This can only be done if there is no broadcast or permuation -/// within the dimensions we need to merge. -/// -/// For example, -/// -/// %0 = tensor.expand_shape %A [[0, 1], [2]] -/// : tensor<12544x16xf32> into tensor<112x112x16xf32> -/// %2 = linalg.generic {indexing_maps = [ -/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>, -/// affine_map<(d0, d1, d2) -> (d2)>, -/// affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = -/// ["parallel", "parallel", "parallel"]} { -/// } -> tensor<112x112x16xf32> -/// -/// into -/// -/// %2 = linalg.generic {indexing_maps = [ -/// affine_map<(d0, d1) -> (d0, d1)>, -/// affine_map<(d0, d1) -> (d1)>, -/// affine_map<(d0, d1) -> (d0, d1)>], -/// iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 -/// : tensor<12544x16xf32>, tensor<16xf32>) outs(%1 : tensor<12544x16xf32>) { -/// } -> tensor<12544x16xf32> -/// %3 = tensor.expand_shape %2 [[0, 1], [2]] -/// : tensor<12544x16xf32> into tensor<112x112x16xf32> -struct PushExpandingReshape : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - // Only apply to elementwise linalg on tensor. - if (!genericOp.hasTensorSemantics() || genericOp.hasIndexSemantics() || - genericOp.getNumParallelLoops() != genericOp.getNumLoops()) - return failure(); - // Only support identity output maps. It could be extended to permuations if - // needed. - if (llvm::any_of(genericOp.getOutputOperands(), [&](OpOperand *opOperand) { - return !genericOp.getTiedIndexingMap(opOperand).isIdentity(); - })) - return failure(); - int64_t destRank = genericOp.getNumParallelLoops(); - SmallVector newOperands = genericOp.getInputOperands(); - tensor::ExpandShapeOp reshapeFound; - // 1. Look for tensor_expand_shape operands and figure out save the - // dimensions merged. - SmallVector inputOperands = genericOp.getInputOperands(); - for (const auto &en : llvm::enumerate(inputOperands)) { - auto reshapeOp = - en.value()->get().template getDefiningOp(); - if (!reshapeOp) - continue; - // TODO: We could support non-identity map as long as the merged - // dimensions are still contiguous. - if (!genericOp.getTiedIndexingMap(en.value()).isIdentity()) - continue; - if (reshapeFound) { - // Only support a second reshape op if it has the same reassociate maps. - if (reshapeFound.getReassociationMaps() == - reshapeOp.getReassociationMaps()) - newOperands[en.index()] = reshapeOp.src(); - continue; - } - reshapeFound = reshapeOp; - newOperands[en.index()] = reshapeOp.src(); - } - if (!reshapeFound) - return failure(); - - // Calculate the reassociation indices and rassociated reverse map. - SmallVector reassociation = - getReassociationIndices(reshapeFound.getReassociationMaps()); - SmallVector remap(destRank); - for (auto &indices : llvm::enumerate(reassociation)) { - for (int64_t index : indices.value()) { - remap[index] = indices.index(); - } - } - // 2. Verify that we can merge the dimensions in the linalg and that we - // don't need to create new reshapes operands. Inserting new reshape - // operands would defeat the purpose of the transformation. - for (const auto &en : llvm::enumerate(inputOperands)) { - if (en.value()->get() == newOperands[en.index()]) { - AffineMap map = genericOp.getTiedIndexingMap(en.value()); - for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) { - if (reassociation[remap[map.getDimPosition(i)]].size() > 1) - return failure(); - } - } - } - - // 3. Calculate the affine map remapping and the reassociation to apply to - // output tensors. - SmallVector newMaps; - unsigned newRank = reassociation.size(); - for (auto map : genericOp.getIndexingMaps()) { - SmallVector newExprs; - for (auto expr : map.getResults()) { - unsigned position = expr.template cast().getPosition(); - // Skip dimension merged except for the last of the group. - if (reassociation[remap[position]].back() == position) { - newExprs.push_back( - getAffineDimExpr(remap[position], genericOp.getContext())); - } - } - newMaps.push_back( - AffineMap::get(newRank, 0, newExprs, genericOp.getContext())); - } - - // 4. Reshape the output tensors. - SmallVector newOutputs; - SmallVector newOutputTypes; - for (auto output : genericOp.outputs()) { - auto newOutputType = RankedTensorType::get( - reshapeFound.getSrcType().getShape(), - output.getType().template cast().getElementType()); - Value newOutput = rewriter.create( - genericOp->getLoc(), newOutputType, output, reassociation); - newOutputTypes.push_back(newOutputType); - newOutputs.push_back(newOutput); - } - // 5. Create a new generic op with lowerer rank. - SmallVector iteratorTypes(newRank, - getParallelIteratorTypeName()); - auto newOp = rewriter.create(genericOp->getLoc(), newOutputTypes, - newOperands, newOutputs, newMaps, - iteratorTypes); - rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), - newOp.region().begin()); - // 6. Reshape the so that the type matches the uses. - SmallVector newResults; - for (const auto &result : llvm::enumerate(newOp->getResults())) { - newResults.push_back(rewriter.create( - genericOp->getLoc(), genericOp.getOutputTensorTypes()[result.index()], - result.value(), reassociation)); - } - rewriter.replaceOp(genericOp, newResults); - return success(); - } -}; -} // namespace - //===---------------------------------------------------------------------===// // Methods and patterns that fuse constants with linalg.generic operations. //===---------------------------------------------------------------------===// @@ -2093,27 +1668,6 @@ } }; } // namespace -//===---------------------------------------------------------------------===// -// Methods that add patterns described in this file to a pattern list. -//===---------------------------------------------------------------------===// - -void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( - RewritePatternSet &patterns) { - patterns.add< - FoldProducerReshapeOpByLinearization, - FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( - patterns.getContext()); -} - -void mlir::linalg::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( - RewritePatternSet &patterns) { - patterns - .add, - FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>( - patterns.getContext()); -} void mlir::linalg::populateFoldReshapeOpsByExpansionPatterns( RewritePatternSet &patterns, @@ -2140,28 +1694,10 @@ RemoveOutsDependency>(context); } -void mlir::linalg::populatePushReshapeOpsPatterns(RewritePatternSet &patterns) { - auto *context = patterns.getContext(); - patterns.add(context); -} - //===---------------------------------------------------------------------===// // Passes //===---------------------------------------------------------------------===// -bool mlir::linalg::skipUnitDimReshape(const OpResult &producer, - OpOperand &consumer) { - if (auto producerCollapseOp = - dyn_cast(producer.getOwner())) { - return !isUnitDimExpansionOnly(producerCollapseOp); - } - if (auto consumerExpandOp = - dyn_cast(consumer.getOwner())) { - return !isUnitDimExpansionOnly(consumerExpandOp); - } - return true; -} - namespace { /// Pass that fuses generic ops on tensors. Used only for testing. @@ -2186,9 +1722,7 @@ // Add elementwise op fusion patterns. populateElementwiseOpsFusionPatterns(patterns, defaultControlFn); - populateFoldReshapeOpsByExpansionPatterns( - patterns, - allowFoldingUnitDimReshapes ? defaultControlFn : skipUnitDimReshape); + populateFoldReshapeOpsByExpansionPatterns(patterns, defaultControlFn); // Add the sparse tensor rewriting patterns. populateSparseTensorRewriting(patterns); @@ -2212,27 +1746,8 @@ } }; -/// Pass to test folding of reshape ops with generic ops by linearization. -struct FoldReshapeOpsByLinearizationPass - : public LinalgFoldReshapeOpsByLinearizationBase< - FoldReshapeOpsByLinearizationPass> { - void runOnOperation() override { - Operation *op = getOperation(); - RewritePatternSet patterns(op->getContext()); - populateFoldReshapeOpsByLinearizationPatterns(patterns); - if (allowFoldingUnitDimReshapes) { - populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); - } - (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); - } -}; - } // namespace std::unique_ptr mlir::createLinalgElementwiseOpFusionPass() { return std::make_unique(); } - -std::unique_ptr mlir::createFoldReshapeOpsByLinearizationPass() { - return std::make_unique(); -} diff --git a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir --- a/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir +++ b/mlir/test/Dialect/Linalg/fusion-push-reshape.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=push-expanding-reshape -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s // CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d1)> @@ -124,30 +124,3 @@ // CHECK-SAME: outs(%{{.+}} : tensor<6x5xf32>) // CHECK: tensor.expand_shape %[[OP]] // CHECK-SAME: tensor<6x5xf32> into tensor<2x3x5xf32> - -// ----- - -func.func @generic_op_index_semantics(%A: tensor, %B: tensor<16xi64>, %init: tensor) -> tensor { - %0 = tensor.expand_shape %A [[0, 1], [2]] - : tensor into tensor - %2 = linalg.generic {indexing_maps = [ - affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d2)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0, %B : tensor, tensor<16xi64>) - outs(%init : tensor) { - ^bb0(%arg1: i64, %arg2: i64, %arg3: i64): // no predecessors - %index = linalg.index 0 : index - %1 = arith.index_cast %index : index to i64 - %add = arith.addi %arg1, %1 : i64 - %s = arith.subi %add, %arg2 : i64 - linalg.yield %s : i64 - } -> tensor - return %2 : tensor -} -// CHECK: func @generic_op_index_semantics -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[RESHAPE]] -// CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=false" -split-input-file | FileCheck %s -// RUN: mlir-opt %s -linalg-fuse-elementwise-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-expansion -split-input-file | FileCheck %s + #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> #map2 = affine_map<(d0, d1, d2) -> ()> @@ -14,7 +14,7 @@ indexing_maps = [#map0, #map1, #map2, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0, %arg1, %arg2 : tensor, tensor, f32) - outs(%0 : tensor) { + outs(%arg1 : tensor) { ^bb0(%arg3: f32, %arg4: f32, %arg5: f32, %s: f32): %1 = arith.mulf %arg3, %arg4 : f32 %2 = arith.addf %1, %arg5 : f32 @@ -30,15 +30,15 @@ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: f32 -// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]] -// CHECK-SAME: [0], [1, 2], [3] // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0], [1], [2, 3] +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] +// CHECK-SAME: [0], [1], [2, 3] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[T1]], %[[ARG2]] : tensor, tensor, f32) -// CHECK-SAME: outs(%{{.+}} : tensor) +// CHECK-SAME: outs(%[[T2]] : tensor) // CHECK: %[[T4:.+]] = tensor.collapse_shape %[[T3]] // CHECK-SAME: [0], [1], [2, 3] // CHECK-SAME: tensor into tensor @@ -80,12 +80,14 @@ // CHECK-SAME: tensor into tensor // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0], [1, 2, 3] +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP3]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]], %[[ARG2]] : tensor, tensor, f32) -// CHECK-SAME: outs(%{{.+}} : tensor) +// CHECK-SAME: outs(%[[T2]] : tensor) // CHECK: return %[[T3]] : tensor @@ -121,11 +123,14 @@ // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor<3x4x?x?xf32> +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: [0, 1], [2], [3, 4, 5]] +// CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>) -// CHECK-SAME: outs(%{{.+}} : tensor) +// CHECK-SAME: outs(%[[T2]] : tensor) // CHECK: return %[[T3]] : tensor // ----- @@ -155,14 +160,19 @@ // CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: func @generic_op_reshape_consumer_static // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<264x4xf32> +// CHECK-DAG: %[[CST:.+]] = arith.constant +// CHECK-SAME: : tensor<8x33x4xf32> +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [264, 4] // CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1], [2] // CHECK-SAME: tensor<264x4xf32> into tensor<8x33x4xf32> -// CHECK: %[[T1:.+]] = linalg.init_tensor [8, 33, 4] +// CHECK: %[[T1:.+]] = tensor.expand_shape %[[INIT]] +// CHECK-SAME: [0, 1], [2] +// CHECK-SAME: : tensor<264x4xf32> into tensor<8x33x4xf32> // CHECK: %[[T2:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel"] -// CHECK-SAME: ins(%[[T0]] : tensor<8x33x4xf32>) +// CHECK-SAME: ins(%[[T0]], %[[CST]] : // CHECK-SAME: outs(%[[T1]] : tensor<8x33x4xf32>) // CHECK: return %[[T2]] : tensor<8x33x4xf32> @@ -246,7 +256,8 @@ } // Only check the body in the indexed version of the test. -// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 5 + d2 * 20)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 4)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 5)> // CHECK: func @indexed_producer_reshape_consumer_fusion // CHECK: linalg.generic // CHECK: ^{{.*}}( @@ -256,11 +267,12 @@ // CHECK-DAG: %[[IDX1:.+]] = linalg.index 1 : index // CHECK-DAG: %[[IDX2:.+]] = linalg.index 2 : index // CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index -// CHECK-DAG: %[[T3:.+]] = affine.apply #[[MAP]](%[[IDX3]], %[[IDX2]], %[[IDX1]]) +// CHECK: %[[T1:.+]] = affine.apply #[[MAP1]](%[[IDX2]], %[[IDX1]]) +// CHECK: %[[T2:.+]] = affine.apply #[[MAP2]](%[[IDX3]], %[[T1]]) // CHECK: %[[T4:.+]] = arith.muli %[[ARG3]], %[[ARG4]] // CHECK: %[[T5:.+]] = arith.index_cast %[[IDX0]] // CHECK: %[[T6:.+]] = arith.addi %[[T4]], %[[T5]] -// CHECK: %[[T7:.+]] = arith.index_cast %[[T3]] +// CHECK: %[[T7:.+]] = arith.index_cast %[[T2]] // CHECK: %[[T8:.+]] = arith.addi %[[T6]], %[[T7]] // CHECK: linalg.yield %[[T8]] @@ -295,24 +307,29 @@ return %d : tensor<2x3x4x5x6x7xi32> } +// ----- -// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> -// CHECK-DAG: #[[MAP6:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> -// CHECK-DAG: #[[MAP7:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> -// CHECK-DAG: #[[MAP8:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> -// CHECK-DAG: #[[MAP9:.+]] = affine_map<(d0, d1, d2) -> (d0 + d1 * 7 + d2 * 42)> +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d0, d1, d5)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d3, d4, d5)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d2, d3, d4)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 6)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 7)> // CHECK: func @reshape_as_consumer_permutation // CHECK-SAME: %[[ARG0:.+]]: tensor<210x6x4xi32> // CHECK-SAME: %[[ARG1:.+]]: tensor<210x4xi32> +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [6, 4, 210] // CHECK-DAG: %[[T1:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: [0, 1, 2], [3, 4], [5] // CHECK-DAG: %[[T2:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] -// CHECK-DAG: %[[T0:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7] +// CHECK-DAG: %[[T3:.+]] = tensor.expand_shape %[[INIT]] +// CHECK-SAME: [0, 1], [2], [3, 4, 5] +// CHECK-SAME: : tensor<6x4x210xi32> into tensor<2x3x4x5x6x7xi32> // CHECK: %[[T4:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP7]]] +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] // CHECK-SAME: ins(%[[T1]], %[[T2]] : tensor<5x6x7x2x3x4xi32>, tensor<5x6x7x4xi32>) -// CHECK-SAME: outs(%[[T0]] : tensor<2x3x4x5x6x7xi32>) +// CHECK-SAME: outs(%[[T3]] : tensor<2x3x4x5x6x7xi32>) // CHECK: ^{{.+}}( // CHECK-SAME: %[[ARG8:[a-zA-Z0-9]+]]: i32, %[[ARG9:[a-zA-Z0-9]+]]: i32, // CHECK-SAME: %[[ARG10:[a-zA-Z0-9]+]]: i32) @@ -322,15 +339,16 @@ // CHECK-DAG: %[[IDX3:.+]] = linalg.index 3 : index // CHECK-DAG: %[[IDX4:.+]] = linalg.index 4 : index // CHECK-DAG: %[[IDX5:.+]] = linalg.index 5 : index -// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP8]](%[[IDX1]], %[[IDX0]]) -// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP9]](%[[IDX4]], %[[IDX3]], %[[IDX2]]) -// CHECK-DAG: %[[T7:.+]] = arith.addi %[[ARG8]], %[[ARG9]] -// CHECK: %[[T8:.+]] = arith.index_cast %[[T5]] -// CHECK: %[[T9:.+]] = arith.addi %[[T7]], %[[T8]] -// CHECK: %[[T10:.+]] = arith.index_cast %[[T6]] -// CHECK: %[[T11:.+]] = arith.addi %[[T9]], %[[T10]] -// CHECK: %[[T12:.+]] = arith.index_cast %[[IDX5]] -// CHECK: %[[T13:.+]] = arith.addi %[[T11]], %[[T12]] +// CHECK-DAG: %[[T5:.+]] = affine.apply #[[MAP3]](%[[IDX1]], %[[IDX0]]) +// CHECK-DAG: %[[T6:.+]] = affine.apply #[[MAP4]](%[[IDX3]], %[[IDX2]]) +// CHECK-DAG: %[[T7:.+]] = affine.apply #[[MAP5]](%[[IDX4]], %[[T6]]) +// CHECK-DAG: %[[T8:.+]] = arith.addi %[[ARG8]], %[[ARG9]] +// CHECK: %[[T9:.+]] = arith.index_cast %[[T5]] +// CHECK: %[[T10:.+]] = arith.addi %[[T8]], %[[T9]] +// CHECK: %[[T11:.+]] = arith.index_cast %[[T7]] +// CHECK: %[[T12:.+]] = arith.addi %[[T10]], %[[T11]] +// CHECK: %[[T13:.+]] = arith.index_cast %[[IDX5]] +// CHECK: %[[T14:.+]] = arith.addi %[[T12]], %[[T13]] // ----- @@ -421,94 +439,18 @@ // CHECK: %[[T1:.+]] = tensor.expand_shape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor +// CHECK: %[[T2:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-SAME: [0], [1, 2, 3] +// CHECK-SAME: tensor into tensor // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) -// CHECK-SAME: outs(%{{.+}} : tensor) +// CHECK-SAME: outs(%[[T2]] : tensor) // CHECK: return %[[T3]] : tensor // ----- -func.func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> { - %0 = tensor.collapse_shape %arg0 [[0, 1]] - : tensor<1x5xf32> into tensor<5xf32> - %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32> - %2 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - linalg.yield %arg2 : f32 - } -> tensor<5x5xf32> - return %2 : tensor<5x5xf32> -} -// CHECK: func @unit_dim_reshape_expansion -// CHECK-DAG: tensor.collapse_shape -// CHECK-DAG: linalg.init_tensor -// CHECK: linalg.generic - -// ----- - -func.func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> { - %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32> - %1 = linalg.generic - {indexing_maps = [affine_map<(d0, d1) -> (d0)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - linalg.yield %arg2 : f32 - } -> tensor<5x5xf32> - %2 = tensor.expand_shape %1 [[0, 1], [2]] - : tensor<5x5xf32> into tensor<5x1x5xf32> - return %2 : tensor<5x1x5xf32> -} -// CHECK: func @unit_dim_reshape_collapse -// CHECK: linalg.init_tensor -// CHECK: linalg.generic -// CHECK: tensor.expand_shape - -// ----- - -func.func @unit_dim_reshape_expansion_full - (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor) - -> tensor { - %c1 = arith.constant 1 : index - %0 = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] - : tensor<1x?x1x2x1x4xf32> into tensor - %1 = tensor.dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32> - %2 = linalg.init_tensor [%1, 2, 4] : tensor - %3 = linalg.generic - {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>, - affine_map<(d0, d1, d2) -> (d0, d1, d2)>], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0, %arg1 : tensor, tensor) - outs(%2 : tensor) { - ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): - %4 = arith.mulf %arg2, %arg3 : f32 - linalg.yield %4 : f32 - } -> tensor - return %3 : tensor -} -// CHECK: func @unit_dim_reshape_expansion_full -// CHECK-DAG: tensor.collapse_shape -// CHECK-DAG: linalg.init_tensor -// CHECK: linalg.generic -// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor, tensor) - -// FOLDUNITDIM: func @unit_dim_reshape_expansion_full -// FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32> -// FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor -// FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG1]] -// FOLDUNITDIM: linalg.generic -// FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) -// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>) - -// ----- - func.func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { %c0 = arith.constant 0 : index %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor into tensor @@ -554,7 +496,6 @@ // CHECK-SAME: %[[ARG0:.+]]: tensor<2x1xi64> // CHECK-SAME: %[[ARG1:.+]]: tensor // CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] -// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor to tensor<2xi64> // CHECK: %[[GENERIC:.+]] = linalg.generic -// CHECK-SAME: ins(%[[RESHAPE]], %[[CAST]] : tensor<2xi64>, tensor<2xi64>) +// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<2xi64>, tensor) // CHECK: return %[[GENERIC]] diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir +++ /dev/null @@ -1,287 +0,0 @@ -// RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s - -// Note: These tests fuse the reshape ops by linearization. This can create -// indexing maps which are hard to analyse later on. These patterns are useful -// only if the folded dimensions in the reshape op are unit extent. Tests here -// are more general for testing purposes, but use of these pattern for non-unit -// dimensions should be deprecated. - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @generic_op_reshape_producer_fusion(%arg0 : tensor) - -> tensor { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : - tensor into tensor - %1 = linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } - ins(%0 : tensor) - outs(%0 : tensor) { - ^bb0(%arg6: i32, %arg7 : i32): - %idx = linalg.index 0 : index - %2 = arith.index_cast %idx : index to i32 - %3 = arith.addi %arg6, %2 : i32 - linalg.yield %3 : i32 - } -> tensor - return %1 : tensor -} -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)> -// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK: func @generic_op_reshape_producer_fusion -// CHECK-SAME: %[[ARG0:.+]]: tensor -// CHECK: %[[T0:.+]] = tensor.expand_shape %[[ARG0]] -// CHECK-SAME: [0], [1, 2], [3] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]] -// CHECK-SAME: ins(%[[ARG0]] : tensor) -// CHECK-SAME: outs(%[[T0]] : tensor) -// CHECK: %[[IDX:.+]] = linalg.index 0 : index -// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32 - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @generic_op_reshape_consumer_fusion(%arg0 : tensor) - -> tensor { - %0 = linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] } - ins(%arg0 : tensor) outs(%arg0 : tensor) { - ^bb0(%arg6: i32, %arg7: i32): - %idx = linalg.index 0 : index - %2 = arith.index_cast %idx : index to i32 - %3 = arith.addi %arg6, %2 : i32 - linalg.yield %3 : i32 - } -> tensor - %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : - tensor into tensor - return %1 : tensor -} -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)> -// CHECK: func @generic_op_reshape_consumer_fusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[T0:.+]] = tensor.collapse_shape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] -// CHECK-SAME: outs(%[[T0]] : tensor) -// CHECK: %[[IDX:.+]] = linalg.index 0 : index -// CHECK-NEXT: %[[IDX_CASTED:.+]] = arith.index_cast %[[IDX]] : index to i32 -// CHECK-NOT: tensor.collapse_shape - -// ----- - -#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func.func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2]] - : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.init_tensor [3, 7, 5] : tensor<3x7x5xf32> - %2 = linalg.generic - {indexing_maps = [#map2, #map3], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<3x7x5xf32>) { - ^bb0(%arg2: f32, %arg3 : f32): - linalg.yield %arg2 : f32 - } -> tensor<3x7x5xf32> - return %2 : tensor<3x7x5xf32> -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @generic_op_021_permultation_reshape_producer_fusion -// CHECK-NOT: tensor.expand_shape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] - -// ----- - -#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> -func.func @generic_op_120_permutation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2]] - : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.init_tensor [5, 7, 3] : tensor<5x7x3xf32> - %2 = linalg.generic - {indexing_maps = [#map2, #map3], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x7x3xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - linalg.yield %arg2 : f32 - } -> tensor<5x7x3xf32> - return %2 : tensor<5x7x3xf32> -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> -// CHECK: func @generic_op_120_permutation_reshape_producer_fusion -// CHECK-NOT: tensor.expand_shape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] - -// ----- - -#map0 = affine_map<(d0, d1, d2) -> (d0)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -func.func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2]] - : tensor<3x35xf32> into tensor<3x5x7xf32> - %1 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32> - %2 = linalg.generic - {indexing_maps = [#map2, #map3], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%0 : tensor<3x5x7xf32>) outs(%1 : tensor<5x3x7xf32>) { - ^bb0(%arg2: f32, %arg3: f32): - linalg.yield %arg2 : f32 - } -> tensor<5x3x7xf32> - return %2 : tensor<5x3x7xf32> -} - - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK: func @generic_op_102_permultation_reshape_producer_fusion -// CHECK-NOT: tensor.expand_shape -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] - -// ----- - -#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0)> -#map3 = affine_map<(d0, d1, d2) -> (d1, d2)> -func.func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> { - %0 = linalg.init_tensor [5, 3, 7] : tensor<5x3x7xf32> - %1 = linalg.generic - {indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel", "parallel"]} - ins(%arg0 : tensor<3x5x7xf32>) outs(%0 : tensor<5x3x7xf32>) { - ^bb0(%arg2: f32, %arg3 : f32): - linalg.yield %arg2 : f32 - } -> tensor<5x3x7xf32> - %2 = tensor.collapse_shape %1 [[0], [1, 2]] - : tensor<5x3x7xf32> into tensor<5x21xf32> - return %2 : tensor<5x21xf32> -} -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> -// CHECK: func @generic_op_102_permultation_reshape_consumer_fusion -// CHECK-SAME: %[[ARG0:.+]]: tensor<3x5x7xf32> -// CHECK: %[[T0:.+]] = linalg.init_tensor [5, 3, 7] -// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]] -// CHECK-SAME: [0], [1, 2] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] -// CHECK-SAME: ins(%[[ARG0]] : tensor<3x5x7xf32>) -// CHECK-SAME: outs(%[[T1]] : tensor<5x21xf32>) - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @generic_op_reshape_consumer_nofusion(%arg0 : tensor, - %arg1 : tensor) -> - tensor -{ - %0 = linalg.generic { - indexing_maps = [#map0, #map0, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%arg0 : tensor) { - ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): - %1 = arith.mulf %arg3, %arg4 : f32 - linalg.yield %1 : f32 - } -> tensor - %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : - tensor into tensor - return %1 : tensor -} -// CHECK-LABEL: func @generic_op_reshape_consumer_nofusion -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[NOFUSE:.+]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] -// CHECK: %[[RESULT:.+]] = tensor.collapse_shape %[[NOFUSE]] -// CHECK: return %[[RESULT]] - - -// ----- - -func.func @generic_op_permultation_reshape_consumer_fusion_unused_dim(%arg0 : tensor<6x1xf32>) -> tensor<6xi32> { - %0 = linalg.init_tensor [6, 1] : tensor<6x1xi32> - %1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : tensor<6x1xf32>) outs(%0 : tensor<6x1xi32>) { - ^bb0(%arg3: f32, %arg4: i32): - %5 = arith.fptosi %arg3 : f32 to i32 - linalg.yield %5 : i32 - } -> tensor<6x1xi32> - %6 = tensor.collapse_shape %1 [[0, 1]] : tensor<6x1xi32> into tensor<6xi32> - return %6 : tensor<6xi32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @generic_op_permultation_reshape_consumer_fusion_unused_dim -// CHECK-SAME: %[[ARG0:.+]]: tensor<6x1xf32> -// CHECK: %[[T0:.+]] = linalg.init_tensor [6, 1] -// CHECK: %[[T1:.+]] = tensor.collapse_shape %[[T0]] -// CHECK-SAME: [0, 1] -// CHECK: linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: ins(%[[ARG0]] : tensor<6x1xf32>) -// CHECK-SAME: outs(%[[T1]] : tensor<6xi32>) - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> -func.func @permuted_dims_fusion_expand_shape(%arg0 : tensor<3x8x7x240xf32>) -> tensor<4x6x3x8x2x5x7xf32> { - %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6]] - : tensor<3x8x7x240xf32> into tensor<3x2x4x7x8x5x6xf32> - %1 = linalg.init_tensor [4, 6, 3, 8, 2, 5, 7] : tensor<4x6x3x8x2x5x7xf32> - %2 = linalg.generic { - indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} - ins(%0 : tensor<3x2x4x7x8x5x6xf32>) outs(%1 : tensor<4x6x3x8x2x5x7xf32>) { - ^bb0(%arg1 : f32, %arg2 : f32): - linalg.yield %arg1 : f32 - } -> tensor<4x6x3x8x2x5x7xf32> - return %2 : tensor<4x6x3x8x2x5x7xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> -// CHECK: func @permuted_dims_fusion_expand_shape( -// CHECK-SAME: %[[ARG0:.+]]: tensor<3x8x7x240xf32>) -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: ins(%[[ARG0]] : -// CHECK: return %[[RESULT]] - -// ----- - -#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)> -#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> -func.func @permuted_dims_fusion_collapse_shape(%arg0 : tensor<4x6x3x8x2x5x7xf32>) -> tensor<3x8x7x240xf32> { - %0 = linalg.init_tensor [3, 2, 4, 7, 8, 5, 6] : tensor<3x2x4x7x8x5x6xf32> - %1 = linalg.generic { - indexing_maps = [#map1, #map0], - iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} - ins(%arg0 : tensor<4x6x3x8x2x5x7xf32>) outs(%0 : tensor<3x2x4x7x8x5x6xf32>) { - ^bb0(%arg1 : f32, %arg2 : f32): - linalg.yield %arg1 : f32 - } -> tensor<3x2x4x7x8x5x6xf32> - %2 = tensor.collapse_shape %1 [[0], [1, 2], [3], [4, 5, 6]] - : tensor<3x2x4x7x8x5x6xf32> into tensor<3x8x7x240xf32> - return %2 : tensor<3x8x7x240xf32> -} -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)> -// CHECK: func @permuted_dims_fusion_collapse_shape( -// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x3x8x2x5x7xf32>) -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] -// CHECK-SAME: ins(%[[ARG0]] : -// CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir deleted file mode 100644 --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir +++ /dev/null @@ -1,52 +0,0 @@ -// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @do_not_fold1(%arg0 : tensor, %arg1 : tensor) -> tensor -{ - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor - %3 = linalg.generic { - indexing_maps = [#map, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : tensor, tensor) - outs(%2 : tensor) { - ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): - %4 = arith.addf %arg2, %arg3 : f32 - linalg.yield %4 : f32 - } -> tensor - %4 = tensor.expand_shape %3 [[0], [1, 2]] : tensor into tensor - return %4 : tensor -} -// CHECK-LABEL: func @do_not_fold1 -// CHECK: %[[VAL:.+]] = linalg.generic -// CHECK: tensor.expand_shape %[[VAL]] - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -func.func @do_not_fold2(%arg0 : tensor, %arg1 : tensor) -> tensor -{ - %c0 = arith.constant 0 : index - %c1 = arith.constant 1 : index - %0 = tensor.collapse_shape %arg0 [[0], [1, 2]] : tensor into tensor - %1 = tensor.dim %arg1, %c0 : tensor - %2 = tensor.dim %arg1, %c1 : tensor - %3 = linalg.init_tensor [%1, %2] : tensor - %4 = linalg.generic { - indexing_maps = [#map, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%0, %arg1 : tensor, tensor) - outs(%3 : tensor) { - ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): - %4 = arith.addf %arg2, %arg3 : f32 - linalg.yield %4 : f32 - } -> tensor - return %4 : tensor -} -// CHECK-LABEL: func @do_not_fold2 -// CHECK: %[[VAL:.+]] = tensor.collapse_shape -// CHECK: linalg.generic -// CHECK-SAME: ins(%[[VAL]], %{{.+}} : tensor, tensor) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -70,18 +70,18 @@ llvm::cl::desc("Test fusion of generic operations."), llvm::cl::init(false)}; + Option fuseWithReshapeByExpansion{ + *this, "fuse-with-reshape-by-expansion", + llvm::cl::desc( + "Test fusion of generic operations with reshape by expansion"), + llvm::cl::init(false)}; + Option controlFuseByExpansion{ *this, "control-fusion-by-expansion", llvm::cl::desc( "Test controlling fusion of reshape with generic op by expansion"), llvm::cl::init(false)}; - Option pushExpandingReshape{ - *this, "push-expanding-reshape", - llvm::cl::desc("Test linalg expand_shape -> generic " - "to generic -> expand_shape pattern"), - llvm::cl::init(false)}; - Option fuseWithReshapeByCollapsing{ *this, "fuse-with-reshape-by-collapsing", llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " @@ -109,6 +109,17 @@ return; } + if (fuseWithReshapeByExpansion) { + RewritePatternSet fusionPatterns(context); + linalg::populateFoldReshapeOpsByExpansionPatterns( + fusionPatterns, [](const OpResult & /*producer*/, + OpOperand & /*consumer*/) { return true; }); + if (failed(applyPatternsAndFoldGreedily(funcOp.getBody(), + std::move(fusionPatterns)))) + return signalPassFailure(); + return; + } + if (controlFuseByExpansion) { RewritePatternSet fusionPatterns(context); @@ -128,8 +139,9 @@ if (linalgOp && linalgOp.isOutputTensor(&use)) return true; } + return false; } - return linalg::skipUnitDimReshape(producer, consumer); + return true; }; linalg::populateFoldReshapeOpsByExpansionPatterns(fusionPatterns, @@ -139,12 +151,6 @@ return; } - if (pushExpandingReshape) { - RewritePatternSet patterns(context); - linalg::populatePushReshapeOpsPatterns(patterns); - (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); - } - if (fuseWithReshapeByCollapsing) { RewritePatternSet patterns(context); linalg::populateFoldReshapeOpsByCollapsingPatterns(