diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -149,6 +149,23 @@ } }; +/// Conversion pattern that replaces `linalg.tensor_reshape` with +/// `linalg.reshape`. +class BufferizeTensorReshapeOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TensorReshapeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + linalg::TensorReshapeOpAdaptor adaptor(operands, op->getAttrDictionary()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()).cast(), + adaptor.src(), adaptor.reassociation()); + return success(); + } +}; + /// Conversion pattern that bufferizes `linalg.fill` operation. class BufferizeFillOp : public OpConversionPattern { public: @@ -336,6 +353,7 @@ BufferizeAnyLinalgOp, BufferizeFillOp, BufferizeInitTensorOp, + BufferizeTensorReshapeOp, SubTensorOpConverter, SubTensorInsertOpConverter >(typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -278,3 +278,18 @@ %0 = linalg.fill(%arg0, %c0) : tensor, f32 -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @bufferize_tensor_reshape( +// CHECK-SAME: %[[IN:.*]]: tensor<4x5xf32> +func @bufferize_tensor_reshape(%arg0: tensor<4x5xf32>) -> tensor<20xf32> { + %out = linalg.tensor_reshape %arg0 [[0, 1]] : + tensor<4x5xf32> into tensor<20xf32> + return %out : tensor<20xf32> +} +// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x5xf32> +// CHECK: %[[RESHAPE:.*]] = linalg.reshape %[[MEMREF]] {{\[}}[0, 1]] +// CHECK-SAME: : memref<4x5xf32> into memref<20xf32> +// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[RESHAPE]] : memref<20xf32> +// CHECK: return %[[TENSOR]]