diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -10,6 +10,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -3951,9 +3952,20 @@ reductionMask[srcRank - i - 1] = true; } auto transposeOp = rewriter.create(loc, src, indices); - rewriter.replaceOpWithNewOp( - multiReductionOp, transposeOp.result(), reductionMask, - multiReductionOp.kind()); + auto newMultiReductionOp = rewriter.create(loc, + transposeOp.result(), reductionMask, multiReductionOp.kind()); + if (!std::is_sorted(indices.begin(), indices.end() - reductionSize)) { + // Add additional transpose to restore to original shape + SmallVector newIndices(srcRank - reductionSize); + std::iota(std::begin(newIndices), std::end(newIndices), 0); + std::sort(newIndices.begin(), newIndices.end(), + [&] (int i, int j) { return indices[i] < indices[j]; }); + auto transposeOp = rewriter.create(loc, + newMultiReductionOp.getResult(), newIndices); + rewriter.replaceOp(multiReductionOp, transposeOp.result()); + } else { + rewriter.replaceOp(multiReductionOp, newMultiReductionOp.getResult()); + } return success(); } }; diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -61,6 +61,50 @@ // CHECK-LABEL: func @vector_multi_reduction_transposed // CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32> // CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32> -// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> +// CHECK: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> // CHECK: return %[[RESULT]] + +func @vector_multi_reduction_additional_transpose(%arg0: vector<3x2x4xf32>) -> vector<2x4xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [0] : vector<3x2x4xf32> to vector<2x4xf32> + return %0 : vector<2x4xf32> +} +// CHECK-LABEL: func @vector_multi_reduction_additional_transpose +// CHECK-SAME: %[[INPUT:.+]]: vector<3x2x4xf32> +// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<8xf32> +// CHECK: %[[C0:.+]] = constant 0 : i32 +// CHECK: %[[C1:.+]] = constant 1 : i32 +// CHECK: %[[C2:.+]] = constant 2 : i32 +// CHECK: %[[C3:.+]] = constant 3 : i32 +// CHECK: %[[C4:.+]] = constant 4 : i32 +// CHECK: %[[C5:.+]] = constant 5 : i32 +// CHECK: %[[C6:.+]] = constant 6 : i32 +// CHECK: %[[C7:.+]] = constant 7 : i32 +// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [2, 1, 0] : vector<3x2x4xf32> to vector<4x2x3xf32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 0] +// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<8xf32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED_INPUT]][0, 1] +// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_2:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<8xf32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 0] +// CHECK: %[[RV2:.+]] = vector.reduction "mul", %[[V2]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_3:.+]] = vector.insertelement %[[RV2:.+]], %[[RESULT_VEC_2]][%[[C2]] : i32] : vector<8xf32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED_INPUT]][1, 1] +// CHECK: %[[RV3:.+]] = vector.reduction "mul", %[[V3]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_4:.+]] = vector.insertelement %[[RV3:.+]], %[[RESULT_VEC_3]][%[[C3]] : i32] : vector<8xf32> +// CHECK: %[[V4:.+]] = vector.extract %[[TRANSPOSED_INPUT]][2, 0] +// CHECK: %[[RV4:.+]] = vector.reduction "mul", %[[V4]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_5:.+]] = vector.insertelement %[[RV4:.+]], %[[RESULT_VEC_4]][%[[C4]] : i32] : vector<8xf32> +// CHECK: %[[V5:.+]] = vector.extract %[[TRANSPOSED_INPUT]][2, 1] +// CHECK: %[[RV5:.+]] = vector.reduction "mul", %[[V5]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_6:.+]] = vector.insertelement %[[RV5:.+]], %[[RESULT_VEC_5]][%[[C5]] : i32] : vector<8xf32> +// CHECK: %[[V6:.+]] = vector.extract %[[TRANSPOSED_INPUT]][3, 0] +// CHECK: %[[RV6:.+]] = vector.reduction "mul", %[[V6]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC_7:.+]] = vector.insertelement %[[RV6:.+]], %[[RESULT_VEC_6]][%[[C6]] : i32] : vector<8xf32> +// CHECK: %[[V7:.+]] = vector.extract %[[TRANSPOSED_INPUT]][3, 1] +// CHECK: %[[RV7:.+]] = vector.reduction "mul", %[[V7]] : vector<3xf32> into f32 +// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : i32] : vector<8xf32> +// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<4x2xf32> +// CHECK: %[[TRANSPOSED_VEC:.+]] = vector.transpose %[[RESHAPED_VEC]], [1, 0] : vector<4x2xf32> to vector<2x4xf32> +// CHECK: return %[[TRANSPOSED_VEC]]