diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1085,7 +1085,24 @@ } }; -/// Fold tensor_casts with insert_slice operations. +/// Fold tensor_casts with insert_slice operations. If the source or destination +/// tensor is a tensor_cast that removes static type information, the cast is +/// folded into the insert_slice operation. E.g.: +/// +/// ```mlir +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor +/// %2 = tensor.insert_slice %1 into ... : tensor into ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ... +/// ``` +/// +/// Note: When folding a cast on the destination tensor, the result of the +/// insert_slice operation is casted to ensure that the type of the result did +/// not change. struct InsertSliceOpCastFolder final : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1123,12 +1140,63 @@ return success(); } }; + +/// If additional static type information can be deduced from a insert_slice's +/// size operands, insert an explicit cast of the op's source operand. This +/// enables other canonicalization patterns that are matching for tensor_cast +/// ops such as `ForOpTensorCastFolder` in SCF. +/// +/// Example: +/// +/// ```mlir +/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1] +/// : tensor into ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %tmp = tensor.cast %0 : tensor to tensor<64x64xf32> +/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1] +/// : tensor<64x64xf32> into ... +/// ``` +struct InsertSliceOpSourceCastInserter final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(InsertSliceOp insertSliceOp, + PatternRewriter &rewriter) const override { + RankedTensorType srcType = insertSliceOp.getSourceType(); + if (srcType.getRank() != insertSliceOp.getType().getRank()) + return failure(); + SmallVector newSrcShape(srcType.getShape().begin(), + srcType.getShape().end()); + for (int64_t i = 0; i < srcType.getRank(); ++i) { + if (Optional constInt = + getConstantIntValue(insertSliceOp.getMixedSizes()[i])) + newSrcShape[i] = *constInt; + } + RankedTensorType newSrcType = + RankedTensorType::get(newSrcShape, srcType.getElementType()); + if (srcType == newSrcType) + return failure(); + + // srcType and newSrcType are different. Insert a cast. + Value cast = rewriter.create( + insertSliceOp.getLoc(), newSrcType, insertSliceOp.source()); + rewriter.replaceOpWithNewOp( + insertSliceOp, cast, insertSliceOp.dest(), + insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), + insertSliceOp.getMixedStrides()); + return success(); + } +}; } // namespace void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add( - context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -666,7 +666,7 @@ return %res : tensor<1024x1024xf32> } - +// ----- // CHECK-LABEL: @cond_prop func @cond_prop(%arg0 : i1) -> index { @@ -707,6 +707,8 @@ // CHECK-NEXT: return %[[if]] : index // CHECK-NEXT:} +// ----- + // CHECK-LABEL: @replace_if_with_cond1 func @replace_if_with_cond1(%arg0 : i1) -> (i32, i1) { %true = constant true @@ -729,6 +731,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[if]], %arg0 : i32, i1 +// ----- + // CHECK-LABEL: @replace_if_with_cond2 func @replace_if_with_cond2(%arg0 : i1) -> (i32, i1) { %true = constant true @@ -753,6 +757,7 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[if]], %[[toret]] : i32, i1 +// ----- // CHECK-LABEL: @replace_if_with_cond3 func @replace_if_with_cond3(%arg0 : i1, %arg2: i64) -> (i32, i64) { @@ -774,6 +779,7 @@ // CHECK-NEXT: } // CHECK-NEXT: return %[[if]], %arg1 : i32, i64 +// ----- // CHECK-LABEL: @while_cond_true func @while_cond_true() { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -366,10 +366,11 @@ } // CHECK-LABEL: func @insert_slice_canonicalize // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor -// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[ARG0]] +// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG0]] : tensor to tensor<4x1x?xf32> +// CHECK: %[[RESULT:.+]] = tensor.insert_slice %[[CAST]] // CHECK-SAME: [0, %{{.+}}, 1] [4, 1, %{{.+}}] [1, 1, 1] -// CHECK-SAME: : tensor into tensor -// CHEKC: return %[[RESULT]] +// CHECK-SAME: : tensor<4x1x?xf32> into tensor +// CHECK: return %[[RESULT]] // ----- @@ -517,3 +518,17 @@ %2 = tensor.dim %0, %c1 : tensor return %1, %2: index, index } + +// ----- + +// CHECK-LABEL: func @insert_tensor_cast_on_insert_slice_src( +// CHECK-SAME: %[[arg0:.*]]: tensor, %[[arg1:.*]]: tensor +// CHECK: %[[cast:.*]] = tensor.cast %[[arg0]] : tensor to tensor<64x5x64xf32> +// CHECK: %[[r:.*]] = tensor.insert_slice %[[cast]] into %[[arg1]][0, 1, 2] [64, 5, 64] [1, 1, 1] : tensor<64x5x64xf32> into tensor +// CHECK: return %[[r]] +func @insert_tensor_cast_on_insert_slice_src( + %arg0 : tensor, %arg1 : tensor) -> tensor { + %r = tensor.insert_slice %arg0 into %arg1[0, 1, 2] [64, 5, 64] [1, 1, 1] + : tensor into tensor + return %r : tensor +}