diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1279,12 +1279,19 @@ getConstantIntValue(insertSliceOp.getMixedSizes()[i])) newSrcShape[i] = *constInt; } + RankedTensorType newSrcType = RankedTensorType::get(newSrcShape, srcType.getElementType()); - if (srcType == newSrcType) + if (srcType == newSrcType || + !preservesStaticInformation(srcType, newSrcType) || + !tensor::CastOp::areCastCompatible(srcType, newSrcType)) return failure(); - // srcType and newSrcType are different. Insert a cast. + // newSrcType is: + // 1) Different from srcType. + // 2) "More static" than srcType. + // 3) Cast-compatible with srcType. + // Insert the cast. Value cast = rewriter.create( insertSliceOp.getLoc(), newSrcType, insertSliceOp.source()); rewriter.replaceOpWithNewOp( diff --git a/mlir/test/Dialect/Linalg/comprehensive-foo.mlir b/mlir/test/Dialect/Linalg/comprehensive-foo.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-foo.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-opt %s -canonicalize +// R-UN: mlir-opt %s -linalg-comprehensive-module-bufferize=allow-return-memref -debug -verify-each=0 +// | FileCheck %s + + +// // CHECK-LABEL: func @nested_extract_slice_and_insert +// func @nested_extract_slice_and_insert( +// %A : tensor, +// %B : tensor {linalg.inplaceable = true}, +// %C : tensor {linalg.inplaceable = true}, +// %idx : index) +// -> (tensor, tensor, tensor) +// { +// %f0 = arith.constant 0.0 : f32 + +// %sC = tensor.extract_slice %C[0, 0][%idx, %idx][1, 1] : tensor to tensor +// %ssC = tensor.extract_slice %sC[0, 0][4, 4][1, 1] : tensor to tensor<4x4xf32> +// %FC = linalg.fill(%f0, %ssC) : f32, tensor<4x4xf32> -> tensor<4x4xf32> +// %rsC = tensor.insert_slice %FC into %sC[0, 0][12345, 67890][1, 1] : tensor<4x4xf32> into tensor +// %rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor into tensor + +// return %rC, %rC, %rC: tensor, tensor, tensor +// } + + +// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop +func @folding_incorrect_ir_triggers_infinite_loop( + %A : tensor<4x4xf32>, %C : tensor) -> tensor { + %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] : + tensor<4x4xf32> into tensor + return %rC: tensor +} diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -559,3 +559,13 @@ // CHECK: return %[[INSERT]] return %1 : tensor } + +// ----- + +// CHECK-LABEL: func @folding_incorrect_ir_triggers_infinite_loop +func @folding_incorrect_ir_triggers_infinite_loop( + %A : tensor<4x4xf32>, %C : tensor) -> tensor { + %rC = tensor.insert_slice %A into %C[0, 0][12345, 67890][1, 1] : + tensor<4x4xf32> into tensor + return %rC: tensor +}