diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -35,13 +35,12 @@ int64_t prodOfCollapsedDims = 1; while (sourceDim < sourceShape.size()) { unsigned targetDim = reassociationMap.size(); + // If we have mapped all the target dimensions stop and handle the remaining + // tail of size-1 dimensions explictly. + if (targetDim == targetType.getRank()) + break; - // If all the dimensions of the targetShape are exhausted, then the - // remaining dims in the source shape must be all 1s. So for such cases, set - // 1 as the target shape. The actual reassociation indices will be handled - // later. - int64_t currTargetShape = - (targetDim < targetType.getRank() ? targetShape[targetDim] : 1); + int64_t currTargetShape = targetShape[targetDim]; while (sourceShape[sourceDim] != ShapedType::kDynamicSize && prodOfCollapsedDims * sourceShape[sourceDim] < currTargetShape && sourceDim < sourceShape.size()) { @@ -69,25 +68,23 @@ return llvm::None; currIndices.push_back(sourceDim++); - // If the reassociation is empty but the currIndices is not, this by - // definition is folding unit-dimensions with the result being scalar type. - // So only append the `currIndices` if reassociation map is not empty. - if (targetDim == targetShape.size()) { - while (sourceDim < sourceShape.size()) - currIndices.push_back(sourceDim++); - if (!reassociationMap.empty() && !currIndices.empty()) - reassociationMap.back().append(currIndices.begin(), currIndices.end()); - // Break out of the loops. We should be done here. - break; - } reassociationMap.emplace_back(ReassociationIndices{}); std::swap(reassociationMap.back(), currIndices); prodOfCollapsedDims = 1; } - // All the dimensions in the two shapes must have been processed. - if (reassociationMap.size() != targetShape.size() || - sourceDim != sourceShape.size()) + // All the dimensions in the target must have been processed. + if (reassociationMap.size() != targetShape.size()) return llvm::None; + // Process any remaining entries in the source shape. They all need to be + // 1 or dynamic. + for (; sourceDim < sourceShape.size(); sourceDim++) { + if (sourceShape[sourceDim] != ShapedType::kDynamicSize && + sourceShape[sourceDim] != 1) + return llvm::None; + // The map is empty when the target type is a scalar. + if (!reassociationMap.empty()) + reassociationMap.back().push_back(sourceDim); + } return reassociationMap; } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -879,16 +879,17 @@ // ----- -func @no_fold_reshapes(%arg0 : tensor) -> tensor { +func @fold_reshapes_unit_dims_in_middle(%arg0 : tensor) -> tensor { %0 = tensor.expand_shape %arg0 [[0], [1], [2, 3]] : tensor into tensor %1 = tensor.collapse_shape %0 [[0], [1, 2, 3]] : tensor into tensor return %1 : tensor } -// CHECK-LABEL: func @no_fold_reshapes -// CHECK: tensor.expand_shape -// CHECK: tensor.collapse_shape +// CHECK-LABEL: func @fold_reshapes_unit_dims_in_middle +// CHECK-SAME: (%[[ARG:.*]]: tensor +// CHECK: tensor.collapse_shape %[[ARG]] {{\[}}[0], [1, 2]] +// CHECK-SAME: tensor into tensor // -----