diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -94,6 +94,11 @@ bufferization::getBufferType(castOp.getResult(), options); if (failed(resultMemRefType)) return failure(); + if (resultBuffer->getType() == *resultMemRefType) { + // This cast is a no-op. + replaceOpWithBufferizedValues(rewriter, op, *resultBuffer); + return success(); + } // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize.mlir @@ -199,3 +199,12 @@ %3 = tensor.extract %0[%pos3] : tensor<100xf32> return %2, %3 : f32, f32 } + +// ----- + +// CHECK-LABEL: func @from_unranked_to_unranked +func.func @from_unranked_to_unranked(%arg0: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: return %arg{{.*}} : tensor<*xi32> + %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32> + return %0 : tensor<*xi32> +}