diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -92,6 +92,10 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool enableIndexOptimizations); +// Collect a set of patterns to convert vector.multi_reduction op into +// a sequence of vector.reduction ops. +void populateVectorMultiReductionLoweringPatterns(RewritePatternSet &patterns); + /// An attribute that specifies the combining function for `vector.contract`, /// and `vector.reduction`. class CombiningKindAttr diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -3575,6 +3575,198 @@ const bool enableIndexOptimizations; }; +// Converts vector.multi_reduction into inner-most reduction form by inserting +// vector.transpose +struct InnerDimReductionConversion + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto src = multiReductionOp.source(); + auto loc = multiReductionOp.getLoc(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + + auto reductionDims = llvm::to_vector<4>( + llvm::map_range(multiReductionOp.reduction_dims().cast(), + [](Attribute attr) -> int64_t { + return attr.cast().getInt(); + })); + llvm::sort(reductionDims); + + int64_t reductionSize = multiReductionOp.reduction_dims().size(); + + // Fails if already inner most reduction. + bool innerMostReduction = true; + for (int i = 0; i < reductionSize; ++i) { + if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) { + innerMostReduction = false; + } + } + if (innerMostReduction) + return failure(); + + // Permutes the indices so reduction dims are inner most dims. + SmallVector indices; + for (int i = 0; i < srcRank; ++i) { + indices.push_back(i); + } + int ir = reductionSize - 1; + int id = srcRank - 1; + while (ir >= 0) { + std::swap(indices[reductionDims[ir--]], indices[id--]); + } + + // Sets inner most dims as reduction. + SmallVector reductionMask(srcRank, false); + for (int i = 0; i < reductionSize; ++i) { + reductionMask[srcRank - i - 1] = true; + } + auto transposeOp = rewriter.create(loc, src, indices); + rewriter.replaceOpWithNewOp( + multiReductionOp, transposeOp.result(), reductionMask, + multiReductionOp.kind()); + return success(); + } +}; + +// Reduces the rank of vector.mult_reduction nd -> 2d given all reduction +// dimensions are inner most. +struct ReduceMultiDimReductionRank + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + auto srcShape = multiReductionOp.getSourceVectorType().getShape(); + if (srcRank == 2) + return failure(); + + auto loc = multiReductionOp.getLoc(); + auto reductionDims = llvm::to_vector<4>( + llvm::map_range(multiReductionOp.reduction_dims().cast(), + [](Attribute attr) -> int64_t { + return attr.cast().getInt(); + })); + llvm::sort(reductionDims); + + // Fails if not inner most reduction. + int64_t reductionSize = reductionDims.size(); + bool innerMostReduction = true; + for (int i = 0; i < reductionSize; ++i) { + if (reductionDims[reductionSize - i - 1] != srcRank - i - 1) { + innerMostReduction = false; + } + } + if (!innerMostReduction) + return failure(); + + // Extracts 2d rank reduction shape. + int innerDims = 1; + int outterDims = 1; + SmallVector innerDimsShape; + for (int i = 0; i < srcRank; ++i) { + if (i < (srcRank - reductionSize)) { + innerDims *= srcShape[i]; + innerDimsShape.push_back(srcShape[i]); + } else { + outterDims *= srcShape[i]; + } + } + + // Creates shape cast for the inputs n_d -> 2d + auto castedType = VectorType::get( + {innerDims, outterDims}, + multiReductionOp.getSourceVectorType().getElementType()); + auto castedOp = rewriter.create( + loc, castedType, multiReductionOp.source()); + + // Creates the canonical form of 2d vector.multi_reduction with inner most + // dim as reduction. + auto newOp = rewriter.create( + loc, castedOp.result(), ArrayRef{false, true}, + multiReductionOp.kind()); + + // Creates shape cast for the output 2d -> nd + auto outputCastedType = VectorType::get( + innerDimsShape, + multiReductionOp.getSourceVectorType().getElementType()); + Value castedOutputOp = rewriter.create( + loc, outputCastedType, newOp.dest()); + + rewriter.replaceOp(multiReductionOp, castedOutputOp); + return success(); + } +}; + +// Converts 2d vector.multi_reduction with inner most reduction dimension into a +// sequence of vector.reduction ops. +struct TwoDimMultiReductionToReduction + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, + PatternRewriter &rewriter) const override { + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); + if (srcRank != 2) + return failure(); + + if (multiReductionOp.getReductionMask()[0] || + !multiReductionOp.getReductionMask()[1]) + return failure(); + + auto loc = multiReductionOp.getLoc(); + + Value result = + multiReductionOp.getDestVectorType().getElementType().isIntOrIndex() + ? rewriter.create( + loc, multiReductionOp.getDestVectorType(), + DenseElementsAttr::get(multiReductionOp.getDestVectorType(), + 0)) + : rewriter.create( + loc, multiReductionOp.getDestVectorType(), + DenseElementsAttr::get(multiReductionOp.getDestVectorType(), + 0.0f)); + + int outerDim = multiReductionOp.getSourceVectorType().getShape()[0]; + + // TODO: Add vector::CombiningKind attribute instead of string to + // vector.reduction. + auto getKindStr = [](vector::CombiningKind kind) { + switch (kind) { + case vector::CombiningKind::ADD: + return "add"; + case vector::CombiningKind::MUL: + return "mul"; + case vector::CombiningKind::MIN: + return "min"; + case vector::CombiningKind::MAX: + return "max"; + case vector::CombiningKind::AND: + return "and"; + case vector::CombiningKind::OR: + return "or"; + case vector::CombiningKind::XOR: + return "xor"; + } + }; + + for (int i = 0; i < outerDim; ++i) { + auto v = rewriter.create( + loc, multiReductionOp.source(), ArrayRef{i}); + auto reducedValue = rewriter.create( + loc, multiReductionOp.getDestVectorType().getElementType(), + rewriter.getStringAttr(getKindStr(multiReductionOp.kind())), v, + ValueRange{}); + result = rewriter.create(loc, reducedValue, + result, i); + } + rewriter.replaceOp(multiReductionOp, result); + return success(); + } +}; + void mlir::vector::populateVectorMaskMaterializationPatterns( RewritePatternSet &patterns, bool enableIndexOptimizations) { patterns.add( patterns.getContext()); } + +void mlir::vector::populateVectorMultiReductionLoweringPatterns( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -0,0 +1,66 @@ +// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s + +func @vector_multi_reduction(%arg0: vector<2x4xf32>) -> vector<2xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1] : vector<2x4xf32> to vector<2xf32> + return %0 : vector<2xf32> +} +// CHECK-LABEL: func @vector_multi_reduction +// CHECK-SAME: %[[INPUT:.+]]: vector<2x4xf32> +// CHECK: %[[RESULT_VEC_0:.+]] = constant dense<{{.*}}> : vector<2xf32> +// CHECK: %[[C0:.+]] = constant 0 : i32 +// CHECK: %[[C1:.+]] = constant 1 : i32 +// CHECK: %[[V0:.+]] = vector.extract %[[INPUT]][0] +// CHECK: %[[RV0:.+]] = vector.reduction "mul", %[[V0]] : vector<4xf32> into f32 +// CHECK: %[[RESULT_VEC_1:.+]] = vector.insertelement %[[RV0:.+]], %[[RESULT_VEC_0]][%[[C0]] : i32] : vector<2xf32> +// CHECK: %[[V1:.+]] = vector.extract %[[INPUT]][1] +// CHECK: %[[RV1:.+]] = vector.reduction "mul", %[[V1]] : vector<4xf32> into f32 +// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : i32] : vector<2xf32> +// CHECK: return %[[RESULT_VEC]] + +func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>) -> vector<2x3xi32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> + return %0 : vector<2x3xi32> +} +// CHECK-LABEL: func @vector_reduction_inner +// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xi32> +// CHECK: %[[FLAT_RESULT_VEC_0:.+]] = constant dense<0> : vector<6xi32> +// CHECK-DAG: %[[C0:.+]] = constant 0 : i32 +// CHECK-DAG: %[[C1:.+]] = constant 1 : i32 +// CHECK-DAG: %[[C2:.+]] = constant 2 : i32 +// CHECK-DAG: %[[C3:.+]] = constant 3 : i32 +// CHECK-DAG: %[[C4:.+]] = constant 4 : i32 +// CHECK-DAG: %[[C5:.+]] = constant 5 : i32 +// CHECK: %[[RESHAPED_INPUT:.+]] = vector.shape_cast %[[INPUT]] : vector<2x3x4x5xi32> to vector<6x20xi32> +// CHECK: %[[V0:.+]] = vector.extract %[[RESHAPED_INPUT]][0] : vector<6x20xi32> +// CHECK: %[[V0R:.+]] = vector.reduction "add", %[[V0]] : vector<20xi32> into i32 +// CHECK: %[[FLAT_RESULT_VEC_1:.+]] = vector.insertelement %[[V0R]], %[[FLAT_RESULT_VEC_0]][%[[C0]] : i32] : vector<6xi32> +// CHECK: %[[V1:.+]] = vector.extract %[[RESHAPED_INPUT]][1] : vector<6x20xi32> +// CHECK: %[[V1R:.+]] = vector.reduction "add", %[[V1]] : vector<20xi32> into i32 +// CHECK: %[[FLAT_RESULT_VEC_2:.+]] = vector.insertelement %[[V1R]], %[[FLAT_RESULT_VEC_1]][%[[C1]] : i32] : vector<6xi32> +// CHECK: %[[V2:.+]] = vector.extract %[[RESHAPED_INPUT]][2] : vector<6x20xi32> +// CHECK: %[[V2R:.+]] = vector.reduction "add", %[[V2]] : vector<20xi32> into i32 +// CHECK: %[[FLAT_RESULT_VEC_3:.+]] = vector.insertelement %[[V2R]], %[[FLAT_RESULT_VEC_2]][%[[C2]] : i32] : vector<6xi32> +// CHECK: %[[V3:.+]] = vector.extract %[[RESHAPED_INPUT]][3] : vector<6x20xi32> +// CHECK: %[[V3R:.+]] = vector.reduction "add", %[[V3]] : vector<20xi32> into i32 +// CHECK: %[[FLAT_RESULT_VEC_4:.+]] = vector.insertelement %[[V3R]], %[[FLAT_RESULT_VEC_3]][%[[C3]] : i32] : vector<6xi32> +// CHECK: %[[V4:.+]] = vector.extract %[[RESHAPED_INPUT]][4] : vector<6x20xi32> +// CHECK: %[[V4R:.+]] = vector.reduction "add", %[[V4]] : vector<20xi32> into i32 +// CHECK: %[[FLAT_RESULT_VEC_5:.+]] = vector.insertelement %[[V4R]], %[[FLAT_RESULT_VEC_4]][%[[C4]] : i32] : vector<6xi32> +/// CHECK: %[[V5:.+]] = vector.extract %[[RESHAPED_INPUT]][5] : vector<6x20xi32> +// CHECK: %[[V5R:.+]] = vector.reduction "add", %[[V5]] : vector<20xi32> into i32 +// CHECK: %[[FLAT_RESULT_VEC:.+]] = vector.insertelement %[[V5R]], %[[FLAT_RESULT_VEC_5]][%[[C5]] : i32] : vector<6xi32> +// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> +// CHECK: return %[[RESULT]] + + +func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>) -> vector<2x5xf32> { + %0 = vector.multi_reduction #vector.kind, %arg0 [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> + return %0 : vector<2x5xf32> +} + +// CHECK-LABEL: func @vector_multi_reduction_transposed +// CHECK-SAME: %[[INPUT:.+]]: vector<2x3x4x5xf32> +// CHECK: %[[TRANSPOSED_INPUT:.+]] = vector.transpose %[[INPUT]], [0, 3, 1, 2] : vector<2x3x4x5xf32> to vector<2x5x3x4xf32> +// CHEKC: vector.shape_cast %[[TRANSPOSED_INPUT]] : vector<2x5x3x4xf32> to vector<10x12xf32> +// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -376,6 +376,19 @@ } }; +struct TestVectorMultiReductionLoweringPatterns + : public PassWrapper { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnFunction() override { + RewritePatternSet patterns(&getContext()); + populateVectorMultiReductionLoweringPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); + } +}; + struct TestProgressiveVectorToSCFLoweringPatterns : public PassWrapper { @@ -439,6 +452,12 @@ PassRegistration transferOpToSCF( "test-progressive-convert-vector-to-scf", "Test conversion patterns to progressively lower transfer ops to SCF"); + + PassRegistration + multiDimReductionOpLoweringPass( + "test-vector-multi-reduction-lowering-patterns", + "Test conversion patterns to lower vector.multi_reduction to other " + "vector ops"); } } // namespace test } // namespace mlir