Index: mlir/lib/Dialect/Tensor/IR/TensorOps.cpp =================================================================== --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -229,11 +229,58 @@ } }; +/// Fold tensor.cast into tesor.extract_slice producer. +/// Example: +/// ``` +/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : +/// tensor<128x512xf32> to tensor +/// %1 = tensor.cast %0 : tensor to tensor<16x512xf32> +/// ``` +/// -> +/// ``` +/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] : +/// tensor<128x512xf32> to tensor<16x512xf32> +/// ``` +struct TensorCastExtractSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CastOp tensorCast, + PatternRewriter &rewriter) const final { + auto extractOperand = + tensorCast.getOperand().getDefiningOp(); + + if (!extractOperand || !canFoldIntoProducerOp(tensorCast) || + tensorCast.getType().getShape() == + tensorCast.source().getType().cast().getShape()) + return failure(); + + SmallVector sizes = extractOperand.getMixedSizes(); + auto dimMask = computeRankReductionMask( + extractFromI64ArrayAttr(extractOperand.static_sizes()), + extractOperand.getType().getShape()); + size_t dimIndex = 0; + for (size_t i = 0, e = sizes.size(); i < e; i++) { + if (dimMask && dimMask->count(i)) + continue; + int64_t dim = tensorCast.getType().getShape()[dimIndex++]; + if (ShapedType::isDynamic(dim)) + continue; + sizes[i] = rewriter.getIndexAttr(dim); + } + + rewriter.replaceOpWithNewOp( + tensorCast, tensorCast.getType().cast(), + extractOperand.source(), extractOperand.getMixedOffsets(), sizes, + extractOperand.getMixedStrides()); + return success(); + } +}; + } // namespace void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// Index: mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir =================================================================== --- mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -30,9 +30,6 @@ return %3 : tensor } -// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)> -// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> - // CHECK: func @matmul_tensors( // CHECK-SAME: %[[A:[0-9a-z]*]]: tensor // CHECK-SAME: %[[B:[0-9a-z]*]]: tensor @@ -40,26 +37,20 @@ // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[dA0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor // CHECK-DAG: %[[dA1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor // CHECK-DAG: %[[dB0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor // CHECK-DAG: %[[dB1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor // CHECK: scf.for %[[I:[0-9a-z]*]] -// CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]] -// CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor to tensor -// CHECK: %[[castA:.*]] = tensor.cast %[[stA]] : tensor to tensor<2x?xf32> +// CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor to tensor<2x?xf32> // CHECK: scf.for %[[J:[0-9a-z]*]] // CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]] // CHECK-DAG: %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor to tensor<4x3xf32> // CHECK-DAG: %[[stF:.*]] = tensor.extract_slice %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor to tensor<2x3xf32> // // slices of the producing matmul. -// CHECK: %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]] -// CHECK: %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor to tensor -// CHECK: %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1] : tensor to tensor -// CHECK-DAG: %[[castC:.+]] = tensor.cast %[[stC]] : tensor to tensor<2x4xf32> -// CHECK-DAG: %[[castB:.+]] = tensor.cast %[[stB2]] : tensor to tensor -// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[castA]], %[[castB]] : tensor<2x?xf32>, tensor) outs(%[[castC]] : tensor<2x4xf32>) -> tensor<2x4xf32> +// CHECK-DAG: %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], 4] [1, 1] : tensor to tensor +// CHECK-DAG: %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor to tensor<2x4xf32> +// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor) outs(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32> // CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]] Index: mlir/test/Dialect/Tensor/canonicalize.mlir =================================================================== --- mlir/test/Dialect/Tensor/canonicalize.mlir +++ mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1401,3 +1401,27 @@ // CHECK: return %[[RES]] : tensor return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @cast_extract_slice +func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index) + -> tensor<16x512xf32> { +// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32> + %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor + %1 = tensor.cast %0 : tensor to tensor<16x512xf32> +// CHECK: return %[[E]] : tensor<16x512xf32> + return %1 : tensor<16x512xf32> +} + +// ----- + +// CHECK-LABEL: func @cast_extract_slice_rank_reduce +func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index) + -> tensor<16xf32> { +// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32> + %0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor + %1 = tensor.cast %0 : tensor to tensor<16xf32> +// CHECK: return %[[E]] : tensor<16xf32> + return %1 : tensor<16xf32> +}