diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2590,6 +2590,24 @@ } }; +/// Replace a scalar vector.store with a memref.store. +struct VectorStoreToMemrefStoreLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto vecType = storeOp.getVectorType(); + if (vecType.getNumElements() != 1) + return failure(); + Value extracted = rewriter.create( + storeOp.getLoc(), storeOp.valueToStore(), ArrayRef{1}); + rewriter.replaceOpWithNewOp( + storeOp, extracted, storeOp.base(), storeOp.indices()); + return success(); + } +}; + /// Progressive lowering of transfer_write. This pattern supports lowering of /// `vector.transfer_write` to `vector.store` if all of the following hold: /// - Stride of most minor memref dimension must be 1. @@ -2611,7 +2629,7 @@ return failure(); // Permutations are handled by VectorToSCF or // populateVectorTransferPermutationMapLoweringPatterns. - if (!write.permutation_map().isMinorIdentity()) + if (!write.isZeroD() && !write.permutation_map().isMinorIdentity()) return failure(); auto memRefType = write.getShapedType().dyn_cast(); if (!memRefType) @@ -2766,6 +2784,9 @@ LogicalResult matchAndRewrite(vector::TransferWriteOp op, PatternRewriter &rewriter) const override { + if (op.isZeroD()) + return failure(); + SmallVector permutation; AffineMap map = op.permutation_map(); if (map.isMinorIdentity()) @@ -3582,7 +3603,9 @@ patterns.add(patterns.getContext(), maxTransferRank); - patterns.add(patterns.getContext()); + patterns + .add( + patterns.getContext()); } void mlir::vector::populateVectorUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -1,5 +1,23 @@ // RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s +// CHECK-LABEL: func @vector_transfer_ops_0d( +// CHECK-SAME: %[[MEM:.*]]: memref) { +func @vector_transfer_ops_0d(%M: memref) { + %f0 = constant 0.0 : f32 + +// CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref + %0 = vector.transfer_read %M[], %f0 {permutation_map = affine_map<(d0)->(0)>} : + memref, vector<1xf32> + +// CHECK-NEXT: memref.store %[[V]], %[[MEM]][] : memref + vector.transfer_write %0, %M[] {permutation_map = affine_map<(d0)->(0)>} : + vector<1xf32>, memref + + return +} + +// ----- + // transfer_read/write are lowered to vector.load/store // CHECK-LABEL: func @transfer_to_load( // CHECK-SAME: %[[MEM:.*]]: memref<8x8xf32>,