diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1151,43 +1151,24 @@ // contraction of dimensions. //===---------------------------------------------------------------------===// -/// For an `indexingMap` that is a projected permutation, if the range is to be -/// collapsed using the given `reassociation`, get the reassociation in the -/// domain that would keep the map a projected permutation. -static SmallVector +/// For a given list of indices in the range of the `indexingMap` that are +/// folded, return the indices of the corresponding domain. Return `llvm::None` +/// on failure. Ensures that all the elements of the returned reassociation are +/// distinct. +static ReassociationIndices getDomainReassociation(AffineMap indexingMap, - ArrayRef rangeReassociation) { + ReassociationIndicesRef rangeReassociation) { assert(indexingMap.isProjectedPermutation() && - "expected projected permutation map"); - unsigned counter = 0; - SmallVector domainReassociation; - llvm::SmallDenseSet processedDomainDims; - // Iterate over the reassociation indices. - for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) { - ReassociationIndices foldedDomainDims; - for (auto rangeDim : foldedRangeDims) { - (void)rangeDim; - AffineDimExpr dimExpr = - indexingMap.getResult(counter++).cast(); - foldedDomainDims.push_back(dimExpr.getPosition()); - processedDomainDims.insert(dimExpr.getPosition()); - } - domainReassociation.emplace_back(std::move(foldedDomainDims)); - } - // Fill in the missing domain dims. - for (auto dim : llvm::seq(0, indexingMap.getNumDims())) { - if (processedDomainDims.count(dim)) - continue; - ReassociationIndices vec = {dim}; - domainReassociation.emplace_back(std::move(vec)); - } + "expected projected permutation"); - // Sort the reassociation using the first dimension of the folded range to - // not create unnecessary transposes. - llvm::sort(domainReassociation, - [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { - return lhs[0] < rhs[0]; - }); + ReassociationIndices domainReassociation = llvm::to_vector<4>( + llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { + return indexingMap.getResults()[pos] + .cast() + .getPosition(); + })); + // The projected permutation semantics ensures that there is no repetition of + // the domain indices. return domainReassociation; } @@ -1235,100 +1216,238 @@ return true; } -// Check if a generic op can be fused along an operand by collapsing dimensions. -static bool isFusableWithReshapeByDimCollapse( - GenericOp genericOp, OpOperand *fusableOperand, - ArrayRef reassociation) { +// Return the list of dimensions of the iteration domain that can be +// collapsed to allow for fusion with the a producer that is an expand_shape +// operation. If all dimensions created by expansion can be collapsed in the +// iteration space then the reshape is defunct. +// +// Example: +// +// ```mlir +// #map = affine_map<(d0, d1) -> (d0, d1)> +// %1 = tensor.expand_shape %0 [[0, 1]] : tensor into tensor +// %2 = linalg.init_tensor [..] : tensor +// %3 = linalg.generic { +// indexing_maps = [#map, #map], +// iterator_types = ["parallel" ,"parallel"]} +// ins(%1 : tensor) outs(%2 : tensor) {.. } +// ``` +// +// can be fused by collapsing the dimensions of the iteration space. +// +// ```mlir +// #map = affine_map<(d0) -> (d0)> +// %2 = linalg.init_tensor [..] : tensor +// %3 = linalg.generic { +// indexing_maps = [#map, #map], +// iterator_types = ["parallel"]} +// ins(%1 : tensor) outs(%2 : tensor) {.. } +// %4 = tensor.expand_shape %3 [[0, 1]] : tensor into tensor +// ``` +// +// In the following example, +// +// ```mlir +// #map0 = affine_map<(d0, d1) -> (d0, d1)> +// #map1 = affine_map<(d0, d1) -> (d1, d0)> +// %1 = tensor.expand_shape %0 [[0, 1]] : tensor into tensor +// %2 = linalg.init_tensor [..] : tensor<4x?xf32> +// %2 = linalg.generic { +// indexing_maps = [#map0, #map1], +// iterator_types = ["parallel" ,"parallel"]} +// ins(%1 : tensor) outs(%2 : tensor<4x?xf32>) {.. } +// ``` +// +// the reshape cannot be fused with the generic op by collapsing the op +// dimensions since the indexing maps will have to contain mods and divs +// to preserve the accesses pattern. When no dimensions of the iteration +// space are collapsable and empty vector is returned. +static SmallVector +getCollapsableIterationSpaceDims(GenericOp genericOp, OpOperand *fusableOperand, + ArrayRef reassociation) { // Some basic checks for this fusion to be valid. if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) - return false; + return {}; if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { return map.isProjectedPermutation(); })) { - return false; + return {}; } - // Get the reassociation for the iteration space. - SmallVector iterationReassociation = - getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand), - reassociation); - if (iterationReassociation.empty()) { - // If the domain reassociation indices is empty, then this is a scalar op. - // Nothing to do. - return false; + // Compute all the loops with the reduction iterator types. + SmallVector reductionDims; + for (auto iteratorType : llvm::enumerate(genericOp.iterator_types())) { + if (isReductionIterator(iteratorType.value())) { + reductionDims.push_back(iteratorType.index()); + } } + llvm::SmallDenseSet processedIterationDims; + AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand); auto iteratorTypes = genericOp.iterator_types().getValue(); - ArrayRef iteratorTypesRef(iteratorTypes); - for (ReassociationIndicesRef foldedIterDims : iterationReassociation) { - // Check that for all indexing maps, the folded dimensions sequence is - // preserved. - if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { - return isDimSequencePreserved(indexingMap, foldedIterDims); + SmallVector iterationSpaceReassociation; + for (ReassociationIndicesRef foldedRangeDims : reassociation) { + assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); + + // Ignore dims that are not folded. + if (foldedRangeDims.size() == 1) + continue; + + ReassociationIndices foldedIterationSpaceDims = + getDomainReassociation(indexingMap, foldedRangeDims); + + // Check that the folded iteration dims do not contain already processed + // dims. + if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { + return processedIterationDims.count(dim); })) - return false; - unsigned startDim = foldedIterDims[0]; - ArrayRef foldedIteratorTypes = - iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size()); - // Check that all folded iterator types are either all parallel, or all - // reduction. - if (!llvm::all_of( - foldedIteratorTypes, - [](Attribute attr) { return isParallelIterator(attr); }) && - !llvm::all_of(foldedIteratorTypes, - [](Attribute attr) { return isReductionIterator(attr); })) - return false; + continue; + + // Check that all folded iterator types are all parallel or all reductions. + Attribute startIteratorType = iteratorTypes[foldedIterationSpaceDims[0]]; + if (!isParallelIterator(startIteratorType) && + !isReductionIterator(startIteratorType)) + continue; + if (llvm::any_of(foldedIterationSpaceDims, [&](int64_t dim) { + return iteratorTypes[dim] != startIteratorType; + })) + continue; + + // If the folded dimensions correspond to a "reduction" iterator type, + // the folded dimensions need to be "in-order". Strictly speaking this is + // not necessary, for reductions that are associative and commutative, but + // using a more strict definition of reduction for now. + if (isReductionIterator(startIteratorType)) { + bool isContiguous = false; + for (auto startDim : llvm::enumerate(reductionDims)) { + // Move window in `reductionDims` to start of the folded iteration dims. + if (startDim.value() != foldedIterationSpaceDims[0]) + continue; + // If sizes doesnt match, trivial not contiguous. This condition should + // not be hit. + if (startDim.index() + foldedIterationSpaceDims.size() > + reductionDims.size()) + break; + // Check that the contiguity is maintained. + isContiguous = true; + for (auto foldedDim : llvm::enumerate(foldedIterationSpaceDims)) { + if (reductionDims[foldedDim.index() + startDim.index()] != + foldedDim.value()) { + isContiguous = false; + break; + } + } + break; + } + if (!isContiguous) + continue; + } + + // Check that the sequence is preserved in all indexing maps. + if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { + return !isDimSequencePreserved(indexingMap, foldedIterationSpaceDims); + })) + continue; + + processedIterationDims.insert(foldedIterationSpaceDims.begin(), + foldedIterationSpaceDims.end()); + iterationSpaceReassociation.emplace_back( + std::move(foldedIterationSpaceDims)); } - return true; + + return iterationSpaceReassociation; } /// Helper class to carry state while collapsing the `linalg.generic` op. namespace { class CollapsingInfo { public: - CollapsingInfo(SmallVector &&reassociation) { - iterationReassociation = std::move(reassociation); - for (const auto &foldedIterDims : enumerate(iterationReassociation)) { - foldedDimStartToSequenceMap[foldedIterDims.value()[0]] = - foldedIterDims.index(); + LogicalResult initialize(unsigned origNumLoops, + ArrayRef foldedIterationDims) { + llvm::SmallDenseSet processedDims; + // Find all the dims that are folded. + for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { + if (foldedIterationDim.empty()) + continue; + // If the folded dims contain dims already folded, that's illegal + // specification. Repetition within a list is also illegal. + for (auto dim : foldedIterationDim) { + if (dim >= origNumLoops) + return failure(); + if (processedDims.count(dim)) + return failure(); + processedDims.insert(dim); + } + collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), + foldedIterationDim.end()); + } + if (processedDims.size() > origNumLoops) + return failure(); + + // Add all the preserved dims of the original op as single + // elements to `collapsedOpToOrigOpIterationDim`. + for (auto dim : llvm::seq(0, origNumLoops)) { + if (processedDims.count(dim)) + continue; + collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); } - } - // Returns the iteration space reassociation. - ArrayRef getReassociationIndices() { - return iterationReassociation; + llvm::sort(collapsedOpToOrigOpIterationDim, + [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { + return lhs[0] < rhs[0]; + }); + origOpToCollapsedOpIterationDim.resize(origNumLoops); + for (auto foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) { + for (auto dim : enumerate(foldedDims.value())) + origOpToCollapsedOpIterationDim[dim.value()] = + std::make_pair(foldedDims.index(), dim.index()); + } + return success(); } - // Returns true if the given dimension is the start of a sequence of folded - // dimensions. - bool isDimStartOfFoldedDims(unsigned dim) { - return foldedDimStartToSequenceMap.count(dim); + /// Return mapping from collapsed loop domain to original loop domain. + ArrayRef getCollapsedOpToOrigOpMapping() const { + return collapsedOpToOrigOpIterationDim; } - // Return the folded dimensions starting at `dim`. - ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) { - assert(foldedDimStartToSequenceMap.count(dim) && - "invalid start dim of folded dim " - "sequence"); - return iterationReassociation[foldedDimStartToSequenceMap[dim]]; + /// Return mapping from original loop domain to collapsed loop domain. The + /// mapping is a pair. First value is the dimension in the collapsed loop that + /// the original loop is mapped to. Second is the relative position in folded + /// list of this domain. For example if the original loop domain is 3D, and + /// the collapsed loop domain is folding all of it, i.e. + /// + /// ``` + /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` + /// ``` + /// + /// then + /// + /// ``` + /// origOpToCollapsedOpMapping[0] = {0, 0}; + /// origOpToCollapsedOpMapping[1] = {0, 1}; + /// origOpToCollapsedOpMapping[2] = {0, 2}; + /// origOpToCollapsedOpMapping[3] = {1, 0}; + /// origOpToCollapsedOpMapping[4] = {1, 1}; + /// ``` + /// + ArrayRef> getOrigOpToCollapsedOpMapping() const { + return origOpToCollapsedOpIterationDim; } - // For a dim in the original op, return the dim in the collapsed op, that it - // is mapped to. Expectes `dim` to be start of a folded dimension sequence. - unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) { - assert(foldedDimStartToSequenceMap.count(dim) && - "invalid start dim of folded dim sequence"); - return foldedDimStartToSequenceMap[dim]; + /// Return the collapsed op iteration domain rank. + unsigned getCollapsedOpIterationRank() const { + return collapsedOpToOrigOpIterationDim.size(); } private: - /// Reassociation describing the folded iteration space dimensions. - SmallVector iterationReassociation; + /// Map from the iteration domain index in collapsed op to the iteration + /// domain indices in the original op. + SmallVector collapsedOpToOrigOpIterationDim; - /// Map from the starting dimensions of the folded dimension sequences to - /// their index in `iterationReassociation`. - llvm::DenseMap foldedDimStartToSequenceMap; + /// Map from iteration domain index in the original op to the iteration domain + /// index in the collapsed op. + SmallVector> origOpToCollapsedOpIterationDim; }; } // namespace @@ -1336,10 +1455,10 @@ /// iterator types and collapsed dimensions. static SmallVector getCollapsedOpIteratorTypes(ArrayRef iteratorTypes, - CollapsingInfo &collapsingInfo) { + const CollapsingInfo &collapsingInfo) { SmallVector collapsedIteratorTypes; for (ReassociationIndicesRef foldedIterDims : - collapsingInfo.getReassociationIndices()) { + collapsingInfo.getCollapsedOpToOrigOpMapping()) { assert(!foldedIterDims.empty() && "reassociation indices expected to have non-empty sets"); // Just pick the iterator type of the first folded dim. Pre-condition checks @@ -1353,35 +1472,50 @@ /// Compute the indexing map in the collapsed op that corresponds to the given /// `indexingMap` of the original operation. -static AffineMap getCollapsedOpIndexingMap(AffineMap indexingMap, - CollapsingInfo &collapsingInfo) { +static AffineMap +getCollapsedOpIndexingMap(AffineMap indexingMap, + const CollapsingInfo &collapsingInfo) { MLIRContext *context = indexingMap.getContext(); assert(indexingMap.isProjectedPermutation() && "expected indexing map to be projected permutation"); SmallVector resultExprs; + auto origOpToCollapsedOpMapping = + collapsingInfo.getOrigOpToCollapsedOpMapping(); for (auto expr : indexingMap.getResults()) { unsigned dim = expr.cast().getPosition(); - if (collapsingInfo.isDimStartOfFoldedDims(dim)) { - resultExprs.push_back(getAffineDimExpr( - collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim), - context)); - } + // If the dim is not the first of the collapsed dim, do nothing. + if (origOpToCollapsedOpMapping[dim].second != 0) + continue; + // The next n-dims are guaranteed to be collapsed. So just use the + // iteration dimension of the collapsed op. + resultExprs.push_back( + getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); } - return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0, + return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, resultExprs, context); } /// Return the `reassociation` indices to use to collapse the operand when the /// iteration space of a generic op is collapsed. static SmallVector -getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) { +getOperandReassociation(AffineMap indexingMap, + const CollapsingInfo &collapsingInfo) { unsigned counter = 0; SmallVector operandReassociation; - for (auto expr : indexingMap.getResults()) { - unsigned dim = expr.cast().getPosition(); - if (collapsingInfo.isDimStartOfFoldedDims(dim)) { + auto origOpToCollapsedOpMapping = + collapsingInfo.getOrigOpToCollapsedOpMapping(); + auto collapsedOpToOrigOpMapping = + collapsingInfo.getCollapsedOpToOrigOpMapping(); + while (counter < indexingMap.getNumResults()) { + unsigned dim = + indexingMap.getResult(counter).cast().getPosition(); + if (origOpToCollapsedOpMapping[dim].second == 0) { + // This is the start of a collapsed dimensions of the iteration that + // is gauranteed to be preserved in the indexing map. The number of folded + // dims is obtained from the collapsed op to original op mapping. unsigned numFoldedDims = - collapsingInfo.getFoldedDimsStartingAt(dim).size(); + collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] + .size(); auto range = llvm::seq(counter, counter + numFoldedDims); operandReassociation.emplace_back(range.begin(), range.end()); counter += numFoldedDims; @@ -1393,7 +1527,7 @@ /// Get the new value to use for a given `OpOperand` in the collapsed operation. static Value getCollapsedOpOperand(Location loc, GenericOp genericOp, OpOperand *opOperand, - CollapsingInfo &collapsingInfo, + const CollapsingInfo &collapsingInfo, OpBuilder &builder) { AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); SmallVector operandReassociation = @@ -1414,7 +1548,7 @@ /// Modify the `linalg.index` operations in the original generic op, to its /// value in the collapsed operation. void generateCollapsedIndexingRegion(Location loc, Block *block, - CollapsingInfo &collapsingInfo, + const CollapsingInfo &collapsingInfo, ValueRange loopRange, PatternRewriter &rewriter) { OpBuilder::InsertionGuard g(rewriter); @@ -1431,7 +1565,8 @@ // i1 = (i_{folded} / d2) % d1 // i0 = i_{folded} / (d1 * d2) llvm::DenseMap indexReplacementVals; - for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) { + for (auto &foldedDims : + enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { ReassociationIndicesRef foldedDimsRef(foldedDims.value()); Value newIndexVal = rewriter.create(loc, foldedDims.index()); @@ -1451,24 +1586,22 @@ } /// Implementation of fusion with reshape operation by collapsing dimensions. -static Optional> -fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp, - OpOperand *fusableOpOperand, - PatternRewriter &rewriter) { - SmallVector reassociation = - isa(reshapeOp) - ? cast(reshapeOp).getReassociationIndices() - : cast(reshapeOp).getReassociationIndices(); - assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand, - reassociation) && - "preconditions for fusing with reshape by collapse failed"); - - CollapsingInfo collapsingInfo(getDomainReassociation( - genericOp.getTiedIndexingMap(fusableOpOperand), reassociation)); - // Check for trivial no transformation cases. In that case return nothing. - if (collapsingInfo.getReassociationIndices().size() == - genericOp.getNumLoops()) - return llvm::None; +static FailureOr> collapseGenericOpIterationDims( + GenericOp genericOp, ArrayRef foldedIterationDims, + OpOperand *fusableOpOperand, PatternRewriter &rewriter) { + // Bail on trivial no-op cases. + if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || + llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { + return foldedDims.size() <= 1; + })) + return failure(); + + CollapsingInfo collapsingInfo; + if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), + foldedIterationDims))) { + return rewriter.notifyMatchFailure( + genericOp, "illegal to collapse specified dimensions"); + } // Get the iterator types for the operand. SmallVector iteratorTypes = getCollapsedOpIteratorTypes( @@ -1480,8 +1613,7 @@ return getCollapsedOpIndexingMap(map, collapsingInfo); })); - Location loc = - rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()}); + Location loc = genericOp->getLoc(); // Get the input operands. auto inputOperands = llvm::to_vector( @@ -1576,14 +1708,17 @@ if (!reshapeOp) continue; - if (!isFusableWithReshapeByDimCollapse( - genericOp, opOperand, reshapeOp.getReassociationIndices()) || + SmallVector collapsableIterationDims = + getCollapsableIterationSpaceDims(genericOp, opOperand, + reshapeOp.getReassociationIndices()); + if (collapsableIterationDims.empty() || !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { continue; } - Optional> replacements = fuseWithReshapeByCollapsing( - genericOp, reshapeOp, opOperand, rewriter); + Optional> replacements = + collapseGenericOpIterationDims(genericOp, collapsableIterationDims, + opOperand, rewriter); if (!replacements) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -399,3 +399,128 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPAND]] : // CHECK: return %[[GENERIC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4x?x?x8xf32>) -> tensor<4x?x?x8xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor<4x?x?x8xf32>) + outs(%arg1 : tensor<4x?x?x8xf32>) { + ^bb0(%b0: f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<4x?x?x8xf32> + return %1 : tensor<4x?x?x8xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @fuse_only_one_reassociation( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor<4x?x?x8xf32> +// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] : +// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] : +// CHECK: %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3]{{\]}} +// CHECK: return %[[EXPAND_GENERIC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)> +func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %d0 = tensor.dim %0, %c0 : tensor + %d1 = tensor.dim %0, %c2 : tensor + %init = linalg.init_tensor [%d1, 8, %d0, 4] : tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0 : tensor) outs(%init : tensor) { + ^bb0(%b0 : i32, %b1 : i32): + %2 = linalg.index 0 : index + %3 = linalg.index 1 : index + %4 = linalg.index 2 : index + %5 = linalg.index 3 : index + %6 = arith.addi %2, %3 : index + %7 = arith.addi %6, %4 : index + %8 = arith.addi %7, %5 : index + %9 = arith.index_cast %8 : index to i32 + linalg.yield %9: i32 + } -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: func @fold_non_consecutive_dims( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK-SAME: outs(%[[COLLAPSE_INIT]] : +// CHECK-NEXT: ^bb{{[0-9]}} +// CHECK: %[[ID0:.+]] = linalg.index 0 +// CHECK-DAG: %[[T0:.+]] = arith.remui %[[ID0]], %[[C4]] +// CHECK-DAG: %[[T1:.+]] = arith.divui %[[ID0]], %[[C4]] +// CHECK: %[[ID1:.+]] = linalg.index 1 +// CHECK-DAG: %[[T2:.+]] = arith.remui %[[ID1]], %[[C8]] +// CHECK-DAG: %[[T3:.+]] = arith.divui %[[ID1]], %[[C8]] +// CHECK-DAG: %[[T4:.+]] = arith.addi %[[T1]], %[[T2]] +// CHECK-DAG: %[[T5:.+]] = arith.addi %[[T4]], %[[T0]] +// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]] +// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]] +// CHECK: linalg.yield %[[T7]] +// CHECK: %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: return %[[EXPAND_GENERIC]] + +// ----- + +// None of the folded iteration space dims are contiguous reduction dimensions. +// So no change in the code. +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +#map1 = affine_map<(d0, d1, d2, d3) -> ()> +func @no_fold_non_consecutive_reduction_dims(%arg0 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %init = linalg.init_tensor [] : tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["reduction", "reduction", "reduction", "reduction"]} + ins(%0 : tensor) outs(%init : tensor) { + ^bb0(%b0 : i32, %b1 : i32): + %2 = linalg.index 0 : index + %3 = linalg.index 1 : index + %4 = linalg.index 2 : index + %5 = linalg.index 3 : index + %6 = arith.addi %2, %3 : index + %7 = arith.addi %6, %4 : index + %8 = arith.addi %7, %5 : index + %9 = arith.index_cast %8 : index to i32 + linalg.yield %9: i32 + } -> tensor + return %1 : tensor +} +// CHECK: func @no_fold_non_consecutive_reduction_dims( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[EXPAND_ARG0]] : +// CHECK: return %[[GENERIC]]