diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -81,7 +81,8 @@ // Collect a set of patterns to convert vector.multi_reduction op into // a sequence of vector.reduction ops. -void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns); +void populateVectorMultiReductionLoweringPatterns( + RewritePatternSet &patterns, bool useInnerDimsForReduction = false); /// Collect a set of patterns to propagate insert_map/extract_map in the ssa /// chain. diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3490,12 +3490,18 @@ const bool enableIndexOptimizations; }; -// Converts vector.multi_reduction into inner-most reduction form by inserting -// vector.transpose -struct InnerDimReductionConversion +// Converts vector.multi_reduction into inner-most/outer-most reduction form +// by using vector.tranpose +class InnerOuterDimReductionConversion : public OpRewritePattern { +public: using OpRewritePattern::OpRewritePattern; + explicit InnerOuterDimReductionConversion(MLIRContext *context, + bool useInnerDimsForReduction) + : mlir::OpRewritePattern(context), + useInnerDimsForReduction(useInnerDimsForReduction) {} + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { auto src = multiReductionOp.source(); @@ -3516,87 +3522,116 @@ parallelDims.push_back(i); } - // Add transpose only if inner-most dimensions are not reductions - if (parallelDims == - llvm::to_vector<4>(llvm::seq(0, parallelDims.size()))) + // Add transpose only if inner-most/outer-most dimensions are not parallel + if (useInnerDimsForReduction && + (parallelDims == + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) + return failure(); + + if (!useInnerDimsForReduction && + (parallelDims != + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) return failure(); SmallVector indices; - indices.append(parallelDims.begin(), parallelDims.end()); - indices.append(reductionDims.begin(), reductionDims.end()); + if (useInnerDimsForReduction) { + indices.append(parallelDims.begin(), parallelDims.end()); + indices.append(reductionDims.begin(), reductionDims.end()); + } else { + indices.append(reductionDims.begin(), reductionDims.end()); + indices.append(parallelDims.begin(), parallelDims.end()); + } auto transposeOp = rewriter.create(loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { - reductionMask[srcRank - i - 1] = true; + if (useInnerDimsForReduction) + reductionMask[srcRank - i - 1] = true; + else + reductionMask[i] = true; } rewriter.replaceOpWithNewOp( multiReductionOp, transposeOp.result(), reductionMask, multiReductionOp.kind()); return success(); } + +private: + const bool useInnerDimsForReduction; }; // Reduces the rank of vector.mult_reduction nd -> 2d given all reduction -// dimensions are inner most. -struct ReduceMultiDimReductionRank +// dimensions are either inner most or outer most. +class ReduceMultiDimReductionRank : public OpRewritePattern { +public: using OpRewritePattern::OpRewritePattern; + explicit ReduceMultiDimReductionRank(MLIRContext *context, + bool useInnerDimsForReduction) + : mlir::OpRewritePattern(context), + useInnerDimsForReduction(useInnerDimsForReduction) {} + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { auto srcRank = multiReductionOp.getSourceVectorType().getRank(); auto srcShape = multiReductionOp.getSourceVectorType().getShape(); + auto loc = multiReductionOp.getLoc(); if (srcRank == 2) return failure(); - auto loc = multiReductionOp.getLoc(); - auto reductionDims = llvm::to_vector<4>( - llvm::map_range(multiReductionOp.reduction_dims().cast(), - [](Attribute attr) -> int64_t { - return attr.cast().getInt(); - })); - llvm::sort(reductionDims); - - // Fails if not inner most reduction. - int64_t reductionSize = reductionDims.size(); - bool innerMostReduction = true; - for (int i = 0; i < reductionSize; ++i) { - if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) { - innerMostReduction = false; + // Separate reduction and parallel dims + auto reductionDimsRange = + multiReductionOp.reduction_dims().getAsValueRange(); + auto reductionDims = llvm::to_vector<4>(llvm::map_range( + reductionDimsRange, [](APInt a) { return a.getZExtValue(); })); + llvm::SmallDenseSet reductionDimsSet(reductionDims.begin(), + reductionDims.end()); + SmallVector parallelDims, parallelShapes; + int canonicalReductionDim = 1; + int canonicalParallelDim = 1; + for (int64_t i = 0; i < srcRank; i++) { + if (!reductionDimsSet.contains(i)) { + parallelDims.push_back(i); + parallelShapes.push_back(srcShape[i]); + canonicalParallelDim *= srcShape[i]; + } else { + canonicalReductionDim *= srcShape[i]; } } - if (!innerMostReduction) + + // Fail if reduction dims are not either inner-most or outer-most + if (useInnerDimsForReduction && + (parallelDims != + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) return failure(); - // Extracts 2d rank reduction shape. - int innerDims = 1; - int outterDims = 1; - SmallVector innerDimsShape; - for (int i = 0; i < srcRank; ++i) { - if (i < (srcRank - reductionSize)) { - innerDims *= srcShape[i]; - innerDimsShape.push_back(srcShape[i]); - } else { - outterDims *= srcShape[i]; - } - } + if (!useInnerDimsForReduction && + (parallelDims == + llvm::to_vector<4>(llvm::seq(0, parallelDims.size())))) + return failure(); // Creates shape cast for the inputs n_d -> 2d + int64_t outerDim = + useInnerDimsForReduction ? canonicalParallelDim : canonicalReductionDim; + int64_t innerDim = + useInnerDimsForReduction ? canonicalReductionDim : canonicalParallelDim; + auto castedType = VectorType::get( - {innerDims, outterDims}, + ArrayRef{outerDim, innerDim}, multiReductionOp.getSourceVectorType().getElementType()); auto castedOp = rewriter.create( loc, castedType, multiReductionOp.source()); - // Creates the canonical form of 2d vector.multi_reduction with inner most - // dim as reduction. + // Creates the canonical form of 2d vector.multi_reduction with inner/outer + // most dim as reduction. + SmallVector mask{!useInnerDimsForReduction, + useInnerDimsForReduction}; auto newOp = rewriter.create( - loc, castedOp.result(), ArrayRef{false, true}, - multiReductionOp.kind()); + loc, castedOp.result(), mask, multiReductionOp.kind()); // Creates shape cast for the output 2d -> nd - auto outputCastedType = VectorType::get( - innerDimsShape, + VectorType outputCastedType = VectorType::get( + parallelShapes, multiReductionOp.getSourceVectorType().getElementType()); Value castedOutputOp = rewriter.create( loc, outputCastedType, newOp.dest()); @@ -3604,6 +3639,88 @@ rewriter.replaceOp(multiReductionOp, castedOutputOp); return success(); } + +private: + const bool useInnerDimsForReduction; +}; + +// Unrolls vector.multi_reduction with outermost reductions +// and combines results +struct UnrollOuterMultiReduction + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + if (srcRank != 2) + return failure(); + + if (multiReductionOp.getReductionMask()[1] || + !multiReductionOp.getReductionMask()[0]) + return failure(); + + auto loc = multiReductionOp.getLoc(); + ArrayRef srcShape = + multiReductionOp.getSourceVectorType().getShape(); + + Type elementType = multiReductionOp.getDestVectorType().getElementType(); + if (!elementType.isIntOrIndexOrFloat()) + return failure(); + + Value condition; + Value result = + rewriter.create(loc, multiReductionOp.source(), 0) + .getResult(); + for (int64_t i = 1; i < srcShape[0]; i++) { + auto operand = + rewriter.create(loc, multiReductionOp.source(), i); + switch (multiReductionOp.kind()) { + case vector::CombiningKind::ADD: + if (elementType.isIntOrIndex()) + result = rewriter.create(loc, operand, result); + else + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MUL: + if (elementType.isIntOrIndex()) + result = rewriter.create(loc, operand, result); + else + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::MIN: + if (elementType.isIntOrIndex()) + condition = + rewriter.create(loc, CmpIPredicate::slt, operand, result); + else + condition = + rewriter.create(loc, CmpFPredicate::OLT, operand, result); + result = rewriter.create(loc, condition, operand, result); + break; + case vector::CombiningKind::MAX: + if (elementType.isIntOrIndex()) + condition = + rewriter.create(loc, CmpIPredicate::sge, operand, result); + else + condition = + rewriter.create(loc, CmpFPredicate::OGE, operand, result); + result = rewriter.create(loc, condition, operand, result); + break; + case vector::CombiningKind::AND: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::OR: + result = rewriter.create(loc, operand, result); + break; + case vector::CombiningKind::XOR: + result = rewriter.create(loc, operand, result); + break; + } + } + + rewriter.replaceOp(multiReductionOp, result); + return success(); + } }; // Converts 2d vector.multi_reduction with inner most reduction dimension into a @@ -3747,9 +3864,13 @@ } void mlir::vector::populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + RewritePatternSet &patterns, bool useInnerDimsForReduction) { + patterns.add( + patterns.getContext(), useInnerDimsForReduction); + if (useInnerDimsForReduction) + patterns.add(patterns.getContext()); + else + patterns.add(patterns.getContext()); } void mlir::vector::populateVectorUnrollPatterns( diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -0,0 +1,161 @@ +// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns="use-outer-reductions" | FileCheck %s + +func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @vector_multi_reduction +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> +// CHECK: %[[RV01:.+]] = mulf %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> +// CHECK: %[[RV012:.+]] = mulf %[[V2]], %[[RV01]] : vector<2xf32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> +// CHECK: %[[RESULT_VEC:.+]] = mulf %[[V3]], %[[RV012]] : vector<2xf32> +// CHECK: return %[[RESULT_VEC]] : vector<2xf32> + +func @vector_multi_reduction_min(%arg0: vector<2x4xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @vector_multi_reduction_min +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> +// CHECK: %[[C0:.+]] = cmpf olt, %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> +// CHECK: %[[C1:.+]] = cmpf olt, %[[V2]], %[[RV01]] : vector<2xf32> +// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> +// CHECK: %[[C2:.+]] = cmpf olt, %[[V3]], %[[RV012]] : vector<2xf32> +// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32> +// CHECK: return %[[RESULT_VEC]] : vector<2xf32> + +func @vector_multi_reduction_max(%arg0: vector<2x4xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func @vector_multi_reduction_max +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xf32> to vector<4x2xf32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xf32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xf32> +// CHECK: %[[C0:.+]] = cmpf oge, %[[V1]], %[[V0]] : vector<2xf32> +// CHECK: %[[RV01:.+]] = select %[[C0]], %[[V1]], %[[V0]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xf32> +// CHECK: %[[C1:.+]] = cmpf oge, %[[V2]], %[[RV01]] : vector<2xf32> +// CHECK: %[[RV012:.+]] = select %[[C1]], %[[V2]], %[[RV01]] : vector<2xi1>, vector<2xf32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xf32> +// CHECK: %[[C2:.+]] = cmpf oge, %[[V3]], %[[RV012]] : vector<2xf32> +// CHECK: %[[RESULT_VEC:.+]] = select %[[C2]], %[[V3]], %[[RV012]] : vector<2xi1>, vector<2xf32> +// CHECK: return %[[RESULT_VEC]] : vector<2xf32> + +func @vector_multi_reduction_and(%arg0: vector<2x4xi32>) -> vector<2xi32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func @vector_multi_reduction_and +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32> +// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32> +// CHECK: %[[RV01:.+]] = and %[[V1]], %[[V0]] : vector<2xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32> +// CHECK: %[[RV012:.+]] = and %[[V2]], %[[RV01]] : vector<2xi32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32> +// CHECK: %[[RESULT_VEC:.+]] = and %[[V3]], %[[RV012]] : vector<2xi32> +// CHECK: return %[[RESULT_VEC]] : vector<2xi32> + +func @vector_multi_reduction_or(%arg0: vector<2x4xi32>) -> vector<2xi32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func @vector_multi_reduction_or +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32> +// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32> +// CHECK: %[[RV01:.+]] = or %[[V1]], %[[V0]] : vector<2xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32> +// CHECK: %[[RV012:.+]] = or %[[V2]], %[[RV01]] : vector<2xi32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32> +// CHECK: %[[RESULT_VEC:.+]] = or %[[V3]], %[[RV012]] : vector<2xi32> +// CHECK: return %[[RESULT_VEC]] : vector<2xi32> + +func @vector_multi_reduction_xor(%arg0: vector<2x4xi32>) -> vector<2xi32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xi32> to vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func @vector_multi_reduction_xor +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xi32> +// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [1, 0] : vector<2x4xi32> to vector<4x2xi32> +// CHECK: %[[V0:.+]] = vector.extract %[[TRANSPOSED]][0] : vector<4x2xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[TRANSPOSED]][1] : vector<4x2xi32> +// CHECK: %[[RV01:.+]] = xor %[[V1]], %[[V0]] : vector<2xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[TRANSPOSED]][2] : vector<4x2xi32> +// CHECK: %[[RV012:.+]] = xor %[[V2]], %[[RV01]] : vector<2xi32> +// CHECK: %[[V3:.+]] = vector.extract %[[TRANSPOSED]][3] : vector<4x2xi32> +// CHECK: %[[RESULT_VEC:.+]] = xor %[[V3]], %[[RV012]] : vector<2xi32> +// CHECK: return %[[RESULT_VEC]] : vector<2xi32> + + +func @vector_reduction_outer(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> + return %0 : vector<2x3xi32> +} + +// CHECK-LABEL: func @vector_reduction_outer +// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32> +// CHECK: %[[TRANSPOSED:.+]] = vector.transpose %[[INPUT]], [2, 3, 0, 1] : vector<2x3x4x5xi32> to vector<4x5x2x3xi32> +// CHECK: %[[RESHAPED:.+]] = vector.shape_cast %[[TRANSPOSED]] : vector<4x5x2x3xi32> to vector<20x6xi32> +// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED]][0] : vector<20x6xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED]][1] : vector<20x6xi32> +// CHECK: %[[R0:.+]] = addi %[[V1]], %[[V0]] : vector<6xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED]][2] : vector<20x6xi32> +// CHECK: %[[R1:.+]] = addi %[[V2]], %[[R0]] : vector<6xi32> +// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED]][3] : vector<20x6xi32> +// CHECK: %[[R2:.+]] = addi %[[V3]], %[[R1]] : vector<6xi32> +// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED]][4] : vector<20x6xi32> +// CHECK: %[[R3:.+]] = addi %[[V4]], %[[R2]] : vector<6xi32> +// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED]][5] : vector<20x6xi32> +// CHECK: %[[R4:.+]] = addi %[[V5]], %[[R3]] : vector<6xi32> +// CHECK: %[[V6:.+]] = vector.extract %[[RESHAPED]][6] : vector<20x6xi32> +// CHECK: %[[R5:.+]] = addi %[[V6]], %[[R4]] : vector<6xi32> +// CHECK: %[[V7:.+]] = vector.extract %[[RESHAPED]][7] : vector<20x6xi32> +// CHECK: %[[R6:.+]] = addi %[[V7]], %[[R5]] : vector<6xi32> +// CHECK: %[[V8:.+]] = vector.extract %[[RESHAPED]][8] : vector<20x6xi32> +// CHECK: %[[R7:.+]] = addi %[[V8]], %[[R6]] : vector<6xi32> +// CHECK: %[[V9:.+]] = vector.extract %[[RESHAPED]][9] : vector<20x6xi32> +// CHECK: %[[R8:.+]] = addi %[[V9]], %[[R7]] : vector<6xi32> +// CHECK: %[[V10:.+]] = vector.extract %[[RESHAPED]][10] : vector<20x6xi32> +// CHECK: %[[R9:.+]] = addi %[[V10]], %[[R8]] : vector<6xi32> +// CHECK: %[[V11:.+]] = vector.extract %[[RESHAPED]][11] : vector<20x6xi32> +// CHECK: %[[R10:.+]] = addi %[[V11]], %[[R9]] : vector<6xi32> +// CHECK: %[[V12:.+]] = vector.extract %[[RESHAPED]][12] : vector<20x6xi32> +// CHECK: %[[R11:.+]] = addi %[[V12]], %[[R10]] : vector<6xi32> +// CHECK: %[[V13:.+]] = vector.extract %[[RESHAPED]][13] : vector<20x6xi32> +// CHECK: %[[R12:.+]] = addi %[[V13]], %[[R11]] : vector<6xi32> +// CHECK: %[[V14:.+]] = vector.extract %[[RESHAPED]][14] : vector<20x6xi32> +// CHECK: %[[R13:.+]] = addi %[[V14]], %[[R12]] : vector<6xi32> +// CHECK: %[[V15:.+]] = vector.extract %[[RESHAPED]][15] : vector<20x6xi32> +// CHECK: %[[R14:.+]] = addi %[[V15]], %[[R13]] : vector<6xi32> +// CHECK: %[[V16:.+]] = vector.extract %[[RESHAPED]][16] : vector<20x6xi32> +// CHECK: %[[R15:.+]] = addi %[[V16]], %[[R14]] : vector<6xi32> +// CHECK: %[[V17:.+]] = vector.extract %[[RESHAPED]][17] : vector<20x6xi32> +// CHECK: %[[R16:.+]] = addi %[[V17]], %[[R15]] : vector<6xi32> +// CHECK: %[[V18:.+]] = vector.extract %[[RESHAPED]][18] : vector<20x6xi32> +// CHECK: %[[R17:.+]] = addi %[[V18]], %[[R16]] : vector<6xi32> +// CHECK: %[[V19:.+]] = vector.extract %[[RESHAPED]][19] : vector<20x6xi32> +// CHECK: %[[R18:.+]] = addi %[[V19]], %[[R17]] : vector<6xi32> +// CHECK: %[[RESULT_VEC:.+]] = vector.shape_cast %[[R18]] : vector<6xi32> to vector<2x3xi32> +// CHECK: return %[[RESULT_VEC]] : vector<2x3xi32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -444,6 +444,9 @@ struct TestVectorMultiReductionLoweringPatterns : public PassWrapper { + TestVectorMultiReductionLoweringPatterns() = default; + TestVectorMultiReductionLoweringPatterns( + const TestVectorMultiReductionLoweringPatterns &pass) {} void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); } @@ -454,9 +457,13 @@ return "Test conversion patterns to lower vector.multi_reduction to other " "vector ops"; } + Option useOuterReductions{ + *this, "use-outer-reductions", + llvm::cl::desc("Move reductions to outer most dimensions"), + llvm::cl::init(false)}; void runOnFunction() override { RewritePatternSet patterns(&getContext()); - populateVectorMultiReductionLoweringPatterns(patterns); + populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };