diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -37,6 +37,7 @@ #include "mlir/IR/Types.h" #include "mlir/Interfaces/VectorInterfaces.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -3915,42 +3916,33 @@ auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); - auto reductionDims = llvm::to_vector<4>( - llvm::map_range(multiReductionOp.reduction_dims().cast(), - [](Attribute attr) -> int64_t { - return attr.cast().getInt(); - })); - llvm::sort(reductionDims); - - int64_t reductionSize = multiReductionOp.reduction_dims().size(); - - // Fails if already inner most reduction. - bool innerMostReduction = true; - for (int i = 0; i < reductionSize; ++i) { - if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) { - innerMostReduction = false; - } + // Separate reduction and parallel dims + auto reductionDimsRange = + multiReductionOp.reduction_dims().getAsValueRange(); + auto reductionDims = llvm::to_vector<4>(llvm::map_range( + reductionDimsRange, [](APInt a) { return a.getZExtValue(); })); + llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), + reductionDims.end()); + int64_t reductionSize = reductionDims.size(); + SmallVector parallelDims; + for (int64_t i = 0; i < srcRank; i++) { + if (!reductionDimsSet.contains(i)) + parallelDims.push_back(i); } - if (innerMostReduction) - return failure(); - // Permutes the indices so reduction dims are inner most dims. - SmallVector indices; - for (int i = 0; i < srcRank; ++i) { - indices.push_back(i); - } - int ir = reductionSize - 1; - int id = srcRank - 1; - while (ir >= 0) { - std::swap(indices[reductionDims[ir--]], indices[id--]); - } + // Add transpose only if inner-most dimensions are not reductions + if (parallelDims == + llvm::to_vector<4>(llvm::seq(0, parallelDims.size()))) + return failure(); - // Sets inner most dims as reduction. + SmallVector indices; + indices.append(parallelDims.begin(), parallelDims.end()); + indices.append(reductionDims.begin(), reductionDims.end()); + auto transposeOp = rewriter.create(loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { reductionMask[srcRank - i - 1] = true; } - auto transposeOp = rewriter.create(loc, src, indices); rewriter.replaceOpWithNewOp( multiReductionOp, transposeOp.result(), reductionMask, multiReductionOp.kind()); diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -61,6 +61,49 @@ // CHECK-LABEL: func @vector_multi_reduction_transposed // CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32> // CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32> -// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> +// CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> // CHECK: return %[[RESULT]] + +func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> + return %0 : vector<2x4xf32> +} +// CHECK-LABEL: func @vector_multi_reduction_ordering +// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32> +// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<8xf32> +// CHECK: %[[C0:.+]] = constant 0 : i32 +// CHECK: %[[C1:.+]] = constant 1 : i32 +// CHECK: %[[C2:.+]] = constant 2 : i32 +// CHECK: %[[C3:.+]] = constant 3 : i32 +// CHECK: %[[C4:.+]] = constant 4 : i32 +// CHECK: %[[C5:.+]] = constant 5 : i32 +// CHECK: %[[C6:.+]] = constant 6 : i32 +// CHECK: %[[C7:.+]] = constant 7 : i32 +// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [1, 2, 0] : vector<3x2x4xf32> to vector<2x4x3xf32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0] +// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1] +// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 2] +// CHECK: %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 3] +// CHECK: %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32> +// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0] +// CHECK: %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32> +// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1] +// CHECK: %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32> +// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 2] +// CHECK: %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32> +// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 3] +// CHECK: %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32> +// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> +// CHECK: return %[[RESHAPED_VEC]]