diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -743,6 +743,54 @@ } }; +/// Bufferization of tensor.reshape. Replace with memref.reshape. +struct ReshapeOpInterface + : public BufferizableOpInterface::ExternalModel { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + if (&opOperand == &op->getOpOperand(1) /* shape */) + return true; + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {op->getOpResult(0)}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::Equivalent; + } + + LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + BufferizationState &state) const { + auto reshapeOp = cast(op); + auto &srcOperand = reshapeOp->getOpOperand(0); + auto srcBuffer = state.getBuffer(rewriter, srcOperand); + if (failed(srcBuffer)) + return failure(); + + auto &shapeOperand = reshapeOp->getOpOperand(1); + auto shapeBuffer = state.getBuffer(rewriter, shapeOperand); + if (failed(shapeBuffer)) + return failure(); + + auto resultTensorType = reshapeOp.getResult().getType().cast(); + auto resultMemRefType = getMemRefType(resultTensorType, state.getOptions()); + + replaceOpWithNewBufferizedOp( + rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); + return success(); + } +}; + } // namespace } // namespace tensor } // namespace mlir @@ -761,5 +809,6 @@ InsertOp::attachInterface(*ctx); InsertSliceOp::attachInterface(*ctx); RankOp::attachInterface(*ctx); + ReshapeOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -408,3 +408,30 @@ %ret = tensor.collapse_shape %0 [[0, 1, 2]] : tensor<4x2x1xf32> into tensor<8xf32> return %ret: tensor<8xf32> } + +// CHECK-LABEL: func @tensor.reshape( +// CHECK-SAME: %[[t1:.*]]: tensor +func.func @tensor.reshape(%t1: tensor) -> tensor<2x2x5xf32> { + // CHECK: %[[m1:.*]] = bufferization.to_memref %[[t1]] : memref + + // CHECK: %[[two:.*]] = arith.constant 2 : i64 + %two = arith.constant 2 : i64 + // CHECK: %[[five:.*]] = arith.constant 5 : i64 + %five = arith.constant 5 : i64 + + // CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 128 : i64} : memref<3xi64> + // CHECK: %[[zero_idx:.*]] = arith.constant 0 : index + // CHECK: %[[one_idx:.*]] = arith.constant 1 : index + // CHECK: %[[two_idx:.*]] = arith.constant 2 : index + // CHECK: memref.store %[[two]], %[[alloc]][%[[zero_idx]]] : memref<3xi64> + // CHECK: memref.store %[[two]], %[[alloc]][%[[one_idx]]] : memref<3xi64> + // CHECK: memref.store %[[five]], %[[alloc]][%[[two_idx]]] : memref<3xi64> + %shape = tensor.from_elements %two, %two, %five : tensor<3xi64> + + // CHECK: %[[reshaped:.*]] = memref.reshape %[[m1]](%[[alloc]]) : (memref, memref<3xi64>) -> memref<2x2x5xf32> + %reshaped = tensor.reshape %t1(%shape) : (tensor, tensor<3xi64>) -> tensor<2x2x5xf32> + + // CHECK: %[[r:.*]] = bufferization.to_tensor %[[reshaped]] + // CHECK: return %[[r]] + return %reshaped : tensor<2x2x5xf32> +}