diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -505,6 +505,12 @@ if (expandedDim != expandedShape.size()) return failure(); + if (!llvm::all_of(reassociationExprs, + [](SmallVector exprList) { + return exprList.size() > 0; + })) + return failure(); + SmallVector reassociationMaps = llvm::to_vector<4>(llvm::map_range( reassociationExprs, [&](ArrayRef exprs) -> AffineMap { diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -582,3 +582,21 @@ } return } + +// ----- + +func @no_fold_reshape_empty_expr(%arg0: tensor<3x2x2xf32>) -> tensor<12x1xf32> { + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2, d3) -> (d0)>, affine_map<(d0, d1, d2, d3) -> (d1)>, affine_map<(d0, d1, d2, d3) -> (d2, d3)>] : tensor<3x2x2xf32> into tensor<3x2x2x1xf32> + %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<3x2x2x1xf32> into tensor<12x1xf32> + return %1 : tensor<12x1xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: func @no_fold_reshape_empty_expr +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x2x2xf32> +// CHECK: %[[RARG0:.+]] = linalg.tensor_reshape %[[ARG0:.+]] [#[[MAP0]], #[[MAP1]], #[[MAP2]] +// CHECK: %[[RES:.+]] = linalg.tensor_reshape %[[RARG0:.+]] [#[[MAP3]], #[[MAP4]]] +// CHECK: return %[[RES:.+]] : tensor<12x1xf32>