diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -309,6 +309,47 @@ return success(); } }; + +class VectorTransferReadOpConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransferReadOp readOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (readOp.getShapedType().isa()) + return failure(); + vector::TransferReadOp::Adaptor adaptor(operands, + readOp->getAttrDictionary()); + rewriter.replaceOpWithNewOp( + readOp, readOp.getType(), adaptor.source(), adaptor.indices(), + adaptor.permutation_map(), adaptor.padding(), adaptor.mask(), + adaptor.in_bounds()); + return success(); + } +}; + +class VectorTransferWriteOpConverter + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::TransferWriteOp writeOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + if (writeOp.getShapedType().isa()) + return failure(); + vector::TransferWriteOp::Adaptor adaptor(operands, + writeOp->getAttrDictionary()); + rewriter.create( + writeOp.getLoc(), adaptor.vector(), adaptor.source(), adaptor.indices(), + adaptor.permutation_map(), + adaptor.in_bounds() ? adaptor.in_bounds() : ArrayAttr()); + rewriter.replaceOp(writeOp, adaptor.source()); + return success(); + } +}; } // namespace namespace { @@ -332,10 +373,10 @@ return typeConverter.isLegal(op); }; target.addDynamicallyLegalDialect(isLegalOperation); - target.addDynamicallyLegalOp(isLegalOperation); + target.addDynamicallyLegalOp(isLegalOperation); RewritePatternSet patterns(&context); - patterns.add(patterns.getContext()); populateLinalgBufferizePatterns(typeConverter, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) @@ -359,7 +400,10 @@ BufferizeTensorReshapeOp, BufferizeTensorReshapeOp, ExtractSliceOpConverter, - InsertSliceOpConverter + InsertSliceOpConverter, + VectorTransferReadOpConverter, + VectorTransferWriteOpConverter >(typeConverter, patterns.getContext()); // clang-format on + patterns.add(patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -299,3 +299,20 @@ // CHECK: %[[OUT_TENSOR:.*]] = memref.tensor_load %[[OUT]] : memref<4x?x?x?xf32> // CHECK: return %[[OUT_TENSOR]] : tensor<4x?x?x?xf32> // CHECK: } + + +// ----- + +// CHECK-LABEL: func @vector_transfer +func @vector_transfer(%in: tensor<4xf32>, %out: tensor<4xf32>) { + %c0 = constant 0 : index + %cst = constant 0.000000e+00 : f32 + %read = vector.transfer_read %in[%c0], %cst {in_bounds = [true]} + : tensor<4xf32>, vector<4xf32> + %tanh = math.tanh %read : vector<4xf32> + %write = vector.transfer_write %tanh, %out[%c0] {in_bounds = [true]} + : vector<4xf32>, tensor<4xf32> + return + // CHECK: vector.transfer_read {{.*}} : memref<4xf32>, vector<4xf32> + // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, memref<4xf32> +}