diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -73,6 +73,8 @@ // definition is folding unit-dimensions with the result being scalar type. // So only append the `currIndices` if reassociation map is not empty. if (targetDim == targetShape.size()) { + while (sourceDim < sourceShape.size()) + currIndices.push_back(sourceDim++); if (!reassociationMap.empty() && !currIndices.empty()) reassociationMap.back().append(currIndices.begin(), currIndices.end()); // Break out of the loops. We should be done here. diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -279,6 +279,20 @@ // ----- +func @fold_reshape_trailing_unit_dims(%arg0: tensor<12x42x1x1xf32>) -> tensor<12x42xf32> +{ + %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2], [3, 4]] + : tensor<12x42x1x1xf32> into tensor<12x42x1x1x1xf32> + %1 = linalg.tensor_collapse_shape %0 [[0], [1, 2, 3, 4]] + : tensor<12x42x1x1x1xf32> into tensor<12x42xf32> + return %1 : tensor<12x42xf32> +} +// CHECK: func @fold_reshape_trailing_unit_dims +// CHECK: linalg.tensor_collapse_shape %{{.*}} {{\[}}[0], [1, 2, 3]] +// CHECK-SAME: tensor<12x42x1x1xf32> into tensor<12x42xf32> + +// ----- + func @no_fold_reshapes(%arg0 : tensor) -> tensor { %0 = linalg.tensor_expand_shape %arg0 [[0], [1], [2, 3]]