diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp @@ -41,9 +41,6 @@ auto src = multiReductionOp.source(); auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - // If the rank is less than or equal to 1, there is nothing to do. - if (srcRank <= 1) - return failure(); // Separate reduction and parallel dims auto reductionDimsRange = @@ -59,6 +56,9 @@ parallelDims.push_back(i); // Add transpose only if inner-most/outer-most dimensions are not parallel + // and there are parallel dims. + if (parallelDims.empty()) + return failure(); if (useInnerDimsForReduction && (parallelDims == llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -163,3 +163,10 @@ } // CHECK-LABEL: func @vector_reduction_1D // CHECK: return %{{.+}} + +func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>) -> f32 { + %0 = vector.multi_reduction , %arg0 [0, 1] : vector<2x3xf32> to f32 + return %0 : f32 +} +// CHECK-LABEL: func @vector_multi_reduction_to_scalar +// CHECK: return %{{.+}}