diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -2828,6 +2828,18 @@ // with broadasting. Otherwise we first want to permute the map. if (!newMap.isMinorIdentityWithBroadcasting()) return failure(); + + // TODO: support zero-dimension vectors natively. See: + // https://llvm.discourse.group/t/should-we-have-0-d-vectors/3097. + // In the meantime, lower these to a scalar load when they pop up. + if (reducedShapeRank == 0) { + Value newRead = rewriter.create( + op.getLoc(), originalVecType.getElementType(), op.source(), + op.indices()); + rewriter.replaceOpWithNewOp(op, originalVecType, + newRead); + return success(); + } SmallVector newShape = llvm::to_vector<4>( originalVecType.getShape().take_back(reducedShapeRank)); // Vector rank cannot be zero. Handled by TransferReadToVectorLoadLowering. diff --git a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-lowering.mlir @@ -228,6 +228,7 @@ #map3 = affine_map<(d0, d1) -> (d1, d0, 0, 0)> #map4 = affine_map<(d0, d1) -> (0, d1, 0, d0)> #map5 = affine_map<(d0, d1, d2, d3) -> (d2, d1, d3, d0)> +#map6 = affine_map<(d0, d1) -> (0)> // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, 0, 0)> // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, 0, d3)> @@ -235,7 +236,7 @@ // CHECK-LABEL: func @transfer_read_permutations func @transfer_read_permutations(%arg0 : memref, %arg1 : memref) -> (vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, - vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>) { + vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<8xf32>) { // CHECK-DAG: %[[CF0:.*]] = constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = constant 0 : index %cst = constant 0.000000e+00 : f32 @@ -275,9 +276,13 @@ // CHECK: vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]], %[[C0]]], %[[CF0]] : memref, vector<16x14x7x8xf32> // CHECK: vector.transpose %{{.*}}, [2, 1, 3, 0] : vector<16x14x7x8xf32> to vector<7x14x8x16xf32> - return %0, %1, %2, %3, %4, %5 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, + %6 = vector.transfer_read %arg0[%c0, %c0], %cst {permutation_map = #map6} : memref, vector<8xf32> +// CHECK: memref.load %{{.*}}[%[[C0]], %[[C0]]] : memref +// CHECK: vector.broadcast %{{.*}} : f32 to vector<8xf32> + + return %0, %1, %2, %3, %4, %5, %6 : vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, vector<7x14x8x16xf32>, - vector<7x14x8x16xf32> + vector<7x14x8x16xf32>, vector<8xf32> } // -----