diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -966,6 +966,10 @@ struct ReshapeOpInterface : public BufferizableOpInterface::ExternalModel { + bool bufferizesToAllocation(Operation *op, OpResult opResult) const { + return true; + } + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { if (&opOperand == &op->getOpOperand(1) /* shape */) @@ -986,12 +990,39 @@ LogicalResult bufferize(Operation *op, RewriterBase &rewriter, const BufferizationOptions &options) const { auto reshapeOp = cast(op); + Location loc = op->getLoc(); + + // Get bufferized src and shape operands. FailureOr srcBuffer = getBuffer(rewriter, reshapeOp.getSource(), options); FailureOr shapeBuffer = getBuffer(rewriter, reshapeOp.getShape(), options); if (failed(srcBuffer) || failed(shapeBuffer)) return failure(); + + // Copy src buffer into new allocation if it doesn't have identity layout. + // TODO: This could be extended to layouts that are not the identity layout + // but still allow zero-copy reshape (such as a 1D input with an + // offset). + auto bufferType = cast(srcBuffer->getType()); + if (!bufferType.getLayout().isIdentity()) { + bool dealloc = shouldDeallocateOpResult( + cast(reshapeOp.getResult()), options); + FailureOr tensorAlloc = + allocateTensorForShapedValue(rewriter, loc, reshapeOp.getSource(), + /*escape=*/!dealloc, options, + /*copy=*/true); + if (failed(tensorAlloc)) + return failure(); + auto memrefType = + MemRefType::get(bufferType.getShape(), bufferType.getElementType(), + AffineMap(), bufferType.getMemorySpace()); + auto toMemrefOp = rewriter.create( + op->getLoc(), memrefType, *tensorAlloc); + srcBuffer = toMemrefOp.getResult(); + } + + // Convert this op to memref equivalent on buffers. auto maybeResultMemRefType = bufferization::getBufferType(reshapeOp.getResult(), options); if (failed(maybeResultMemRefType)) diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -416,3 +416,20 @@ // CHECK: return %[[RESHAPED]] return %reshaped : tensor<2x2x5xf32> } + +// ----- + +// CHECK-LABEL: func @tensor.reshape_non_identity( +// CHECK-SAME: %[[ARG0:.*]]: memref>) -> memref<2x2x5xf32> { +func.func @tensor.reshape_non_identity(%t1 : tensor) -> tensor<2x2x5xf32> { + // CHECK: %[[SHAPE:.*]] = memref.get_global @{{.*}} : memref<3xi64> + %shape = arith.constant dense<[2, 2, 5]> : tensor<3xi64> + + // CHECK: %[[ALLOC:.*]] = memref.alloc(%{{.*}}) {alignment = 64 : i64} : memref + // CHECK: memref.copy %[[ARG0]], %[[ALLOC]] : memref> to memref + // CHECK: %[[RESHAPED:.*]] = memref.reshape %[[ALLOC]](%[[SHAPE]]) : (memref, memref<3xi64>) -> memref<2x2x5xf32> + %reshaped = tensor.reshape %t1(%shape) : (tensor, tensor<3xi64>) -> tensor<2x2x5xf32> + + // CHECK: return %[[RESHAPED]] + return %reshaped : tensor<2x2x5xf32> +}