diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1305,16 +1305,29 @@ build(b, result, source, dest, offsetValues, sizeValues, strideValues); } +static SliceVerificationResult +verifyInsertSliceOp(ShapedType srcType, ShapedType dstType, + ArrayAttr staticOffsets, ArrayAttr staticSizes, + ArrayAttr staticStrides, + ShapedType *expectedType = nullptr) { + // insert_slice is the inverse of extract_slice, use the same type inference. + auto expected = ExtractSliceOp::inferRankReducedResultType( + srcType.getRank(), dstType.cast(), + extractFromI64ArrayAttr(staticOffsets), + extractFromI64ArrayAttr(staticSizes), + extractFromI64ArrayAttr(staticStrides)) + .cast(); + if (expectedType) + *expectedType = expected; + return isRankReducedType(expected, srcType); +} + /// Verifier for InsertSliceOp. LogicalResult InsertSliceOp::verify() { - // insert_slice is the inverse of extract_slice, use the same type inference. - auto expectedType = ExtractSliceOp::inferRankReducedResultType( - getSourceType().getRank(), getType(), - extractFromI64ArrayAttr(static_offsets()), - extractFromI64ArrayAttr(static_sizes()), - extractFromI64ArrayAttr(static_strides())); + ShapedType expectedType; auto result = - isRankReducedType(expectedType.cast(), getSourceType()); + verifyInsertSliceOp(getSourceType(), getType(), static_offsets(), + static_sizes(), static_strides(), &expectedType); return produceSliceErrorMsg(result, *this, expectedType); } @@ -1446,12 +1459,20 @@ if (!sourceCastSource && !destCastSource) return failure(); + auto src = (sourceCastSource ? *sourceCastSource : insertSliceOp.source()); + auto dst = (destCastSource ? *destCastSource : insertSliceOp.dest()); + + auto srcType = src.getType().cast(); + auto dstType = dst.getType().cast(); + if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.static_offsets(), + insertSliceOp.static_sizes(), + insertSliceOp.static_strides()) != + SliceVerificationResult::Success) + return failure(); + Value replacement = rewriter.create( - insertSliceOp.getLoc(), - (sourceCastSource ? *sourceCastSource : insertSliceOp.source()), - (destCastSource ? *destCastSource : insertSliceOp.dest()), - insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), - insertSliceOp.getMixedStrides()); + insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(), + insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); if (replacement.getType() != insertSliceOp.getType()) { replacement = rewriter.create( diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -1231,3 +1231,18 @@ // CHECK-NEXT: [[T:%.*]] = arith.constant dense<1.000000e+00> : tensor<4xf32> // CHECK-NEXT: return [[T]] : tensor<4xf32> } + +// ----- + +// There was an issue in cast + insert_slice folding generating invalid ir. +// https://github.com/llvm/llvm-project/issues/53099 +// CHECK-LABEL: func @insert_slice_cast +func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index, %arg7 : index) -> tensor { + // CHECK: %[[CAST:.*]] = tensor.cast %{{.*}} : tensor<1x?xf32> to tensor + %0 = tensor.cast %arg0 : tensor<1x?xf32> to tensor + // CHECK: %[[RES:.*]] = tensor.insert_slice %[[CAST]] + // CHECK-SAME: : tensor into tensor + %1 = tensor.insert_slice %0 into %arg1[%arg2, %arg3] [%arg4, %arg5] [%arg6, %arg7] : tensor into tensor + // CHECK: return %[[RES]] : tensor + return %1 : tensor +}