diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -992,8 +992,8 @@ getBuffer(rewriter, reshapeOp.getShape(), options); if (failed(srcBuffer) || failed(shapeBuffer)) return failure(); - auto resultMemRefType = getMemRefType( - reshapeOp.getResult(), options, /*layout=*/{}, + auto resultMemRefType = getMemRefTypeWithStaticIdentityLayout( + reshapeOp.getResult().getType(), cast(srcBuffer->getType()).getMemorySpace()); replaceOpWithNewBufferizedOp( rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -398,3 +398,21 @@ // CHECK: } return } + +// ----- + +// CHECK-LABEL: func @tensor.reshape( +func.func @tensor.reshape() -> tensor<2x2x5xf32> { + // CHECK-DAG: %[[M1:.*]] = memref.cast %{{.*}} : memref<2x10xf32> to memref + %t1_static = arith.constant dense<0.> : tensor<2x10xf32> + %t1 = tensor.cast %t1_static : tensor<2x10xf32> to tensor + + // CHECK: %[[SHAPE:.*]] = memref.get_global @{{.*}} : memref<3xi64> + %shape = arith.constant dense<[2, 2, 5]> : tensor<3xi64> + + // CHECK: %[[RESHAPED:.*]] = memref.reshape %[[M1]](%[[SHAPE]]) : (memref, memref<3xi64>) -> memref<2x2x5xf32> + %reshaped = tensor.reshape %t1(%shape) : (tensor, tensor<3xi64>) -> tensor<2x2x5xf32> + + // CHECK: return %[[RESHAPED]] + return %reshaped : tensor<2x2x5xf32> +}