diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -403,12 +403,23 @@ srcType.getRank() < dstType.getRank() || parentSrcType.getRank() == dstType.getRank()) return failure(); + + // Rank-1 parentSrcType there will always have just an identity + // reassociation map. + if (parentSrcType.getRank() == 1) { + rewriter.replaceOpWithNewOp( + reshapeOp, dstType, parentReshapeOp.src(), + rewriter.getAffineMapArrayAttr({AffineMap::getMultiDimIdentityMap( + dstType.getRank(), rewriter.getContext())})); + return success(); + } + // Check if the result tensor_reshape after folding the reshapeOp and // parentReshapeOp are combined. // If the final tensor_reshape is folding, the parentReshapeOp is // introducing unit-dims, and the reshapeOp does an actual reshape. - // If the final tensor_reshape op is expanding, the reshapeOp is introducing - // unit-dims, and the parentReshapeOp does an actual reshape. + // If the final tensor_reshape op is expanding, the reshapeOp is + // introducing unit-dims, and the parentReshapeOp does an actual reshape. bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank(); auto reassociationMaps = isFoldingPattern ? reshapeOp.getReassociationMaps() diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -240,3 +240,19 @@ : tensor<1x4x1x512x1x1x512x1x4xf32> into tensor<4x512x1x512x4xf32> return %1 : tensor<4x512x1x512x4xf32> } + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func @fold_reshape +// CHECK: linalg.tensor_reshape %{{.*}} [#[[MAP0]] +// CHECK-SAME: tensor<2xf32> into tensor<2x1xf32> +func @fold_reshape(%arg0: tensor<2xf32>) -> tensor<2x1xf32> +{ + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : tensor<2xf32> into tensor<2x1x1xf32> + %1 = linalg.tensor_reshape %0 + [affine_map<(d0, d1, d2) -> (d0)>, + affine_map<(d0, d1, d2) -> (d1, d2)> + ] : tensor<2x1x1xf32> into tensor<2x1xf32> + return %1 : tensor<2x1xf32> +}