diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -403,61 +403,58 @@ srcType.getRank() < dstType.getRank() || parentSrcType.getRank() == dstType.getRank()) return failure(); + // Check if the result tensor_reshape after folding the reshapeOp and // parentReshapeOp are combined. // If the final tensor_reshape is folding, the parentReshapeOp is // introducing unit-dims, and the reshapeOp does an actual reshape. - // If the final tensor_reshape op is expanding, the reshapeOp is introducing - // unit-dims, and the parentReshapeOp does an actual reshape. + // If the final tensor_reshape op is expanding, the reshapeOp is + // introducing unit-dims, and the parentReshapeOp does an actual reshape. bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); - auto reassociationMaps = isFoldingPattern - ? reshapeOp.getReassociationMaps() - : parentReshapeOp.getReassociationMaps(); - DenseSet conservedDimensions; - for (auto &map : reassociationMaps) { - if (map.getNumResults() == 1) { - conservedDimensions.insert( - map.getResult(0).cast().getPosition()); - } - } - - // Find positions at which the unit-dims exist. - int64_t nonUnitDimPos = 0; - DenseMap nonUnitSrcDims; - ArrayRef nonUnitShape = + ArrayRef expandedShape = isFoldingPattern ? parentSrcType.getShape() : dstType.getShape(); - for (auto shape : enumerate(srcType.getShape())) { - // Case 1 : It is a conserved dimension. - if (conservedDimensions.count(shape.index())) { - nonUnitSrcDims[shape.index()] = nonUnitDimPos++; - continue; + ArrayRef foldedShape = + isFoldingPattern ? dstType.getShape() : parentSrcType.getShape(); + + unsigned expandedDim = 0, foldedDim = 0; + SmallVector, 4> reassociationExprs( + foldedShape.size()); + while (expandedDim < expandedShape.size() && + foldedDim < foldedShape.size()) { + int64_t dstSize = foldedShape[foldedDim]; + int64_t srcSize = expandedShape[expandedDim]; + while (srcSize < dstSize && expandedDim < expandedShape.size()) { + reassociationExprs[foldedDim].push_back( + rewriter.getAffineDimExpr(expandedDim++)); + srcSize *= expandedShape[expandedDim]; } - // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim. - if (shape.value() == 1) - continue; - // Case 3 : Dimensions match, treat it as a non-unit src dim. - if (nonUnitDimPos < static_cast(nonUnitShape.size()) && - nonUnitShape[nonUnitDimPos] == shape.value()) { - nonUnitSrcDims[shape.index()] = nonUnitDimPos++; - continue; + if (srcSize == dstSize) { + reassociationExprs[foldedDim].push_back( + rewriter.getAffineDimExpr(expandedDim++)); + // If the next dim in foldedShape is not 1, treat subsequent dims in + // expandedShape which are 1 to be collapsed. + if (foldedDim == foldedShape.size() - 1 || + foldedShape[foldedDim + 1] != 1) { + while (expandedDim < expandedShape.size() && + expandedShape[expandedDim] == 1) { + reassociationExprs[foldedDim].push_back( + rewriter.getAffineDimExpr(expandedDim++)); + } + } + } else { + return failure(); } - return failure(); + foldedDim++; } + if (expandedDim != expandedShape.size()) + return failure(); - // Compute reassociation maps for the final operation. Use the reassociation - // maps that is actually doing a reshape (and not just introducing - // unit-dims). From these maps, prune the unit-extent dimensions. - for (AffineMap &map : reassociationMaps) { - SmallVector exprs; - exprs.reserve(nonUnitSrcDims.size()); - for (auto result : map.getResults()) { - unsigned dim = result.cast().getPosition(); - if (nonUnitSrcDims.count(dim)) - exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim])); - } - map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs, - rewriter.getContext()); - } + SmallVector reassociationMaps = + llvm::to_vector<4>(llvm::map_range( + reassociationExprs, [&](ArrayRef exprs) -> AffineMap { + return AffineMap::get(expandedShape.size(), 0, exprs, + rewriter.getContext()); + })); rewriter.replaceOpWithNewOp( reshapeOp, dstType, parentReshapeOp.src(), rewriter.getAffineMapArrayAttr(reassociationMaps)); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -240,3 +240,19 @@ : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]] +// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> +func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)> + ] : tensor<2x1x1xf32> into tensor<2x1xf32> + return %1 : tensor<2x1xf32> +}