diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -77,8 +77,9 @@ return failure(); if (!useInnerDimsForReduction && - (parallelDims != - llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) + (parallelDims == llvm::to_vector<4>(llvm::seq( + reductionDims.size(), + parallelDims.size() + reductionDims.size())))) return failure(); SmallVector indices; diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -234,3 +234,13 @@ // CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction // CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]] +// ----- + +func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> { + %0 = vector.multi_reduction , %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @vector_multi_reduction_parallel_middle +// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32> +// CHECK: vector.transpose %[[INPUT]], [1, 0, 2] : vector<3x4x5xf32> to vector<4x3x5xf32> diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -162,6 +162,15 @@ // CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32> // CHECK: return %[[RESULT_VEC]] : vector<2x3xi32> +func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc: vector<4xf32>) -> vector<4xf32> { + %0 = vector.multi_reduction , %arg0, %acc [0, 2] : vector<3x4x5xf32> to vector<4xf32> + return %0 : vector<4xf32> +} + +// CHECK-LABEL: func @vector_multi_reduction_parallel_middle +// CHECK-SAME: %[[INPUT:.+]]: vector<3x4x5xf32>, %[[ACC:.+]]: vector<4xf32> +// CHECK: vector.transpose %[[INPUT]], [0, 2, 1] : vector<3x4x5xf32> to vector<3x5x4xf32> + // This test is mainly to catch a bug that running // `InnerOuterDimReductionConversion` on this function results in an // infinite loop. So just check that some value is returned.