diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -899,6 +899,12 @@ contractionLowering = val; return *this; } + /// Enable lowering of vector.multi_reduce. + bool multiReductionLowering = false; + LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) { + multiReductionLowering = val; + return *this; + } /// Enable lowering of vector.transfer to scf. bool transferToSCFConversion = false; LinalgVectorLoweringOptions &enableTransferToSCFConversion(bool val = true) { diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -40,6 +40,76 @@ struct BitmaskEnumStorage; } // namespace detail +/// Enum to control the lowering of `vector.contract` operations. +enum class VectorContractLowering { + /// Progressively lower to finer grained `vector.contract` and dot-products. + Dot = 0, + /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. + Matmul = 1, + /// Lower to `vector.outerproduct`. + OuterProduct = 2, +}; +/// Enum to control the lowering of `vector.multi_reduction` operations. +enum class VectorMultiReductionLowering { + /// Lower multi_reduction into outer-reduction and inner-parallel ops. + InnerParallel = 0, + /// Lower multi_reduction into outer-parallel and inner-reduction ops. + InnerReduction = 1, +}; +/// Enum to control the lowering of `vector.transpose` operations. +enum class VectorTransposeLowering { + /// Lower transpose into element-wise extract and inserts. + EltWise = 0, + /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix + /// intrinsics. + Flat = 1, +}; +/// Enum to control the splitting of `vector.transfer` operations into +/// in-bounds and out-of-bounds variants. +enum class VectorTransferSplit { + /// Do not split vector transfer operations. + None = 0, + /// Split using in-bounds + out-of-bounds vector.transfer operations. + VectorTransfer = 1, + /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy + /// operations. + LinalgCopy = 2, + /// Do not split vector transfer operation but instead mark it as "in-bounds". + ForceInBounds = 3 +}; +/// Structure to control the behavior of vector transform patterns. +struct VectorTransformsOptions { + /// Option to control the lowering of vector.contract. + VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; + VectorTransformsOptions & + setVectorTransformsOptions(VectorContractLowering opt) { + vectorContractLowering = opt; + return *this; + } + /// Option to control the lowering of vector.multi_reduction. + VectorMultiReductionLowering vectorMultiReductionLowering = + VectorMultiReductionLowering::InnerParallel; + VectorTransformsOptions & + setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { + vectorMultiReductionLowering = opt; + return *this; + } + /// Option to control the lowering of vector.transpose. + VectorTransposeLowering vectorTransposeLowering = + VectorTransposeLowering::EltWise; + VectorTransformsOptions & + setVectorTransposeLowering(VectorTransposeLowering opt) { + vectorTransposeLowering = opt; + return *this; + } + /// Option to control the splitting of vector transfers. + VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; + VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { + vectorTransferSplit = opt; + return *this; + } +}; + /// Return whether `srcType` can be broadcast to `dstVectorType` under the /// semantics of the `vector.broadcast` op. enum class BroadcastableToResult { @@ -114,7 +184,9 @@ /// the other patterns can kick in, thus fully exiting out of the /// vector.multi_reduction abstraction. void populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns, bool useInnerDimsForReduction = false); + RewritePatternSet &patterns, + VectorMultiReductionLowering options = + vector::VectorMultiReductionLowering::InnerParallel); /// Collect a set of patterns to propagate insert_map/extract_map in the ssa /// chain. @@ -136,61 +208,6 @@ static Attribute parse(DialectAsmParser &parser); }; -/// Enum to control the lowering of `vector.contract` operations. -enum class VectorContractLowering { - /// Progressively lower to finer grained `vector.contract` and dot-products. - Dot = 0, - /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. - Matmul = 1, - /// Lower to `vector.outerproduct`. - OuterProduct = 2, -}; -/// Enum to control the lowering of `vector.transpose` operations. -enum class VectorTransposeLowering { - /// Lower transpose into element-wise extract and inserts. - EltWise = 0, - /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix - /// intrinsics. - Flat = 1, -}; -/// Enum to control the splitting of `vector.transfer` operations into -/// in-bounds and out-of-bounds variants. -enum class VectorTransferSplit { - /// Do not split vector transfer operations. - None = 0, - /// Split using in-bounds + out-of-bounds vector.transfer operations. - VectorTransfer = 1, - /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy - /// operations. - LinalgCopy = 2, - /// Do not split vector transfer operation but instead mark it as "in-bounds". - ForceInBounds = 3 -}; -/// Structure to control the behavior of vector transform patterns. -struct VectorTransformsOptions { - /// Option to control the lowering of vector.contract. - VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; - VectorTransformsOptions & - setVectorTransformsOptions(VectorContractLowering opt) { - vectorContractLowering = opt; - return *this; - } - /// Option to control the lowering of vector.transpose. - VectorTransposeLowering vectorTransposeLowering = - VectorTransposeLowering::EltWise; - VectorTransformsOptions & - setVectorTransposeLowering(VectorTransposeLowering opt) { - vectorTransposeLowering = opt; - return *this; - } - /// Option to control the splitting of vector transfers. - VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; - VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { - vectorTransferSplit = opt; - return *this; - } -}; - /// Collects patterns to progressively lower vector.broadcast ops on high-D /// vectors to low-D vector ops. void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -263,6 +263,7 @@ MLIRContext *context = funcOp.getContext(); RewritePatternSet patterns(context); + vector::populateVectorToVectorCanonicalizationPatterns(patterns); if (options.transferLowering) { vector::populateVectorTransferLoweringPatterns(patterns, options.maxTransferRank); @@ -277,6 +278,11 @@ options.vectorTransformOptions, context); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); } + if (options.multiReductionLowering) { + vector::populateVectorMultiReductionLoweringPatterns( + patterns, + options.vectorTransformOptions.vectorMultiReductionLowering); + } if (options.transferToSCFConversion) { populateVectorToSCFConversionPatterns(patterns, options.vectorTransferToSCFOptions); diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp @@ -35,10 +35,11 @@ public: using OpRewritePattern::OpRewritePattern; - explicit InnerOuterDimReductionConversion(MLIRContext *context, - bool useInnerDimsForReduction) + explicit InnerOuterDimReductionConversion( + MLIRContext *context, vector::VectorMultiReductionLowering options) : mlir::OpRewritePattern(context), - useInnerDimsForReduction(useInnerDimsForReduction) {} + useInnerDimsForReduction( + options == vector::VectorMultiReductionLowering::InnerReduction) {} LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -103,10 +104,11 @@ public: using OpRewritePattern::OpRewritePattern; - explicit ReduceMultiDimReductionRank(MLIRContext *context, - bool useInnerDimsForReduction) + explicit ReduceMultiDimReductionRank( + MLIRContext *context, vector::VectorMultiReductionLowering options) : mlir::OpRewritePattern(context), - useInnerDimsForReduction(useInnerDimsForReduction) {} + useInnerDimsForReduction( + options == vector::VectorMultiReductionLowering::InnerReduction) {} LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -398,11 +400,11 @@ }; void mlir::vector::populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns, bool useInnerDimsForReduction) { - patterns.add(patterns.getContext(), - useInnerDimsForReduction); - if (useInnerDimsForReduction) + RewritePatternSet &patterns, VectorMultiReductionLowering options) { + patterns.add( + patterns.getContext(), options); + patterns.add(patterns.getContext()); + if (options == VectorMultiReductionLowering ::InnerReduction) patterns.add(patterns.getContext()); else patterns.add(patterns.getContext()); diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp @@ -94,7 +94,8 @@ //===--------------------------------------------------------------------===// VectorTransformsOptions vectorTransformOptions{ - VectorContractLowering::Dot, VectorTransposeLowering::EltWise}; + VectorContractLowering::Dot, VectorMultiReductionLowering::InnerParallel, + VectorTransposeLowering::EltWise}; RewritePatternSet vectorTransferPatterns(context); // Pattern is not applied because rank-reducing vector transfer is not yet diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -159,11 +159,14 @@ VectorContractLowering contractLowering = VectorContractLowering::Dot; if (lowerToFlatMatrix) contractLowering = VectorContractLowering::Matmul; + VectorMultiReductionLowering vectorMultiReductionLowering = + VectorMultiReductionLowering::InnerParallel; VectorTransposeLowering transposeLowering = VectorTransposeLowering::EltWise; if (lowerToFlatTranspose) transposeLowering = VectorTransposeLowering::Flat; - VectorTransformsOptions options{contractLowering, transposeLowering}; + VectorTransformsOptions options{ + contractLowering, vectorMultiReductionLowering, transposeLowering}; populateVectorBroadcastLoweringPatterns(patterns); populateVectorContractLoweringPatterns(patterns, options); populateVectorMaskOpLoweringPatterns(patterns); @@ -461,7 +464,10 @@ llvm::cl::init(false)}; void runOnFunction() override { RewritePatternSet patterns(&getContext()); - populateVectorMultiReductionLoweringPatterns(patterns, !useOuterReductions); + populateVectorMultiReductionLoweringPatterns( + patterns, useOuterReductions + ? vector::VectorMultiReductionLowering::InnerParallel + : vector::VectorMultiReductionLowering::InnerReduction); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };