diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -992,13 +992,28 @@ getBuffer(rewriter, reshapeOp.getShape(), options); if (failed(srcBuffer) || failed(shapeBuffer)) return failure(); - auto resultMemRefType = getMemRefTypeWithStaticIdentityLayout( - reshapeOp.getResult().getType(), - cast(srcBuffer->getType()).getMemorySpace()); + auto maybeResultMemRefType = + bufferization::getBufferType(reshapeOp.getResult(), options); + if (failed(maybeResultMemRefType)) + return failure(); replaceOpWithNewBufferizedOp( - rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); + rewriter, op, maybeResultMemRefType.value(), *srcBuffer, *shapeBuffer); return success(); } + + FailureOr + getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const DenseMap &fixedTypes) const { + auto reshapeOp = cast(op); + assert(value == reshapeOp.getResult() && "unexpected value provided"); + auto maybeSourceBufferType = bufferization::getBufferType( + reshapeOp.getSource(), options, fixedTypes); + if (failed(maybeSourceBufferType)) + return failure(); + return getMemRefTypeWithStaticIdentityLayout( + reshapeOp.getResult().getType(), + cast(maybeSourceBufferType.value()).getMemorySpace()); + } }; /// Analysis of ParallelInsertSliceOp.