diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1279,12 +1279,19 @@ getConstantIntValue(insertSliceOp.getMixedSizes()[i])) newSrcShape[i] = *constInt; } + RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); - if (srcType == newSrcType) + if (srcType == newSrcType || + !preservesStaticInformation(srcType, newSrcType) || + !tensor::CastOp::areCastCompatible(srcType, newSrcType)) return failure(); - // srcType and newSrcType are different. Insert a cast. + // newSrcType is: + // 1) Different from srcType. + // 2) "More static" than srcType. + // 3) Cast-compatible with srcType. + // Insert the cast. Value cast = rewriter.create( insertSliceOp.getLoc(), newSrcType, insertSliceOp.source()); rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -559,3 +559,13 @@ // CHECK: return %[[INSERT]] return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop +func @folding_incorrect_ir_triggers_infinite_loop( + %A : tensor<4x4xf32>, %C : tensor) -> tensor { + %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] : + tensor<4x4xf32> into tensor + return %rC: tensor +}