Index: mlir/include/mlir/Dialect/Linalg/Utils/Utils.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -136,7 +136,8 @@ /// but non-zero offsets are not handled by SPIR-V backend at this point (and /// potentially cannot be handled). std::optional> -getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes, + bool insert = false); /// Return the identity numeric value associated to the give op. Return /// std::nullopt if there is no known neutral element. Index: mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -634,7 +634,8 @@ SmallVector offsets = insertSliceOp.getMixedOffsets(); SmallVector sizes = insertSliceOp.getMixedSizes(); SmallVector strides = insertSliceOp.getMixedStrides(); - auto reassociation = getReassociationMapForFoldingUnitDims(sizes); + auto reassociation = + getReassociationMapForFoldingUnitDims(sizes, /*insert=*/true); if (!reassociation || reassociation->size() == static_cast(sourceType.getRank())) return failure(); Index: mlir/lib/Dialect/Linalg/Utils/Utils.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -1039,16 +1039,21 @@ /// but non-zero offsets are not handled by SPIR-V backend at this point (and /// potentially cannot be handled). std::optional> -getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes) { +getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes, + bool insert) { SmallVector reassociation; ReassociationIndices curr; for (const auto &it : llvm::enumerate(mixedSizes)) { auto dim = it.index(); auto size = it.value(); - curr.push_back(dim); + if (!insert) + curr.push_back(dim); auto attr = size.dyn_cast(); - if (attr && attr.cast().getInt() == 1) + if (attr && attr.cast().getInt() == 1) { + if (insert) + curr.push_back(dim); continue; + } reassociation.emplace_back(ReassociationIndices{}); std::swap(reassociation.back(), curr); } Index: mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir =================================================================== --- mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -501,6 +501,18 @@ // ----- +func.func @insert_slice_lower_rank(%arg0: tensor<1x1x32xi32>, %arg1: tensor<1x32xi32>) -> tensor<1x1x32xi32> { + %0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0] [1, 1, 32] [1, 1, 1] : tensor<1x32xi32> into tensor<1x1x32xi32> + return %0 : tensor<1x1x32xi32> +} +// CHECK-LABEL: func @insert_slice_lower_rank +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1]] : tensor<1x32xi32> into tensor<32xi32> +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]] +// CHECK-SAME: tensor<32xi32> into tensor<1x1x32xi32> +// CHECK: return %[[RESULT]] + +// ----- + #accesses = [ affine_map<(i, j, k, l, m) -> (i, k, m)>, affine_map<(i, j, k, l, m) -> ()>,