diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -669,6 +669,22 @@ return *(indexingMaps.begin() + opOperand->getOperandNumber()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the indexing map for a `result`. + }], + /*retTy=*/"AffineMap", + /*methodName=*/"getTiedIndexingMapForResult", + /*args=*/(ins "OpResult":$result), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(result.getOwner() == this->getOperation()); + auto indexingMaps = + $_op.indexing_maps().template getAsValueRange(); + return *(indexingMaps.begin() + getNumInputs() + + result.getResultNumber()); + }] + >, InterfaceMethod< /*desc=*/[{ Return the result tied to `opOperand`. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -73,10 +73,21 @@ const ControlElementwiseOpsFusionFn &controlFoldingReshapes = skipUnitDimReshape); +/// Patterns to fold an expanding tensor.expand_shape operation with its +/// producer generic operation by collapsing the dimensions of the generic op. +void populateFoldReshapeOpsByCollapsingPatterns( + RewritePatternSet &patterns, + const ControlElementwiseOpsFusionFn &controlFoldingReshapes = + [](const OpResult & /*producer*/, OpOperand & /*consumer*/) { + return true; + }); + /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its /// producer (consumer) generic operation by linearizing the indexing map used /// to access the source (target) of the reshape operation in the generic /// operation. +/// TODO(ravishankarm): These patterns are to be deprecated in favor of using +/// the `populateFoldReshapeByCollapsingPatterns`. void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns); /// Patterns to fold a collapsing (expanding) tensor_reshape operation with its @@ -84,6 +95,8 @@ /// to access the source (target) of the reshape operation in the generic /// operation. The patterns are applied only when the tensor reshape involved is /// collapsing (introducing) unit-extent dimensions. +/// TODO(ravishankarm): These patterns are to be deprecated in favor of using +/// the `populateFoldReshapeByCollapsingPatterns`. void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( RewritePatternSet &patterns); @@ -153,6 +166,8 @@ /// Patterns to push reshape op towards the end of the graph in order to expose /// more fusion opportunities. +/// TODO(ravishankarm): These patterns are to be deprecated in favor of using +/// the `populateFoldReshapeByCollapsingPatterns`. void populatePushReshapeOpsPatterns(RewritePatternSet &patterns); /// Perform standalone tiling of a single LinalgOp by `tileSizes`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -396,10 +396,11 @@ // linearization of indexing maps. //===---------------------------------------------------------------------===// -// TODO(ravishankarm): These patterns need to be deprecated. The indexing maps +// TODO(ravishankarm): The indexing maps // these produce in the general case are detrimental to transformations. -// They are useful now only in the limited case of unit-dimension folding. -// Remove these in favor of more general folding by dimension contraction. +// These patterns are on deprecation path in favor of using fusion by +// collapsing, which covers the only legitimate use case of this pattern of +// folding unit-extent dims. /// Linearize the expressions in `sourceMap` based on the `reassociationMaps` /// provided, given the shape of the source tensor that corresponds to the @@ -1144,11 +1145,470 @@ }; } // namespace +//===---------------------------------------------------------------------===// +// Methods and patterns to fuse reshape with linalg.generic operations by +// contraction of dimensions. +//===---------------------------------------------------------------------===// + +/// For an `indexingMap` that is a projected permutation, if the range is to be +/// collapsed using the given `reassociation`, get the reassociation in the +/// domain that would keep the map a projected permutation. +static SmallVector +getDomainReassociation(AffineMap indexingMap, + ArrayRef rangeReassociation) { + assert(indexingMap.isProjectedPermutation() && + "expected projected permutation map"); + unsigned counter = 0; + SmallVector domainReassociation; + llvm::SmallDenseSet processedDomainDims; + // Iterate over the reassociation indices. + for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) { + ReassociationIndices foldedDomainDims; + for (auto rangeDim : foldedRangeDims) { + (void)rangeDim; + AffineDimExpr dimExpr = + indexingMap.getResult(counter++).cast(); + foldedDomainDims.push_back(dimExpr.getPosition()); + processedDomainDims.insert(dimExpr.getPosition()); + } + domainReassociation.emplace_back(std::move(foldedDomainDims)); + } + // Fill in the missing domain dims. + for (auto dim : llvm::seq(0, indexingMap.getNumDims())) { + if (processedDomainDims.count(dim)) + continue; + ReassociationIndices vec = {dim}; + domainReassociation.emplace_back(std::move(vec)); + } + + // Sort the reassociation using the first dimension of the folded range to + // not create unnecessary transposes. + llvm::sort(domainReassociation, + [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { + return lhs[0] < rhs[0]; + }); + return domainReassociation; +} + +/// For a given `dimSequence`, check if the sequence is conserved in the +/// `indexingMap`. `indexingMap` is expected to be a projected permutation. +/// Non-existense of the sequence returns true as well. +static bool isDimSequencePreserved(AffineMap indexingMap, + ReassociationIndicesRef dimSequence) { + assert(!dimSequence.empty() && + "expected non-empty list for dimension sequence"); + assert(indexingMap.isProjectedPermutation() && + "expected indexing map to be projected permutation"); + + llvm::SmallDenseSet sequenceElements; + sequenceElements.insert(dimSequence.begin(), dimSequence.end()); + + unsigned dimSequenceStart = dimSequence[0]; + for (auto expr : enumerate(indexingMap.getResults())) { + unsigned dimInMapStart = expr.value().cast().getPosition(); + // 1. Check if this start of the sequence. + if (dimInMapStart == dimSequenceStart) { + if (expr.index() + dimSequence.size() > indexingMap.getNumResults()) + return false; + // 1a. Check if sequence is preserved. + for (auto dimInSequence : enumerate(dimSequence)) { + unsigned dimInMap = + indexingMap.getResult(expr.index() + dimInSequence.index()) + .cast() + .getPosition(); + if (dimInMap != dimInSequence.value()) + return false; + } + // Found the sequence. Projected permutation + // enforces that all AffineDimExprs in the result are unique, so no + // further checks are needed. + return true; + } + // 2. If position in the expr (which is of type AffineDimExpr) is part + // of sequence, return false here. This implies the entire sequence does not + // exist in the indexing map. + if (sequenceElements.count(dimInMapStart)) + return false; + } + // 3. No element of sequence found. Return true. + return true; +} + +// Check if a generic op can be fused along an operand by collapsing dimensions. +static bool isFusableWithReshapeByDimCollapse( + GenericOp genericOp, OpOperand *fusableOperand, + ArrayRef reassociation) { + // Some basic checks for this fusion to be valid. + if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) + return false; + + if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { + return map.isProjectedPermutation(); + })) { + return false; + } + + // Get the reassociation for the iteration space. + SmallVector iterationReassociation = + getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand), + reassociation); + if (iterationReassociation.empty()) { + // If the domain reassociation indices is empty, then this is a scalar op. + // Nothing to do. + return false; + } + + auto iteratorTypes = genericOp.iterator_types().getValue(); + ArrayRef iteratorTypesRef(iteratorTypes); + for (ReassociationIndicesRef foldedIterDims : iterationReassociation) { + // Check that all indexing maps, the folded dimensions sequence is + // preserved. + if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { + return isDimSequencePreserved(indexingMap, foldedIterDims); + })) + return false; + unsigned startDim = foldedIterDims[0]; + ArrayRef foldedIteratorTypes = + iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size()); + // Check that all folded iterator types are either all parallel, or all + // reduction. + if (!llvm::all_of( + foldedIteratorTypes, + [](Attribute attr) { return isParallelIterator(attr); }) && + !llvm::all_of(foldedIteratorTypes, + [](Attribute attr) { return isReductionIterator(attr); })) + return false; + } + return true; +} + +/// Helper class to carry state while collapsing the `linalg.generic` op. +namespace { +class CollapsingInfo { +public: + CollapsingInfo(SmallVector &&reassociation) { + iterationReassociation = std::move(reassociation); + for (auto foldedIterDims : enumerate(iterationReassociation)) { + foldedDimStartToSequenceMap[foldedIterDims.value()[0]] = + foldedIterDims.index(); + } + } + + // Returns the iteration space reassociation. + ArrayRef getReassociationIndices() { + return iterationReassociation; + } + + // Returns true if the given dimension is the start of a sequence of folded + // dimensions. + bool isDimStartOfFoldedDims(unsigned dim) { + return foldedDimStartToSequenceMap.count(dim); + } + + // Return the folded dimensions starting at `dim`. + ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) { + assert(foldedDimStartToSequenceMap.count(dim) && + "invalid start dim of folded dim " + "sequence"); + return iterationReassociation[foldedDimStartToSequenceMap[dim]]; + } + + // For a dim in the original op, return the dim in the collapsed op, that it + // is mapped to. Expectes `dim` to be start of a folded dimension sequence. + unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) { + assert(foldedDimStartToSequenceMap.count(dim) && + "invalid start dim of folded dim sequence"); + return foldedDimStartToSequenceMap[dim]; + } + +private: + /// Reassociation describing the folded iteration space dimensions. + SmallVector iterationReassociation; + + /// Map from the starting dimension folded to the sequence of folded + /// dimension. They value refers to the position in `iterationReassociation`. + llvm::DenseMap foldedDimStartToSequenceMap; +}; +} // namespace + +/// Get the iterator types for the collapsed operation given the original +/// iterator types and collapsed dimensions. +static SmallVector +getCollapsedOpIteratorTypes(ArrayRef iteratorTypes, + CollapsingInfo &collapsingInfo) { + SmallVector collapsedIteratorTypes; + for (ReassociationIndicesRef foldedIterDims : + collapsingInfo.getReassociationIndices()) { + assert(!foldedIterDims.empty() && + "reassociation indices expected to have non-empty sets"); + // Just pick the iterator type of the first folded dim. Pre-condition checks + // expected to have checked that iterator types of all folded dimensions are + // the same. + collapsedIteratorTypes.push_back( + iteratorTypes[foldedIterDims[0]].cast().getValue()); + } + return collapsedIteratorTypes; +} + +/// Compute the indexing map the collapsed op that corresponds to the given +/// `indexingMap` of the original operation. +static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, + CollapsingInfo &collapsingInfo) { + MLIRContext *context = indexingMap.getContext(); + assert(indexingMap.isProjectedPermutation() && + "expected indexing map to be projected permutation"); + SmallVector resultExprs; + for (auto expr : indexingMap.getResults()) { + unsigned dim = expr.cast().getPosition(); + if (collapsingInfo.isDimStartOfFoldedDims(dim)) { + resultExprs.push_back(getAffineDimExpr( + collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim), + context)); + } + } + return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0, + resultExprs, context); +} + +/// Return the `reassociation` indices to use to collapse the operand when the +/// iteration space of a generic op is collapsed. +static SmallVector +getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) { + unsigned counter = 0; + SmallVector operandReassociation; + for (auto expr : indexingMap.getResults()) { + unsigned dim = expr.cast().getPosition(); + if (collapsingInfo.isDimStartOfFoldedDims(dim)) { + unsigned numFoldedDims = + collapsingInfo.getFoldedDimsStartingAt(dim).size(); + auto range = llvm::seq(counter, counter + numFoldedDims); + operandReassociation.emplace_back(range.begin(), range.end()); + counter += numFoldedDims; + } + } + return operandReassociation; +} + +/// Get the new value to use for a given `OpOperand` in the collapsed operation. +static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, + OpOperand *opOperand, + CollapsingInfo &collapsingInfo, + OpBuilder &builder) { + AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + SmallVector operandReassociation = + getOperandReassociation(indexingMap, collapsingInfo); + + // If the number of entries in the reassocation for the operand is same as the + // number of results of the indexing map, then nothing to do for this operand. + Value operand = opOperand->get(); + if (operandReassociation.size() == indexingMap.getNumResults()) + return operand; + + // Insert a reshape to collapse the dimensions. + auto reshapeOp = builder.create( + loc, operand, operandReassociation); + return reshapeOp.getResult(); +} + +/// Modify the `linalg.index` operations in the original generic op, to its +/// value in the collapsed operation. +void generateCollapsedIndexingRegion(Location loc, Block *block, + CollapsingInfo &collapsingInfo, + ValueRange loopRange, + PatternRewriter &rewriter) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointToStart(block); + + // Collect all the original index ops. + auto indexOps = llvm::to_vector(block->getOps()); + + // For each folded dimension list resolve the original induction variable + // values in terms of the folded dimension induction variable. + // i_{folded} = (i_0 * d1 + i1) * d2 + i2. + // can be inverted to + // i2 = i_{folded} % d2 + // i1 = (i_{folded} / d2) % d1 + // i0 = i_{folded} / (d1 * d2) + llvm::DenseMap indexReplacementVals; + for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) { + ReassociationIndicesRef foldedDimsRef(foldedDims.value()); + Value newIndexVal = + rewriter.create(loc, foldedDims.index()); + for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { + indexReplacementVals[dim] = + rewriter.create(loc, newIndexVal, loopRange[dim]); + newIndexVal = + rewriter.create(loc, newIndexVal, loopRange[dim]); + } + indexReplacementVals[foldedDims.value().front()] = newIndexVal; + } + + for (auto indexOp : indexOps) { + auto dim = indexOp.dim(); + rewriter.replaceOp(indexOp, indexReplacementVals[dim]); + } +} + +/// Implementation of fusion with reshape operation by collapsing dimensions. +static Optional> +fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp, + OpOperand *fusableOpOperand, + PatternRewriter &rewriter) { + SmallVector reassociation = + isa(reshapeOp) + ? cast(reshapeOp).getReassociationIndices() + : cast(reshapeOp).getReassociationIndices(); + assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand, + reassociation) && + "preconditions for fusing with reshape by collapse failed"); + + CollapsingInfo collapsingInfo(getDomainReassociation( + genericOp.getTiedIndexingMap(fusableOpOperand), reassociation)); + // Check for trivial no transformation cases. In that case return nothing. + if (collapsingInfo.getReassociationIndices().size() == + genericOp.getNumLoops()) + return llvm::None; + + // Get the iterator types for the operand. + SmallVector iteratorTypes = getCollapsedOpIteratorTypes( + genericOp.iterator_types().getValue(), collapsingInfo); + + // Get the indexing maps. + auto indexingMaps = llvm::to_vector( + llvm::map_range(genericOp.getIndexingMaps(), [&](AffineMap map) { + return getCollapsedOpIndexingMap(map, collapsingInfo); + })); + + Location loc = + rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()}); + + // Get the input operands. + auto inputOperands = llvm::to_vector( + llvm::map_range(genericOp.getInputOperands(), [&](OpOperand *opOperand) { + return getCollapsedOpOperand(loc, genericOp, opOperand, collapsingInfo, + rewriter); + })); + + // Get the output operands and result types. + SmallVector resultTypes; + SmallVector outputOperands; + resultTypes.reserve(genericOp.getNumOutputs()); + outputOperands.reserve(genericOp.getNumOutputs()); + for (OpOperand *output : genericOp.getOutputOperands()) { + Value newOutput = + getCollapsedOpOperand(loc, genericOp, output, collapsingInfo, rewriter); + outputOperands.push_back(newOutput); + resultTypes.push_back(newOutput.getType()); + } + + // Create the generic op. + auto collapsedGenericOp = rewriter.create( + loc, resultTypes, inputOperands, outputOperands, indexingMaps, + iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); + Block *origOpBlock = &genericOp->getRegion(0).front(); + Block *collapsedOpBlock = &collapsedGenericOp->getRegion(0).front(); + rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, + collapsedOpBlock->getArguments()); + + if (collapsedGenericOp.hasIndexSemantics()) { + // Collect the loop range of the generic op. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(collapsedGenericOp); + SmallVector loopRanges = + cast(genericOp.getOperation()) + .createLoopRanges(rewriter, genericOp.getLoc()); + assert(llvm::all_of(loopRanges, + [](Range range) { + return matchPattern(range.offset, m_Zero()) && + matchPattern(range.stride, m_One()); + }) && + "expected all loop ranges to have zero start and unit stride"); + SmallVector loopBound = llvm::to_vector( + llvm::map_range(loopRanges, [](Range range) { return range.size; })); + generateCollapsedIndexingRegion(loc, + &collapsedGenericOp->getRegion(0).front(), + collapsingInfo, loopBound, rewriter); + } + + // Insert expanding reshape for the result to get back the original result + // type. + SmallVector results; + for (auto originalResult : llvm::enumerate(genericOp->getResults())) { + Value collapsedOpResult = + collapsedGenericOp->getResult(originalResult.index()); + auto originalResultType = + originalResult.value().getType().cast(); + auto collapsedOpResultType = collapsedOpResult.getType().cast(); + if (collapsedOpResultType.getRank() != originalResultType.getRank()) { + AffineMap indexingMap = + genericOp.getTiedIndexingMapForResult(originalResult.value()); + SmallVector reassociation = + getOperandReassociation(indexingMap, collapsingInfo); + Value result = rewriter.create( + loc, originalResultType, collapsedOpResult, reassociation); + results.push_back(result); + } else { + results.push_back(collapsedOpResult); + } + } + return results; +} + +namespace { + +/// Pattern to fuse a tensor.expand_shape op with its consumer generic op by +/// contracting dimensions of the loop. +class FoldWithProducerReshapeOpByCollapsing + : public OpRewritePattern { +public: + FoldWithProducerReshapeOpByCollapsing( + MLIRContext *context, ControlElementwiseOpsFusionFn foldReshapes, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlFoldingReshapes(std::move(foldReshapes)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + for (OpOperand *opOperand : genericOp.getInputTensorOperands()) { + tensor::ExpandShapeOp reshapeOp = + opOperand->get().getDefiningOp(); + if (!reshapeOp) + continue; + + if (!isFusableWithReshapeByDimCollapse( + genericOp, opOperand, reshapeOp.getReassociationIndices()) || + !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { + continue; + } + + Optional> replacements = fuseWithReshapeByCollapsing( + genericOp, reshapeOp, opOperand, rewriter); + if (!replacements) { + return rewriter.notifyMatchFailure( + genericOp, "failed to do the fusion by collapsing transformation"); + } + + rewriter.replaceOp(genericOp, replacements.getValue()); + return success(); + } + return failure(); + } + +private: + ControlElementwiseOpsFusionFn controlFoldingReshapes; +}; +} // namespace + //===---------------------------------------------------------------------===// // Methods and patterns to convert tensor.expand_shape -> linalg.generic // into linalg.generic -> tensor.expand_shape, i.e. push the reshape down. //===---------------------------------------------------------------------===// +// TODO(ravishankarm): This pattern is to be deprecated in favor of fusion by +// collapsing that provides a more general functionality. This pattern is very +// specific to a particular use case. The fusion by collapsing can provide the +// same control to clients using the control function there. + static SmallVector getReassociationIndices(ArrayRef maps) { SmallVector reassociation; @@ -1785,6 +2245,13 @@ controlFoldingReshapes); } +void mlir::linalg::populateFoldReshapeOpsByCollapsingPatterns( + RewritePatternSet &patterns, + const ControlElementwiseOpsFusionFn &controlFoldingReshapes) { + patterns.add(patterns.getContext(), + controlFoldingReshapes); +} + void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { auto *context = patterns.getContext(); diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -0,0 +1,401 @@ +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=fuse-with-reshape-by-collapsing-control -split-input-file | FileCheck %s --check-prefix=CONTROL + +// Static problem sizes. Checks all aspects of fusion by collapsing. Rest of the +// tests only check a subset of conditions. +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> +func @fuse_by_collapsing(%arg0 : tensor<2x12x5x336x9xi32>, + %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] + : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32> + %generic = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>) + outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %t0 = arith.addi %b0, %b1 : i32 + %t1 = arith.addi %t0, %b2 : i32 + linalg.yield %t1 : i32 + } -> tensor<2x3x4x5x6x7x8x9xi32> + return %generic : tensor<2x3x4x5x6x7x8x9xi32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> +// CHECK: func @fuse_by_collapsing( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32> +// CHECK-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32> +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1, 2]{{\]}} +// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} +// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP0]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] : +// CHECK_SAME: outs(%[[INIT_RESHAPE]] : +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2], [3], [4, 5, 6], [7]{{\]}} +// CHECK: return %[[RESULT_RESHAPE]] + +// CONTROL: func @fuse_by_collapsing( +// CONTROL-SAME: %[[ARG0:.+]]: tensor<2x12x5x336x9xi32> +// CONTROL-SAME: %[[ARG1:.+]]: tensor<2x3x4xi32> +// CONTROL-SAME: %[[ARG2:.+]]: tensor<5x6x7x8xi32> +// CONTROL: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CONTROL: %[[GENERIC:.+]] = linalg.generic +// CONTROL-SAME: ins(%[[EXPAND]], +// CONTROL: return %[[GENERIC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> +func @fuse_by_collapsing_indexing_op(%arg0 : tensor<2x12x5x336x9xi32>, + %arg1 : tensor<2x3x4xi32>, %arg2 : tensor<5x6x7x8xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] + : tensor<2x12x5x336x9xi32> into tensor<2x3x4x5x6x7x8x9xi32> + %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32> + %generic = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%expand, %arg1, %arg2 : tensor<2x3x4x5x6x7x8x9xi32>, tensor<2x3x4xi32>, tensor<5x6x7x8xi32>) + outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %iv0 = linalg.index 0: index + %iv1 = linalg.index 1: index + %t0 = arith.addi %iv0, %iv1 : index + %iv2 = linalg.index 2 : index + %t1 = arith.addi %t0, %iv2 : index + %iv3 = linalg.index 3 : index + %t2 = arith.addi %t1, %iv3 : index + %iv4 = linalg.index 4 : index + %t3 = arith.addi %t2, %iv4 : index + %iv5 = linalg.index 5 : index + %t4 = arith.addi %t3, %iv5 : index + %iv6 = linalg.index 6 : index + %t5 = arith.addi %t4, %iv6 : index + %iv7 = linalg.index 7 : index + %t6 = arith.addi %t5, %iv7 : index + %yield = arith.index_cast %t6 : index to i32 + linalg.yield %yield : i32 + } -> tensor<2x3x4x5x6x7x8x9xi32> + return %generic : tensor<2x3x4x5x6x7x8x9xi32> +} +// CHECK-LABEL: func @fuse_by_collapsing_indexing_op( +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[C7:.+]] = arith.constant 7 : index +// CHECK: %[[IV0:.+]] = linalg.index 0 +// CHECK: %[[IV1:.+]] = linalg.index 1 +// CHECK: %[[REM_IV1:.+]] = arith.remui %[[IV1]], %[[C4]] +// CHECK: %[[DIV_IV1:.+]] = arith.divui %[[IV1]], %[[C4]] +// CHECK: %[[IV2:.+]] = linalg.index 2 +// CHECK: %[[IV3:.+]] = linalg.index 3 +// CHECK: %[[REM1_IV3:.+]] = arith.remui %[[IV3]], %[[C8]] +// CHECK: %[[DIV1_IV3:.+]] = arith.divui %[[IV3]], %[[C8]] +// CHECK: %[[REM2_IV3:.+]] = arith.remui %[[DIV1_IV3]], %[[C7]] +// CHECK: %[[DIV2_IV3:.+]] = arith.divui %[[DIV1_IV3]], %[[C7]] +// CHECK: %[[IV4:.+]] = linalg.index 4 +// CHECK: %[[T0:.+]] = arith.addi %[[IV0]], %[[DIV_IV1]] +// CHECK: %[[T1:.+]] = arith.addi %[[T0]], %[[REM_IV1]] +// CHECK: %[[T2:.+]] = arith.addi %[[T1]], %[[IV2]] +// CHECK: %[[T3:.+]] = arith.addi %[[T2]], %[[DIV2_IV3]] +// CHECK: %[[T4:.+]] = arith.addi %[[T3]], %[[REM2_IV3]] +// CHECK: %[[T5:.+]] = arith.addi %[[T4]], %[[REM1_IV3]] +// CHECK: %[[T6:.+]] = arith.addi %[[T5]], %[[IV4]] +// CHECK: %[[YIELD:.+]] = arith.index_cast %[[T6]] +// CHECK: linalg.yield %[[YIELD]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d5, d6, d0, d1, d2, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d0)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> +func @fuse_by_collapsing_change_reshape_order(%arg0 : tensor<9x56x2x60x6xi32>, + %arg1 : tensor<7x8x2xi32>, %arg2 : tensor<6x3x4x5xi32>) -> tensor<2x3x4x5x6x7x8x9xi32> { + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] + : tensor<9x56x2x60x6xi32> into tensor<9x7x8x2x3x4x5x6xi32> + %init = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] : tensor<2x3x4x5x6x7x8x9xi32> + %generic = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%expand, %arg1, %arg2 : tensor<9x7x8x2x3x4x5x6xi32>, tensor<7x8x2xi32>, tensor<6x3x4x5xi32>) + outs(%init : tensor<2x3x4x5x6x7x8x9xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %t0 = arith.addi %b0, %b1 : i32 + %t1 = arith.addi %t0, %b2 : i32 + linalg.yield %t1 : i32 + } -> tensor<2x3x4x5x6x7x8x9xi32> + return %generic : tensor<2x3x4x5x6x7x8x9xi32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4, d3, d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d3, d0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d2, d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> +// CHECK: func @fuse_by_collapsing_change_reshape_order( +// CHECK-SAME: %[[ARG0:.+]]: tensor<9x56x2x60x6xi32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<7x8x2xi32> +// CHECK-SAME: %[[ARG2:.+]]: tensor<6x3x4x5xi32> +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5, 6, 7, 8, 9] +// CHECK-DAG: %[[ARG1_RESHAPE:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0, 1], [2]{{\]}} +// CHECK-DAG: %[[ARG2_RESHAPE:.+]] = tensor.collapse_shape %[[ARG2]] {{\[}}[0], [1, 2, 3]{{\]}} +// CHECK-DAG: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}} +// CHECK: %[[COLLAPSED_OP:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_RESHAPE]], %[[ARG2_RESHAPE]] : +// CHECK_SAME: outs(%[[INIT_RESHAPE]] : +// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[COLLAPSED_OP]] {{\[}}[0], [1, 2, 3], [4], [5, 6], [7]{{\]}} +// CHECK: return %[[RESULT_RESHAPE]] + +// ----- + +// Dynamic case. Only checks things not covered by `fuse_by_collapsing` test above. +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d5, d6, d0, d1, d2, d3, d4)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d0)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d1, d2, d3)> +#map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4, d5, d6, d7)> +func @fuse_by_collapsing_dynamic(%arg0 : tensor, + %arg1 : tensor, %arg2 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %expand = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6], [7]] + : tensor into tensor + %d0 = tensor.dim %arg1, %c2 : tensor + %d2 = tensor.dim %arg2, %c2 : tensor + %d4 = tensor.dim %arg2, %c0 : tensor + %d6 = tensor.dim %arg1, %c1 : tensor + %d7 = tensor.dim %arg0, %c0 : tensor + %init = linalg.init_tensor [%d0, 3, %d2, 5, %d4, 7, %d6, %d7] : tensor + %generic = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%expand, %arg1, %arg2 : tensor, tensor, tensor) + outs(%init : tensor) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %iv0 = linalg.index 0: index + %iv1 = linalg.index 1: index + %t0 = arith.addi %iv0, %iv1 : index + %iv2 = linalg.index 2 : index + %t1 = arith.addi %t0, %iv2 : index + %iv3 = linalg.index 3 : index + %t2 = arith.addi %t1, %iv3 : index + %iv4 = linalg.index 4 : index + %t3 = arith.addi %t2, %iv4 : index + %iv5 = linalg.index 5 : index + %t4 = arith.addi %t3, %iv5 : index + %iv6 = linalg.index 6 : index + %t5 = arith.addi %t4, %iv6 : index + %iv7 = linalg.index 7 : index + %t6 = arith.addi %t5, %iv7 : index + %yield = arith.index_cast %t6 : index to i32 + linalg.yield %yield : i32 + } -> tensor + return %generic : tensor +} +// CHECK: func @fuse_by_collapsing_dynamic( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[C5:.+]] = arith.constant 5 : index +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[EXPAND]], %[[C2]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[EXPAND]], %[[C5]] +// CHECK: linalg.generic +// CHECK: %[[IV0:.+]] = linalg.index 1 +// CHECK: %[[REM1_IV0:.+]] = arith.remui %[[IV0]], %[[C5]] +// CHECK: %[[DIV1_IV0:.+]] = arith.divui %[[IV0]], %[[C5]] +// CHECK: %[[REM2_IV0:.+]] = arith.remui %[[DIV1_IV0]], %[[D1]] +// CHECK: %[[DIV2_IV0:.+]] = arith.divui %[[DIV1_IV0]], %[[D1]] +// CHECK: %[[IV1:.+]] = linalg.index 3 +// CHECK: %[[REM1_IV1:.+]] = arith.remui %[[IV1]], %[[D0]] +// CHECK: %[[DIV1_IV1:.+]] = arith.divui %[[IV1]], %[[D0]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3)> +func @fuse_reductions(%arg0 : tensor<2x?x5xf32>, %arg1 : tensor<2x5xf32>) -> tensor<2x5xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x?x5xf32> into tensor<2x6x?x5xf32> + %1 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "reduction", "reduction", "parallel"]} + ins(%0 : tensor<2x6x?x5xf32>) outs(%arg1 : tensor<2x5xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<2x5xf32> + return %1 : tensor<2x5xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK: func @fuse_reductions( +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x?x5xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<2x5xf32>) -> tensor<2x5xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"] +// CHECK-SAME: ins(%[[ARG0]] : tensor<2x?x5xf32>) +// CHECK-SAME: outs(%[[ARG1]] : tensor<2x5xf32>) + +// ----- + +// Test no fusion because the folded dimensions are not all preserved. +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +func @no_fuse_unpreserved_folding(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x3x4x5xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> + %init = linalg.init_tensor [2, 3, 4, 5] : tensor<2x3x4x5xf32> + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor<2x3x4x5xf32>, tensor<2x3xf32>) outs(%init : tensor<2x3x4x5xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<2x3x4x5xf32> + return %1 : tensor<2x3x4x5xf32> +} +// CHECK: func @no_fuse_unpreserved_folding +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<2x3xf32> +// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : +// CHECK: return %[[GENERIC]] + +// ----- + +// Test no fusion because the folded dimensions are not all preserved. +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)> +func @no_fuse_unpreserved_folding_transpose(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2xf32>) -> tensor<2x4x3x5xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> + %init = linalg.init_tensor [2, 4, 3, 5] : tensor<2x4x3x5xf32> + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor<2x3x4x5xf32>, tensor<2xf32>) outs(%init : tensor<2x4x3x5xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<2x4x3x5xf32> + return %1 : tensor<2x4x3x5xf32> +} +// CHECK: func @no_fuse_unpreserved_folding_transpose +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<2xf32> +// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : +// CHECK: return %[[GENERIC]] + +// ----- + +// Test no fusion because the iterator types of folded dims are not preserved. +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d3)> +func @no_fuse_mismatched_iterator_types(%arg0 : tensor<2x12x5xf32>, %arg1 : tensor<2x3xf32>) -> tensor<2x5xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3]] : tensor<2x12x5xf32> into tensor<2x3x4x5xf32> + %init = linalg.init_tensor [2, 5] : tensor<2x5xf32> + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "reduction", "parallel", "parallel"]} + ins(%0, %arg1 : tensor<2x3x4x5xf32>, tensor<2x3xf32>) outs(%init : tensor<2x5xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<2x5xf32> + return %1 : tensor<2x5xf32> +} +// CHECK: func @no_fuse_mismatched_iterator_types +// CHECK-SAME: %[[ARG0:.+]]: tensor<2x12x5xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<2x3xf32> +// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : +// CHECK: return %[[GENERIC]] + +// ----- + +// Test control of fusion using control function +// Test no fusion because the folded dimensions are not all preserved. +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @control_fusion(%arg0 : tensor<6xf32>, %arg1 : tensor<20xf32>) -> tensor<2x3x4x5xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1]] : tensor<6xf32> into tensor<2x3xf32> + %1 = tensor.expand_shape %arg1 [[0, 1]] : tensor<20xf32> into tensor<4x5xf32> + %init = linalg.init_tensor [2, 3, 4, 5] : tensor<2x3x4x5xf32> + %2 = linalg.generic { + indexing_maps = [#map0, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %1 : tensor<2x3xf32>, tensor<4x5xf32>) outs(%init : tensor<2x3x4x5xf32>) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): + %3 = arith.addf %b0, %b1 : f32 + linalg.yield %3 : f32 + } -> tensor<2x3x4x5xf32> + return %2 : tensor<2x3x4x5xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @control_fusion( +// CHECK-SAME: %[[ARG0:.+]]: tensor<6xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<20xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%{{.+}}: tensor<6x20xf32>) +// CHECK: %[[RESHAPE1:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1, 2]{{\]}} +// CHECK: %[[RESHAPE2:.+]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1], [2], [3]{{\]}} +// CHECK: return %[[RESHAPE2]] + +// CONTROL-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CONTROL-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d2)> +// CONTROL-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CONTROL: func @control_fusion( +// CONTROL-SAME: %[[ARG0:.+]]: tensor<6xf32> +// CONTROL-SAME: %[[ARG1:.+]]: tensor<20xf32> +// CONTROL: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CONTROL: %[[INIT:.+]] = linalg.init_tensor [2, 3, 4, 5] +// CONTROL: %[[INIT_RESHAPE:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0], [1], [2, 3]{{\]}} +// CONTROL: %[[GENERIC:.+]] = linalg.generic +// CONTROL-SAME: ins(%[[EXPAND]], %[[ARG1]] : +// CONTROL-SAME: outs(%[[INIT_RESHAPE]] : +// CONTROL: %[[RESULT:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}} + +// ----- + +// Corner case that isnt handled currently. +#map = affine_map<(d0) -> (d0)> +func @zero_D_test(%arg0: tensor) -> tensor<1xf32> { + %0 = tensor.expand_shape %arg0 [] : tensor into tensor<1xf32> + %init = linalg.init_tensor [1] : tensor<1xf32> + %1 = linalg.generic { + indexing_maps = [#map, #map], + iterator_types = ["parallel"]} + ins(%0: tensor<1xf32>) outs(%init : tensor<1xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + linalg.yield %b0: f32 + } -> tensor<1xf32> + return %1 : tensor<1xf32> +} +// CHECK: func @zero_D_test +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[EXPAND]] : +// CHECK: return %[[GENERIC]] diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -78,6 +78,19 @@ "to generic -> expand_shape pattern"), llvm::cl::init(false)}; + Option fuseWithReshapeByCollapsing{ + *this, "fuse-with-reshape-by-collapsing", + llvm::cl::desc("Test linalg expand_shape -> generic fusion patterns that " + "collapse the iteration space of the consumer"), + llvm::cl::init(false)}; + + Option fuseWithReshapeByCollapsingWithControlFn{ + *this, "fuse-with-reshape-by-collapsing-control", + llvm::cl::desc("Test controlling the linalg expand_shape -> generic " + "fusion patterns that " + "collapse the iteration space of the consumer"), + llvm::cl::init(false)}; + void runOnOperation() override { MLIRContext *context = &this->getContext(); FuncOp funcOp = this->getOperation(); @@ -129,6 +142,26 @@ linalg::populatePushReshapeOpsPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } + + if (fuseWithReshapeByCollapsing) { + RewritePatternSet patterns(context); + linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + } + + if (fuseWithReshapeByCollapsingWithControlFn) { + RewritePatternSet patterns(context); + linalg::ControlElementwiseOpsFusionFn controlFn = + [](const OpResult &producer, OpOperand &consumer) -> bool { + if (isa(producer.getDefiningOp())) { + // Skip fusing the first operand + return consumer.getOperandNumber(); + } + return true; + }; + linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + } } };