diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1134,10 +1134,12 @@ return llvm::None; currIndices.push_back(sourceDim++); - // If there are no dimensions in the target to match, then append the - // `currIndices` to the last element of the reassociationMap. + // If the reassociation is empty but the currIndices is not, this by + // definition is folding unit-dimensions with the result being scalar type. + // So only append the `currIndices` if reassociation map is not empty. if (targetDim == targetShape.size()) { - reassociationMap.back().append(currIndices.begin(), currIndices.end()); + if (!reassociationMap.empty() && !currIndices.empty()) + reassociationMap.back().append(currIndices.begin(), currIndices.end()); // Break out of the loops. We should be done here. break; } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -43,6 +43,17 @@ // ----- +// CHECK-LABEL: zero_rank_reshape_multi +func @zero_rank_reshape_multi(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = linalg.tensor_reshape %arg0 [] : tensor into tensor<1xf32> + %1 = linalg.tensor_reshape %0 [[0, 1]] : tensor<1xf32> into tensor<1x1xf32> + %2 = linalg.tensor_reshape %1 [] : tensor<1x1xf32> into tensor + return %2 : tensor +} + +// ----- + func @collapsing_tensor_reshapes(%arg0 : tensor) -> tensor { %0 = linalg.tensor_reshape %arg0 [[0, 1], [2], [3, 4]]