diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -459,7 +459,10 @@ reassociation.emplace_back(ReassociationIndices{}); std::swap(reassociation.back(), curr); } - if (!curr.empty()) + // When the reassociations are not empty, then fold the remaining + // unit-dimensions into the last dimension. If the reassociations so far is + // empty, then leave it emtpy. This will fold everything to a rank-0 tensor. + if (!curr.empty() && !reassociation.empty()) reassociation.back().append(curr.begin(), curr.end()); return reassociation; } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -496,3 +496,27 @@ // CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_reshape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]] + +// ----- + +func @subtensor_unit_dims(%arg0: tensor<1x3xf32>) -> tensor<1x1xf32> { + %0 = subtensor %arg0[0, 2] [1, 1] [1, 1] : tensor<1x3xf32> to tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} +// CHECK-LABEL: func @subtensor_unit_dims +// CHECK: %[[SUBTENSOR:.+]] = subtensor +// CHECK-SAME: tensor<1x3xf32> to tensor +// CHECK: %[[RESULT:.+]] = linalg.tensor_reshape %[[SUBTENSOR]] [] +// CHECK: return %[[RESULT]] + +// ----- + +func @subtensor_insert_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> { + %0 = subtensor_insert %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} +// CHECK-LABEL: func @subtensor_insert_unit_dims +// CHECK: %[[RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] +// CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]] +// CHECK-SAME: tensor into tensor<1x3xf32> +// CHECK: return %[[RESULT]]