Index: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp =================================================================== --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -229,11 +229,58 @@ } }; +/// Fold tensor.cast into tesor.extract_slice producer. +/// Example: +/// ``` +/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : +/// tensor<128x512xf32> to tensor +/// %1 = tensor.cast %0 : tensor to tensor<16x512xf32> +/// ``` +/// -> +/// ``` +/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] : +/// tensor<128x512xf32> to tensor<16x512xf32> +/// ``` +struct TensorCastExtractSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CastOp tensorCast, + PatternRewriter &rewriter) const final { + auto extractOperand = + tensorCast.getOperand().getDefiningOp(); + + if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || + tensorCast.getType().getShape() == + tensorCast.source().getType().cast().getShape()) + return failure(); + + SmallVector sizes = extractOperand.getMixedSizes(); + auto dimMask = computeRankReductionMask( + extractFromI64ArrayAttr(extractOperand.static_sizes()), + extractOperand.getType().getShape()); + size_t dimIndex = 0; + for (size_t i = 0, e = sizes.size(); i < e; i++) { + if (dimMask && dimMask->count(i)) + continue; + int64_t dim = tensorCast.getType().getShape()[dimIndex++]; + if (ShapedType::isDynamic(dim)) + continue; + sizes[i] = rewriter.getIndexAttr(dim); + } + + rewriter.replaceOpWithNewOp( + tensorCast, tensorCast.getType().cast(), + extractOperand.source(), extractOperand.getMixedOffsets(), sizes, + extractOperand.getMixedStrides()); + return success(); + } +}; + } // namespace void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/Tensor/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tensor/canonicalize.mlir +++ mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1401,3 +1401,27 @@ // CHECK: return %[[RES]] : tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @cast_extract_slice +func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index) + -> tensor<16x512xf32> { +// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32> + %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor + %1 = tensor.cast %0 : tensor to tensor<16x512xf32> +// CHECK: return %[[E]] : tensor<16x512xf32> + return %1 : tensor<16x512xf32> +} + +// ----- + +// CHECK-LABEL: func @cast_extract_slice_rank_reduce +func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index) + -> tensor<16xf32> { +// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32> + %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor + %1 = tensor.cast %0 : tensor to tensor<16xf32> +// CHECK: return %[[E]] : tensor<16xf32> + return %1 : tensor<16xf32> +}