diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -500,7 +500,16 @@ } else { return failure(); } + foldedDim++; + // If inner most dims are folded there shouldn't be any leading 1 dims. + // otherwise these dims are not mapped and will lead into an illegal + // reshape. + if (expandedDim == expandedShape.size()) { + if (foldedDim < foldedShape.size() && foldedShape[foldedDim] == 1) { + return failure(); + } + } } if (expandedDim != expandedShape.size()) return failure(); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -582,3 +582,21 @@ } return } + +// ----- + +func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> + %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<3x2x2x1xf32> into tensor<12x1xf32> + return %1 : tensor<12x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: func @no_fold_reshape_empty_expr +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> +// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0:.+]] [#[[MAP0]], #[[MAP1]], #[[MAP2]] +// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0:.+]] [#[[MAP3]], #[[MAP4]]] +// CHECK: return %[[RES:.+]] : tensor<12x1xf32>