diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2600,8 +2600,9 @@ auto vecType = storeOp.getVectorType(); if (vecType.getNumElements() != 1) return failure(); + SmallVector indices(vecType.getRank(), 0); Value extracted = rewriter.create( - storeOp.getLoc(), storeOp.valueToStore(), ArrayRef{1}); + storeOp.getLoc(), storeOp.valueToStore(), indices); rewriter.replaceOpWithNewOp( storeOp, extracted, storeOp.base(), storeOp.indices()); return success(); diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -1,8 +1,9 @@ // RUN: mlir-opt %s -test-vector-transfer-lowering-patterns -canonicalize -split-input-file | FileCheck %s // CHECK-LABEL: func @vector_transfer_ops_0d( -// CHECK-SAME: %[[MEM:.*]]: memref) { -func @vector_transfer_ops_0d(%M: memref) { +// CHECK-SAME: %[[MEM:.*]]: memref +// CHECK-SAME: %[[VV:.*]]: vector<1x1x1xf32> +func @vector_transfer_ops_0d(%M: memref, %v: vector<1x1x1xf32>) { %f0 = constant 0.0 : f32 // CHECK-NEXT: %[[V:.*]] = memref.load %[[MEM]][] : memref @@ -13,6 +14,10 @@ vector.transfer_write %0, %M[] {permutation_map = affine_map<()->(0)>} : vector<1xf32>, memref +// CHECK-NEXT: %[[VV:.*]] = vector.extract %arg1[0, 0, 0] : vector<1x1x1xf32> +// CHECK-NEXT: memref.store %[[VV]], %[[MEM]][] : memref + vector.store %v, %M[] : memref, vector<1x1x1xf32> + return }