diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -172,7 +172,7 @@ /// general, but non-zero offsets are not handled by SPIR-V backend at this /// point (and potentially cannot be handled). std::optional> -getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes); +getReassociationMapForFoldingUnitDims(ArrayRef targetShape); /// Return the identity numeric value associated to the give op. Return /// std::nullopt if there is no known neutral element. diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -603,7 +603,8 @@ SmallVector offsets = sliceOp.getMixedOffsets(); SmallVector sizes = sliceOp.getMixedSizes(); SmallVector strides = sliceOp.getMixedStrides(); - auto reassociation = getReassociationMapForFoldingUnitDims(sizes); + auto reassociation = + getReassociationMapForFoldingUnitDims(resultType.getShape()); if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); @@ -632,9 +633,9 @@ PatternRewriter &rewriter) const override { RankedTensorType sourceType = insertSliceOp.getSourceType(); SmallVector offsets = insertSliceOp.getMixedOffsets(); - SmallVector sizes = insertSliceOp.getMixedSizes(); SmallVector strides = insertSliceOp.getMixedStrides(); - auto reassociation = getReassociationMapForFoldingUnitDims(sizes); + auto reassociation = + getReassociationMapForFoldingUnitDims(sourceType.getShape()); if (!reassociation || reassociation->size() == static_cast(sourceType.getRank())) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -1109,15 +1109,14 @@ /// but non-zero offsets are not handled by SPIR-V backend at this point (and /// potentially cannot be handled). std::optional> -getReassociationMapForFoldingUnitDims(ArrayRef mixedSizes) { +getReassociationMapForFoldingUnitDims(ArrayRef targetShape) { SmallVector reassociation; ReassociationIndices curr; - for (const auto &it : llvm::enumerate(mixedSizes)) { + for (const auto &it : llvm::enumerate(targetShape)) { auto dim = it.index(); auto size = it.value(); curr.push_back(dim); - auto attr = size.dyn_cast(); - if (attr && attr.cast().getInt() == 1) + if (size == 1) continue; reassociation.emplace_back(ReassociationIndices{}); std::swap(reassociation.back(), curr); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -489,6 +489,18 @@ // ----- +func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> { + %0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32> + return %0 : tensor<1x3x3xf32> +} +// CHECK-LABEL: func @rank_reduced_extract_slice +// CHECK: %[[SLICE:.+]] = tensor.extract_slice +// CHECK-SAME: tensor<1x1x3x1x3xf32> to tensor<3x3xf32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]] +// CHECK: return %[[RESULT]] + +// ----- + func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> { %0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32> return %0 : tensor<1x3xf32> @@ -501,6 +513,18 @@ // ----- +func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> { + %0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32> + return %0 : tensor<1x1x3x1x3xf32> +} +// CHECK-LABEL: func @rank_reduced_insert_slice +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]] +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]] +// CHECK-SAME: tensor<3x3xf32> into tensor<1x1x3x1x3xf32> +// CHECK: return %[[RESULT]] + +// ----- + #accesses = [ affine_map<(i, j, k, l, m) -> (i, k, m)>, affine_map<(i, j, k, l, m) -> ()>,