diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1151,43 +1151,24 @@ // contraction of dimensions. //===---------------------------------------------------------------------===// -/// For an `indexingMap` that is a projected permutation, if the range is to be -/// collapsed using the given `reassociation`, get the reassociation in the -/// domain that would keep the map a projected permutation. -static SmallVector +/// For a given list of indices in the range of the `indexingMap` that are +/// folded, return the indices of the corresponding domain. Return `llvm::None` +/// on failure. Ensures that all the elements of the returned reassociation are +/// distinct. +static Optional getDomainReassociation(AffineMap indexingMap, - ArrayRef rangeReassociation) { - assert(indexingMap.isProjectedPermutation() && - "expected projected permutation map"); - unsigned counter = 0; - SmallVector domainReassociation; - llvm::SmallDenseSet processedDomainDims; - // Iterate over the reassociation indices. - for (ReassociationIndicesRef foldedRangeDims : rangeReassociation) { - ReassociationIndices foldedDomainDims; - for (auto rangeDim : foldedRangeDims) { - (void)rangeDim; - AffineDimExpr dimExpr = - indexingMap.getResult(counter++).cast(); - foldedDomainDims.push_back(dimExpr.getPosition()); - processedDomainDims.insert(dimExpr.getPosition()); - } - domainReassociation.emplace_back(std::move(foldedDomainDims)); - } - // Fill in the missing domain dims. - for (auto dim : llvm::seq(0, indexingMap.getNumDims())) { - if (processedDomainDims.count(dim)) - continue; - ReassociationIndices vec = {dim}; - domainReassociation.emplace_back(std::move(vec)); - } + ReassociationIndicesRef rangeReassociation) { + if (!indexingMap.isProjectedPermutation()) + return llvm::None; - // Sort the reassociation using the first dimension of the folded range to - // not create unnecessary transposes. - llvm::sort(domainReassociation, - [](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { - return lhs[0] < rhs[0]; - }); + ReassociationIndices domainReassociation = llvm::to_vector<4>( + llvm::map_range(rangeReassociation, [&](int64_t pos) -> int64_t { + return indexingMap.getResults()[pos] + .cast() + .getPosition(); + })); + // The projected permutation semantics ensures that there is no repeatition of + // the domain indices. return domainReassociation; } @@ -1236,99 +1217,174 @@ } // Check if a generic op can be fused along an operand by collapsing dimensions. -static bool isFusableWithReshapeByDimCollapse( +static SmallVector isFusableWithReshapeByDimCollapse( GenericOp genericOp, OpOperand *fusableOperand, ArrayRef reassociation) { // Some basic checks for this fusion to be valid. if (!genericOp.hasTensorSemantics() || genericOp.getNumOutputs() != 1) - return false; + return {}; if (!llvm::all_of(genericOp.getIndexingMaps(), [](AffineMap map) { return map.isProjectedPermutation(); })) { - return false; - } - - // Get the reassociation for the iteration space. - SmallVector iterationReassociation = - getDomainReassociation(genericOp.getTiedIndexingMap(fusableOperand), - reassociation); - if (iterationReassociation.empty()) { - // If the domain reassociation indices is empty, then this is a scalar op. - // Nothing to do. - return false; + return {}; } + llvm::SmallDenseSet processedIterationDims; + AffineMap indexingMap = genericOp.getTiedIndexingMap(fusableOperand); auto iteratorTypes = genericOp.iterator_types().getValue(); - ArrayRef iteratorTypesRef(iteratorTypes); - for (ReassociationIndicesRef foldedIterDims : iterationReassociation) { - // Check that for all indexing maps, the folded dimensions sequence is - // preserved. - if (!llvm::all_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { - return isDimSequencePreserved(indexingMap, foldedIterDims); + SmallVector domainReassociation; + for (ReassociationIndicesRef foldedRangeDims : reassociation) { + assert(!foldedRangeDims.empty() && "unexpected empty reassociation"); + + // Ignore dims that are not folded. + if (foldedRangeDims.size() == 1) + continue; + + Optional foldedDomainDims = + getDomainReassociation(indexingMap, foldedRangeDims); + if (!foldedDomainDims) + continue; + ReassociationIndicesRef foldedDomainDimsRef(foldedDomainDims.getValue()); + + // Check that the folded iteration dims do not contain already processed + // dims. + if (llvm::any_of(foldedDomainDimsRef, [&](int64_t dim) { + return processedIterationDims.count(dim); })) - return false; - unsigned startDim = foldedIterDims[0]; - ArrayRef foldedIteratorTypes = - iteratorTypesRef.drop_front(startDim).take_front(foldedIterDims.size()); - // Check that all folded iterator types are either all parallel, or all - // reduction. - if (!llvm::all_of( - foldedIteratorTypes, - [](Attribute attr) { return isParallelIterator(attr); }) && - !llvm::all_of(foldedIteratorTypes, - [](Attribute attr) { return isReductionIterator(attr); })) - return false; + continue; + + // Check that all folded iterator types are all parallel or all reductions. + Attribute startIteratorType = iteratorTypes[foldedDomainDimsRef[0]]; + if (!isParallelIterator(startIteratorType) && + !isReductionIterator(startIteratorType)) + continue; + if (llvm::any_of(foldedDomainDimsRef, [&](int64_t dim) { + return iteratorTypes[dim] != startIteratorType; + })) + continue; + + // If the folded dimension correspond to a "reduction" iterator type, + // the folded dimensions need to be "in-order". Strictly speaking this is + // not necessary, for reductions that are association and commutative, but + // using a more strict definition of reduction for now. + if (isReductionIterator(startIteratorType)) { + bool isMonotonic = true; + for (auto i : llvm::seq(0, foldedDomainDimsRef.size() - 1)) { + if (foldedDomainDimsRef[i] > foldedDomainDimsRef[i + 1]) { + isMonotonic = false; + break; + } + } + if (!isMonotonic) + continue; + } + + // Check that the sequence is preserved in all indexing maps. + if (llvm::any_of(genericOp.getIndexingMaps(), [&](AffineMap indexingMap) { + return !isDimSequencePreserved(indexingMap, + foldedDomainDims.getValue()); + })) + continue; + + processedIterationDims.insert(foldedDomainDimsRef.begin(), + foldedDomainDimsRef.end()); + domainReassociation.emplace_back(foldedDomainDimsRef.begin(), + foldedDomainDimsRef.end()); } - return true; + + return domainReassociation; } /// Helper class to carry state while collapsing the `linalg.generic` op. namespace { class CollapsingInfo { public: - CollapsingInfo(SmallVector &&reassociation) { - iterationReassociation = std::move(reassociation); - for (const auto &foldedIterDims : enumerate(iterationReassociation)) { - foldedDimStartToSequenceMap[foldedIterDims.value()[0]] = - foldedIterDims.index(); + LogicalResult initialize(unsigned origLoopDim, + ArrayRef foldedIterationDims) { + llvm::SmallDenseSet processedDims; + // Find all the dims that are folded. + for (ReassociationIndicesRef foldedIterationDim : foldedIterationDims) { + if (foldedIterationDim.empty()) + continue; + // If the folded dims contains dims already folded, thats illegal + // specification. Repition within a list is also illegal. + for (auto dim : foldedIterationDim) { + if (dim >= origLoopDim) + return failure(); + if (processedDims.count(dim)) + return failure(); + processedDims.insert(dim); + } + collapsedOpToOrigOpIterationDim.emplace_back(foldedIterationDim.begin(), + foldedIterationDim.end()); + } + if (processedDims.size() > origLoopDim) + return failure(); + + // Add all the preserved dims of the original op as single + // elements to `collapsedOpToOrigOpIterationDim`. + for (auto dim : llvm::seq(0, origLoopDim)) { + if (processedDims.count(dim)) + continue; + collapsedOpToOrigOpIterationDim.emplace_back(ReassociationIndices{dim}); } - } - // Returns the iteration space reassociation. - ArrayRef getReassociationIndices() { - return iterationReassociation; + llvm::sort(collapsedOpToOrigOpIterationDim, + [&](ReassociationIndicesRef lhs, ReassociationIndicesRef rhs) { + return lhs[0] < rhs[0]; + }); + origOpToCollapsedOpIterationDim.resize(origLoopDim); + for (auto foldedDims : llvm::enumerate(collapsedOpToOrigOpIterationDim)) { + for (auto dim : enumerate(foldedDims.value())) + origOpToCollapsedOpIterationDim[dim.value()] = + std::make_pair(foldedDims.index(), dim.index()); + } + return success(); } - // Returns true if the given dimension is the start of a sequence of folded - // dimensions. - bool isDimStartOfFoldedDims(unsigned dim) { - return foldedDimStartToSequenceMap.count(dim); + /// Return mapping from collapsed loop domain to original loop domain. + ArrayRef getCollapsedOpToOrigOpMapping() { + return collapsedOpToOrigOpIterationDim; } - // Return the folded dimensions starting at `dim`. - ReassociationIndicesRef getFoldedDimsStartingAt(unsigned dim) { - assert(foldedDimStartToSequenceMap.count(dim) && - "invalid start dim of folded dim " - "sequence"); - return iterationReassociation[foldedDimStartToSequenceMap[dim]]; + /// Return mapping from original loop domain to collapsed loop domain. The + /// mapping is a pair. First value is the dimension in the collapsed loop that + /// the original loop is mapped to. Second is the relative position in folded + /// list of this domain. For example if the original loop domain is 3D, and + /// the collapsed loop domain is folding all of it, i.e. + /// + /// ``` + /// collapsedOpToOrigOpMapping = [[0, 1, 2] [3, 4]]` + /// ``` + /// + /// then + /// + /// ``` + /// origOpToCollapsedOpMapping[0] = {0, 0}; + /// origOpToCollapsedOpMapping[0] = {0, 1}; + /// origOpToCollapsedOpMapping[0] = {0, 2}; + /// origOpToCollapsedOpMapping[0] = {1, 0}; + /// origOpToCollapsedOpMapping[0] = {1, 1}; + /// ``` + /// + ArrayRef> getOrigOpToCollapsedOpMapping() { + return origOpToCollapsedOpIterationDim; } - // For a dim in the original op, return the dim in the collapsed op, that it - // is mapped to. Expectes `dim` to be start of a folded dimension sequence. - unsigned getDimInCollapsedOpForStartOfFoldedDims(unsigned dim) { - assert(foldedDimStartToSequenceMap.count(dim) && - "invalid start dim of folded dim sequence"); - return foldedDimStartToSequenceMap[dim]; + /// Return the collapsed op iteration domain rank. + unsigned getCollapsedOpIterationRank() { + return collapsedOpToOrigOpIterationDim.size(); } private: - /// Reassociation describing the folded iteration space dimensions. - SmallVector iterationReassociation; + /// Map from the iteration domain index in collapsed op to the iteration + /// domain indices in the original op. + SmallVector collapsedOpToOrigOpIterationDim; - /// Map from the starting dimensions of the folded dimension sequences to - /// their index in `iterationReassociation`. - llvm::DenseMap foldedDimStartToSequenceMap; + /// Map from iteration domain index in the original op to the iteration domain + /// index in the collapsed op. + SmallVector> origOpToCollapsedOpIterationDim; }; } // namespace @@ -1339,7 +1395,7 @@ CollapsingInfo &collapsingInfo) { SmallVector collapsedIteratorTypes; for (ReassociationIndicesRef foldedIterDims : - collapsingInfo.getReassociationIndices()) { + collapsingInfo.getCollapsedOpToOrigOpMapping()) { assert(!foldedIterDims.empty() && "reassociation indices expected to have non-empty sets"); // Just pick the iterator type of the first folded dim. Pre-condition checks @@ -1359,15 +1415,19 @@ assert(indexingMap.isProjectedPermutation() && "expected indexing map to be projected permutation"); SmallVector resultExprs; + auto origOpToCollapsedOpMapping = + collapsingInfo.getOrigOpToCollapsedOpMapping(); for (auto expr : indexingMap.getResults()) { unsigned dim = expr.cast().getPosition(); - if (collapsingInfo.isDimStartOfFoldedDims(dim)) { - resultExprs.push_back(getAffineDimExpr( - collapsingInfo.getDimInCollapsedOpForStartOfFoldedDims(dim), - context)); - } + // If the dim is not the first of the collapsed dim, do nothing. + if (origOpToCollapsedOpMapping[dim].second != 0) + continue; + // The next n-dims are guaranteed to be collapsed. So just use the + // iteration dimension of the collapsed op. + resultExprs.push_back( + getAffineDimExpr(origOpToCollapsedOpMapping[dim].first, context)); } - return AffineMap::get(collapsingInfo.getReassociationIndices().size(), 0, + return AffineMap::get(collapsingInfo.getCollapsedOpIterationRank(), 0, resultExprs, context); } @@ -1377,11 +1437,20 @@ getOperandReassociation(AffineMap indexingMap, CollapsingInfo &collapsingInfo) { unsigned counter = 0; SmallVector operandReassociation; - for (auto expr : indexingMap.getResults()) { - unsigned dim = expr.cast().getPosition(); - if (collapsingInfo.isDimStartOfFoldedDims(dim)) { + auto origOpToCollapsedOpMapping = + collapsingInfo.getOrigOpToCollapsedOpMapping(); + auto collapsedOpToOrigOpMapping = + collapsingInfo.getCollapsedOpToOrigOpMapping(); + while (counter < indexingMap.getNumResults()) { + unsigned dim = + indexingMap.getResult(counter).cast().getPosition(); + if (origOpToCollapsedOpMapping[dim].second == 0) { + // This is the start of a collapsed dimensions of the iteration that + // is gauranteed to be preserved in the indexing map. The number of folded + // dims is obtained from the collapsed op to original op mapping. unsigned numFoldedDims = - collapsingInfo.getFoldedDimsStartingAt(dim).size(); + collapsedOpToOrigOpMapping[origOpToCollapsedOpMapping[dim].first] + .size(); auto range = llvm::seq(counter, counter + numFoldedDims); operandReassociation.emplace_back(range.begin(), range.end()); counter += numFoldedDims; @@ -1431,7 +1500,8 @@ // i1 = (i_{folded} / d2) % d1 // i0 = i_{folded} / (d1 * d2) llvm::DenseMap indexReplacementVals; - for (auto &foldedDims : enumerate(collapsingInfo.getReassociationIndices())) { + for (auto &foldedDims : + enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { ReassociationIndicesRef foldedDimsRef(foldedDims.value()); Value newIndexVal = rewriter.create(loc, foldedDims.index()); @@ -1451,25 +1521,24 @@ } /// Implementation of fusion with reshape operation by collapsing dimensions. -static Optional> -fuseWithReshapeByCollapsing(GenericOp genericOp, Operation *reshapeOp, - OpOperand *fusableOpOperand, - PatternRewriter &rewriter) { - SmallVector reassociation = - isa(reshapeOp) - ? cast(reshapeOp).getReassociationIndices() - : cast(reshapeOp).getReassociationIndices(); - assert(isFusableWithReshapeByDimCollapse(genericOp, fusableOpOperand, - reassociation) && - "preconditions for fusing with reshape by collapse failed"); - - CollapsingInfo collapsingInfo(getDomainReassociation( - genericOp.getTiedIndexingMap(fusableOpOperand), reassociation)); - // Check for trivial no transformation cases. In that case return nothing. - if (collapsingInfo.getReassociationIndices().size() == - genericOp.getNumLoops()) +static Optional> collapseGenericOpIterationDims( + GenericOp genericOp, ArrayRef foldedIterationDims, + OpOperand *fusableOpOperand, PatternRewriter &rewriter) { + // Bail on trivial no-op cases. + if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || + llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { + return foldedDims.size() <= 1; + })) return llvm::None; + CollapsingInfo collapsingInfo; + if (failed(collapsingInfo.initialize(genericOp.getNumLoops(), + foldedIterationDims))) { + rewriter.notifyMatchFailure(genericOp, + "illegal to collapsed specified dimensions"); + return llvm::None; + } + // Get the iterator types for the operand. SmallVector iteratorTypes = getCollapsedOpIteratorTypes( genericOp.iterator_types().getValue(), collapsingInfo); @@ -1480,8 +1549,7 @@ return getCollapsedOpIndexingMap(map, collapsingInfo); })); - Location loc = - rewriter.getFusedLoc({genericOp->getLoc(), reshapeOp->getLoc()}); + Location loc = genericOp->getLoc(); // Get the input operands. auto inputOperands = llvm::to_vector( @@ -1576,14 +1644,17 @@ if (!reshapeOp) continue; - if (!isFusableWithReshapeByDimCollapse( - genericOp, opOperand, reshapeOp.getReassociationIndices()) || + SmallVector collapsableIterationDims = + isFusableWithReshapeByDimCollapse( + genericOp, opOperand, reshapeOp.getReassociationIndices()); + if (collapsableIterationDims.empty() || !controlFoldingReshapes(reshapeOp->getResult(0), *opOperand)) { continue; } - Optional> replacements = fuseWithReshapeByCollapsing( - genericOp, reshapeOp, opOperand, rewriter); + Optional> replacements = + collapseGenericOpIterationDims(genericOp, collapsableIterationDims, + opOperand, rewriter); if (!replacements) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); diff --git a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir --- a/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir +++ b/mlir/test/Dialect/Linalg/fuse-with-reshape-by-collapsing.mlir @@ -399,3 +399,141 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[EXPAND]] : // CHECK: return %[[GENERIC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func @fuse_only_one_reassociation(%arg0 : tensor, %arg1 : tensor<4x?x?x8xf32>) -> tensor<4x?x?x8xf32> { + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor<4x?x?x8xf32>) + outs(%arg1 : tensor<4x?x?x8xf32>) { + ^bb0(%b0: f32, %b1 : f32, %b2 : f32): + %2 = arith.addf %b0, %b1 : f32 + linalg.yield %2 : f32 + } -> tensor<4x?x?x8xf32> + return %1 : tensor<4x?x?x8xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK: func @fuse_only_one_reassociation( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-SAME: %[[ARG1:.+]]: tensor<4x?x?x8xf32> +// CHECK-DAG: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK-DAG: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0], [1], [2, 3{{\]}} +// CHECK-DAG: %[[COLLAPSE_ARG1_0:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3{{\]}} +// CHECK-DAG: %[[COLLAPSE_ARG1_1:.+]] = tensor.collapse_shape %[[ARG1]] {{\[}}[0], [1], [2, 3{{\]}} +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[COLLAPSE_ARG0]], %[[COLLAPSE_ARG1_0]] : +// CHECK-SAME: outs(%[[COLLAPSE_ARG1_1]] : +// CHECK: %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0], [1], [2, 3{{\]}} +// CHECK: return %[[EXPAND_GENERIC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d3, d1, d0, d2)> +func @fold_non_consecutive_dims(%arg0 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %d0 = tensor.dim %0, %c0 : tensor + %d1 = tensor.dim %0, %c2 : tensor + %init = linalg.init_tensor [%d1, 8, %d0, 4] : tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%0 : tensor) outs(%init : tensor) { + ^bb0(%b0 : i32, %b1 : i32): + %2 = linalg.index 0 : index + %3 = linalg.index 1 : index + %4 = linalg.index 2 : index + %5 = linalg.index 3 : index + %6 = arith.addi %2, %3 : index + %7 = arith.addi %6, %4 : index + %8 = arith.addi %7, %5 : index + %9 = arith.index_cast %8 : index to i32 + linalg.yield %9: i32 + } -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d1, d0)> +// CHECK: func @fold_non_consecutive_dims( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[COLLAPSE_INIT:.+]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK-SAME: outs(%[[COLLAPSE_INIT]] : +// CHECK-NEXT: ^bb{{[0-9]}} +// CHECK: %[[ID0:.+]] = linalg.index 0 +// CHECK-DAG: %[[T0:.+]] = arith.remui %[[ID0]], %[[C4]] +// CHECK-DAG: %[[T1:.+]] = arith.divui %[[ID0]], %[[C4]] +// CHECK: %[[ID1:.+]] = linalg.index 1 +// CHECK-DAG: %[[T2:.+]] = arith.remui %[[ID1]], %[[C8]] +// CHECK-DAG: %[[T3:.+]] = arith.divui %[[ID1]], %[[C8]] +// CHECK-DAG: %[[T4:.+]] = arith.addi %[[T1]], %[[T2]] +// CHECK-DAG: %[[T5:.+]] = arith.addi %[[T4]], %[[T0]] +// CHECK-DAG: %[[T6:.+]] = arith.addi %[[T5]], %[[T3]] +// CHECK-DAG: %[[T7:.+]] = arith.index_cast %[[T6]] +// CHECK: linalg.yield %[[T7]] +// CHECK: %[[EXPAND_GENERIC:.+]] = tensor.expand_shape %[[GENERIC]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: return %[[EXPAND_GENERIC]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3, d1)> +#map1 = affine_map<(d0, d1, d2, d3) -> ()> +func @fold_non_consecutive_reduction_dims(%arg0 : tensor) -> tensor { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %0 = tensor.expand_shape %arg0 [[0, 1], [2, 3]] : tensor into tensor + %d0 = tensor.dim %0, %c0 : tensor + %d1 = tensor.dim %0, %c2 : tensor + %init = linalg.init_tensor [] : tensor + %1 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["reduction", "reduction", "reduction", "reduction"]} + ins(%0 : tensor) outs(%init : tensor) { + ^bb0(%b0 : i32, %b1 : i32): + %2 = linalg.index 0 : index + %3 = linalg.index 1 : index + %4 = linalg.index 2 : index + %5 = linalg.index 3 : index + %6 = arith.addi %2, %3 : index + %7 = arith.addi %6, %4 : index + %8 = arith.addi %7, %5 : index + %9 = arith.index_cast %8 : index to i32 + linalg.yield %9: i32 + } -> tensor + return %1 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> ()> +// CHECK: func @fold_non_consecutive_reduction_dims( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[EXPAND_ARG0:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2, 3]{{\]}} +// CHECK: %[[COLLAPSE_ARG0:.+]] = tensor.collapse_shape %[[EXPAND_ARG0]] {{\[}}[0, 1], [2], [3]{{\]}} +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-NEXT: ^bb{{[0-9]}} +// CHECK: %[[ID0:.+]] = linalg.index 0 +// CHECK-DAG: %[[T0:.+]] = arith.remui %[[ID0]], %[[C4]] +// CHECK-DAG: %[[T1:.+]] = arith.divui %[[ID0]], %[[C4]] +// CHECK: %[[ID1:.+]] = linalg.index 1 +// CHECK-DAG: %[[ID2:.+]] = linalg.index 2 +// CHECK-DAG: %[[T2:.+]] = arith.addi %[[T1]], %[[ID1]] +// CHECK-DAG: %[[T3:.+]] = arith.addi %[[T2]], %[[T0]] +// CHECK-DAG: %[[T4:.+]] = arith.addi %[[T3]], %[[ID2]] +// CHECK-DAG: %[[T5:.+]] = arith.index_cast %[[T4]] +// CHECK: linalg.yield %[[T5]]