diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -600,13 +600,17 @@ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp, PatternRewriter &rewriter) const override { RankedTensorType resultType = sliceOp.getType(); - SmallVector offsets = sliceOp.getMixedOffsets(); - SmallVector sizes = sliceOp.getMixedSizes(); - SmallVector strides = sliceOp.getMixedStrides(); - auto reassociation = getReassociationMapForFoldingUnitDims(sizes); + SmallVector targetShape; + for (auto size : resultType.getShape()) + targetShape.push_back(rewriter.getIndexAttr(size)); + auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); if (!reassociation || reassociation->size() == static_cast(resultType.getRank())) return failure(); + + SmallVector offsets = sliceOp.getMixedOffsets(); + SmallVector strides = sliceOp.getMixedStrides(); + SmallVector sizes = sliceOp.getMixedSizes(); auto rankReducedType = tensor::ExtractSliceOp::inferCanonicalRankReducedResultType( reassociation->size(), sliceOp.getSourceType(), offsets, sizes, @@ -631,13 +635,14 @@ LogicalResult matchAndRewrite(InsertOpTy insertSliceOp, PatternRewriter &rewriter) const override { RankedTensorType sourceType = insertSliceOp.getSourceType(); - SmallVector offsets = insertSliceOp.getMixedOffsets(); - SmallVector sizes = insertSliceOp.getMixedSizes(); - SmallVector strides = insertSliceOp.getMixedStrides(); - auto reassociation = getReassociationMapForFoldingUnitDims(sizes); + SmallVector targetShape; + for (auto size : sourceType.getShape()) + targetShape.push_back(rewriter.getIndexAttr(size)); + auto reassociation = getReassociationMapForFoldingUnitDims(targetShape); if (!reassociation || reassociation->size() == static_cast(sourceType.getRank())) return failure(); + Location loc = insertSliceOp.getLoc(); tensor::CollapseShapeOp reshapedSource; { diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -489,6 +489,18 @@ // ----- +func.func @rank_reduced_extract_slice(%arg0: tensor<1x1x3x1x3xf32>) -> tensor<1x3x3xf32> { + %0 = tensor.extract_slice %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x1x3x1x3xf32> to tensor<1x3x3xf32> + return %0 : tensor<1x3x3xf32> +} +// CHECK-LABEL: func @rank_reduced_extract_slice +// CHECK: %[[SLICE:.+]] = tensor.extract_slice +// CHECK-SAME: tensor<1x1x3x1x3xf32> to tensor<3x3xf32> +// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[SLICE]] {{\[}}[0, 1], [2]] +// CHECK: return %[[RESULT]] + +// ----- + func.func @insert_slice_unit_dims(%arg0: tensor<1x3xf32>, %arg1: tensor<1x1xf32>) -> tensor<1x3xf32> { %0 = tensor.insert_slice %arg1 into %arg0[0, 2] [1, 1] [1, 1] : tensor<1x1xf32> into tensor<1x3xf32> return %0 : tensor<1x3xf32> @@ -501,6 +513,18 @@ // ----- +func.func @rank_reduced_insert_slice(%arg0: tensor<1x1x3x1x3xf32>, %arg1: tensor<1x3x3xf32>) -> tensor<1x1x3x1x3xf32> { + %0 = tensor.insert_slice %arg1 into %arg0[0, 0, 0, 0, 0] [1, 1, 3, 1, 3] [1, 1, 1, 1, 1] : tensor<1x3x3xf32> into tensor<1x1x3x1x3xf32> + return %0 : tensor<1x1x3x1x3xf32> +} +// CHECK-LABEL: func @rank_reduced_insert_slice +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %{{.+}} {{\[}}[0, 1], [2]] +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[RESHAPE]] +// CHECK-SAME: tensor<3x3xf32> into tensor<1x1x3x1x3xf32> +// CHECK: return %[[RESULT]] + +// ----- + #accesses = [ affine_map<(i, j, k, l, m) -> (i, k, m)>, affine_map<(i, j, k, l, m) -> ()>,