diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1570,6 +1570,39 @@ return success(); } }; + +/// Fold expand_shape(extract_slice) ops that cancel itself out. +struct FoldExpandOfRankReducingExtract + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExpandShapeOp expandShapeOp, + PatternRewriter &rewriter) const override { + RankedTensorType resultType = expandShapeOp.getResultType(); + auto extractSliceOp = + expandShapeOp.getSrc().getDefiningOp(); + if (!extractSliceOp) + return failure(); + RankedTensorType srcType = extractSliceOp.getSourceType(); + + // Only cases where the ExpandShapeOp can be folded away entirely are + // supported. Moreover, only simple cases where the resulting ExtractSliceOp + // has no rank-reduction anymore are supported at the moment. + RankedTensorType nonReducingExtractType = ExtractSliceOp::inferResultType( + srcType, extractSliceOp.getStaticOffsets(), + extractSliceOp.getStaticSizes(), extractSliceOp.getStaticStrides()); + if (nonReducingExtractType != resultType) + return failure(); + + SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); + SmallVector mixedSizes = extractSliceOp.getMixedSizes(); + SmallVector mixedStrides = extractSliceOp.getMixedStrides(); + rewriter.replaceOpWithNewOp( + expandShapeOp, extractSliceOp.getSource(), mixedOffsets, mixedSizes, + mixedStrides); + return success(); + } +}; } // namespace void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1578,7 +1611,7 @@ ComposeExpandOfCollapseOp, FoldReshapeWithConstant, FoldReshapeWithFromElements, FoldDimOfExpandShape, - FoldDimOfCollapseShape>(context); + FoldDimOfCollapseShape, FoldExpandOfRankReducingExtract>(context); } void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1678,3 +1678,23 @@ %1 = tensor.collapse_shape %0 [[0, 1]] : tensor<1x?xf32> into tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @expand_shape_of_rank_reducing_extract( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK-DAG: %[[extract1:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor to tensor +// CHECK-DAG: %[[extract2:.*]] = tensor.extract_slice %{{.*}}[0, 0, 0, 0] [%{{.*}}, 1, 1, 5] [1, 1, 1, 1] : tensor to tensor +// CHECK: return %[[extract1]], %[[extract2]] +func.func @expand_shape_of_rank_reducing_extract( + %t: tensor, %idx: index) + -> (tensor, tensor) +{ + %0 = tensor.extract_slice %t[0, 0, 0, 0][%idx, 1, 1, 5][1, 1, 1, 1] + : tensor to tensor + %1 = tensor.expand_shape %0 [[0], [1, 2], [3]] + : tensor into tensor + %2 = tensor.expand_shape %0 [[0, 1], [2], [3]] + : tensor into tensor + return %1, %2 : tensor, tensor +}