diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -449,8 +449,13 @@ reassociation.emplace_back(ReassociationIndices{}); std::swap(reassociation.back(), curr); } - if (!curr.empty()) - reassociation.back().append(curr.begin(), curr.end()); + if (!curr.empty()) { + // When the reassociations are not empty, then the fold the remaining + // unit-dimensions into the last dimension. If the reassociations so far is + // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. + if (!reassociation.empty()) + reassociation.back().append(curr.begin(), curr.end()); + } return reassociation; } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -514,3 +514,27 @@ // CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] [#[[MAP2]]] // CHECK: return %[[RESULT_RESHAPE]] + +// ----- + +func @subtensor_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = subtensor %arg0[0, 2] [1, 1] [1, 1] : tensor<1x3xf32> to tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} +// CHECK-LABEL: func @subtensor_unit_dims +// CHECK: %[[SUBTENSOR:.+]] = subtensor +// CHECK-SAME: tensor<1x3xf32> to tensor +// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] [] +// CHECK: return %[[RESULT]] + +// ----- + +func @subtensor_insert_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> { + %0 = subtensor_insert %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} +// CHECK-LABEL: func @subtensor_insert_unit_dims +// CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]] +// CHECK-SAME: tensor into tensor<1x3xf32> +// CHECK: return %[[RESULT]]