diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp @@ -41,6 +41,9 @@ auto src = multiReductionOp.source(); auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + // If the rank is less than or equal to 1, there is nothing to do. + if (srcRank <= 1) + return failure(); // Separate reduction and parallel dims auto reductionDimsRange = diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -153,3 +153,13 @@ // CHECK: %[[R18:.+]] = arith.addi %[[V19]], %[[R17]] : vector<6xi32> // CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32> // CHECK: return %[[RESULT_VEC]] : vector<2x3xi32> + +// This test is mainly to catch a bug that running +// `InnerOuterDimReductionConversion` on this function results in an +// infinite loop. So just check that some value is returned. +func @vector_reduction_1D(%arg0 : vector<2xf32>) -> f32 { + %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<2xf32> to f32 + return %0 : f32 +} +// CHECK-LABEL: func @vector_reduction_1D +// CHECK: return %{{.+}}