diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -15,7 +15,7 @@ #include "mlir/Dialect/SCF/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/Bufferize.h" @@ -846,6 +846,9 @@ : LinalgBaseVectorizationPattern(opName, context, filter, benefit) {} }; +//===----------------------------------------------------------------------===// +// Transformation and lowering options exposed as auxiliary structs. +//===----------------------------------------------------------------------===// /// Options to control the application of enabling transformations. /// Hoisting transformations are always deemed beneficial and must be disabled /// explicitly. @@ -887,10 +890,16 @@ transferLowering = val; return *this; } - /// Trigger full / partial vector.transfer splits. - bool transferPartialRewrite = false; - LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) { - transferPartialRewrite = val; + /// Enable lowering of vector.transpose. + bool transposeLowering = false; + LinalgVectorLoweringOptions &enableVectorTransposeLowering(bool val = true) { + transposeLowering = val; + return *this; + } + /// Enable lowering of vector.multi_reduce. + bool multiReductionLowering = false; + LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) { + multiReductionLowering = val; return *this; } /// Enable lowering of vector.contract. @@ -899,10 +908,10 @@ contractionLowering = val; return *this; } - /// Enable lowering of vector.multi_reduce. - bool multiReductionLowering = false; - LinalgVectorLoweringOptions &enableMultiReductionLowering(bool val = true) { - multiReductionLowering = val; + /// Trigger full / partial vector.transfer splits. + bool transferPartialRewrite = false; + LinalgVectorLoweringOptions &enableTransferPartialRewrite(bool val = true) { + transferPartialRewrite = val; return *this; } /// Enable lowering of vector.transfer to scf. @@ -911,13 +920,6 @@ transferToSCFConversion = val; return *this; } - /// Configure late vector transformations. - vector::VectorTransformsOptions vectorTransformOptions; - LinalgVectorLoweringOptions & - setVectorTransformsOptions(vector::VectorTransformsOptions options) { - vectorTransformOptions = options; - return *this; - } /// Configure the post staged-patterns late vector.transfer to scf /// conversion. VectorTransferToSCFOptions vectorTransferToSCFOptions; @@ -926,8 +928,18 @@ vectorTransferToSCFOptions = options; return *this; } + /// Configure late vector transformations. + vector::VectorTransformsOptions vectorTransformOptions; + LinalgVectorLoweringOptions & + setVectorTransformsOptions(vector::VectorTransformsOptions options) { + vectorTransformOptions = options; + return *this; + } }; +//===----------------------------------------------------------------------===// +// Transformations exposed as rewrite patterns. +//===----------------------------------------------------------------------===// /// Trait to check if T provides a `getOperationName` method. template using has_get_operation_name = decltype(T::getOperationName()); diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -40,76 +40,6 @@ struct BitmaskEnumStorage; } // namespace detail -/// Enum to control the lowering of `vector.contract` operations. -enum class VectorContractLowering { - /// Progressively lower to finer grained `vector.contract` and dot-products. - Dot = 0, - /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. - Matmul = 1, - /// Lower to `vector.outerproduct`. - OuterProduct = 2, -}; -/// Enum to control the lowering of `vector.multi_reduction` operations. -enum class VectorMultiReductionLowering { - /// Lower multi_reduction into outer-reduction and inner-parallel ops. - InnerParallel = 0, - /// Lower multi_reduction into outer-parallel and inner-reduction ops. - InnerReduction = 1, -}; -/// Enum to control the lowering of `vector.transpose` operations. -enum class VectorTransposeLowering { - /// Lower transpose into element-wise extract and inserts. - EltWise = 0, - /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix - /// intrinsics. - Flat = 1, -}; -/// Enum to control the splitting of `vector.transfer` operations into -/// in-bounds and out-of-bounds variants. -enum class VectorTransferSplit { - /// Do not split vector transfer operations. - None = 0, - /// Split using in-bounds + out-of-bounds vector.transfer operations. - VectorTransfer = 1, - /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy - /// operations. - LinalgCopy = 2, - /// Do not split vector transfer operation but instead mark it as "in-bounds". - ForceInBounds = 3 -}; -/// Structure to control the behavior of vector transform patterns. -struct VectorTransformsOptions { - /// Option to control the lowering of vector.contract. - VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; - VectorTransformsOptions & - setVectorTransformsOptions(VectorContractLowering opt) { - vectorContractLowering = opt; - return *this; - } - /// Option to control the lowering of vector.multi_reduction. - VectorMultiReductionLowering vectorMultiReductionLowering = - VectorMultiReductionLowering::InnerParallel; - VectorTransformsOptions & - setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { - vectorMultiReductionLowering = opt; - return *this; - } - /// Option to control the lowering of vector.transpose. - VectorTransposeLowering vectorTransposeLowering = - VectorTransposeLowering::EltWise; - VectorTransformsOptions & - setVectorTransposeLowering(VectorTransposeLowering opt) { - vectorTransposeLowering = opt; - return *this; - } - /// Option to control the splitting of vector transfers. - VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; - VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { - vectorTransferSplit = opt; - return *this; - } -}; - /// Return whether `srcType` can be broadcast to `dstVectorType` under the /// semantics of the `vector.broadcast` op. enum class BroadcastableToResult { @@ -161,33 +91,6 @@ void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool enableIndexOptimizations); -/// Collect a set of patterns to convert vector.multi_reduction op into -/// a sequence of vector.reduction ops. The patterns comprise: -/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such -/// that all reduction dimensions are either innermost or outermost, by adding -/// the proper vector.transpose operations. -/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction -/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, -/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand -/// back. -/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction -/// form, with an **outermost** reduction dimension, unroll the outer dimension -/// to obtain a sequence of 1-D vector ops. This also has an opportunity for -/// tree-reduction (in the future). -/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form, -/// with an **innermost** reduction dimension, unroll the outer dimension to -/// obtain a sequence of extract + vector.reduction + insert. This can further -/// lower to horizontal reduction ops. -/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector -/// reduction (and are thus missing either a parallel or a reduction), we lift -/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that -/// the other patterns can kick in, thus fully exiting out of the -/// vector.multi_reduction abstraction. -void populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns, - VectorMultiReductionLowering options = - vector::VectorMultiReductionLowering::InnerParallel); - /// Collect a set of patterns to propagate insert_map/extract_map in the ssa /// chain. void populatePropagateVectorDistributionPatterns(RewritePatternSet &patterns); @@ -212,12 +115,6 @@ /// vectors to low-D vector ops. void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns); -/// Collects patterns to progressively lower vector contraction ops on high-D -/// into low-D reduction and product ops. -void populateVectorContractLoweringPatterns( - RewritePatternSet &patterns, - VectorTransformsOptions options = VectorTransformsOptions()); - /// Collects patterns to progressively lower vector mask ops into elementary /// selection and insertion ops. void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns); @@ -227,15 +124,6 @@ /// ops. void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns); -/// Insert TransposeLowering patterns into extraction/insertion. -void populateVectorTransposeLoweringPatterns( - RewritePatternSet &patterns, - VectorTransformsOptions options = VectorTransformsOptions()); - -/// Collect patterns to convert reduction op to vector.contract and fold -/// transpose/broadcast ops into the contract. -void populateVetorReductionToContractPatterns(RewritePatternSet &patterns); - /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); diff --git a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/VectorRewritePatterns.h @@ -9,11 +9,173 @@ #ifndef DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_ #define DIALECT_VECTOR_VECTORREWRITEPATTERNS_H_ +#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorUtils.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/PatternMatch.h" + namespace mlir { class RewritePatternSet; namespace vector { +//===----------------------------------------------------------------------===// +// Vector transformation options exposed as auxiliary structs. +//===----------------------------------------------------------------------===// +/// Enum to control the lowering of `vector.transpose` operations. +enum class VectorTransposeLowering { + /// Lower transpose into element-wise extract and inserts. + EltWise = 0, + /// Lower 2-D transpose to `vector.flat_transpose`, maps 1-1 to LLVM matrix + /// intrinsics. + Flat = 1, +}; +/// Enum to control the lowering of `vector.multi_reduction` operations. +enum class VectorMultiReductionLowering { + /// Lower multi_reduction into outer-reduction and inner-parallel ops. + InnerParallel = 0, + /// Lower multi_reduction into outer-parallel and inner-reduction ops. + InnerReduction = 1, +}; +/// Enum to control the lowering of `vector.contract` operations. +enum class VectorContractLowering { + /// Progressively lower to finer grained `vector.contract` and dot-products. + Dot = 0, + /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. + Matmul = 1, + /// Lower to `vector.outerproduct`. + OuterProduct = 2, +}; +/// Enum to control the splitting of `vector.transfer` operations into +/// in-bounds and out-of-bounds variants. +enum class VectorTransferSplit { + /// Do not split vector transfer operations. + None = 0, + /// Split using in-bounds + out-of-bounds vector.transfer operations. + VectorTransfer = 1, + /// Split using an in-bounds vector.transfer + linalg.fill + linalg.copy + /// operations. + LinalgCopy = 2, + /// Do not split vector transfer operation but instead mark it as "in-bounds". + ForceInBounds = 3 +}; +/// Structure to control the behavior of vector transform patterns. +struct VectorTransformsOptions { + /// Option to control the lowering of vector.contract. + VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; + VectorTransformsOptions & + setVectorTransformsOptions(VectorContractLowering opt) { + vectorContractLowering = opt; + return *this; + } + /// Option to control the lowering of vector.multi_reduction. + VectorMultiReductionLowering vectorMultiReductionLowering = + VectorMultiReductionLowering::InnerParallel; + VectorTransformsOptions & + setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { + vectorMultiReductionLowering = opt; + return *this; + } + /// Option to control the lowering of vector.transpose. + VectorTransposeLowering vectorTransposeLowering = + VectorTransposeLowering::EltWise; + VectorTransformsOptions & + setVectorTransposeLowering(VectorTransposeLowering opt) { + vectorTransposeLowering = opt; + return *this; + } + /// Option to control the splitting of vector transfers. + VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; + VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { + vectorTransferSplit = opt; + return *this; + } +}; + +/// Options that control the vector unrolling. +struct UnrollVectorOptions { + using FilterConstraintFnType = std::function; + /// Callback function that indicates whether vector unrolling should be + /// attempted on the operation. + FilterConstraintFnType filterConstraint = nullptr; + UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) { + filterConstraint = constraint; + return *this; + } + + using NativeShapeFnType = + std::function>(Operation *op)>; + /// Function that returns the shape of the vector to unroll to for a given + /// operation. The unrolling is aborted if the function returns `llvm::None`. + NativeShapeFnType nativeShape = nullptr; + UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { + nativeShape = fn; + return *this; + } + + /// Set the native shape to use for unrolling. + UnrollVectorOptions &setNativeShape(ArrayRef shape) { + SmallVector tsShape(shape.begin(), shape.end()); + nativeShape = [=](Operation *) -> Optional> { + return tsShape; + }; + return *this; + } +}; + +//===----------------------------------------------------------------------===// +// Vector transformation exposed as populate functions over rewrite patterns. +//===----------------------------------------------------------------------===// + +/// Insert TransposeLowering patterns into extraction/insertion. +void populateVectorTransposeLoweringPatterns( + RewritePatternSet &patterns, + VectorTransformsOptions options = VectorTransformsOptions()); + +/// Collect a set of patterns to convert vector.multi_reduction op into +/// a sequence of vector.reduction ops. The patterns comprise: +/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such +/// that all reduction dimensions are either innermost or outermost, by adding +/// the proper vector.transpose operations. +/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction +/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, +/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand +/// back. +/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction +/// form, with an **outermost** reduction dimension, unroll the outer dimension +/// to obtain a sequence of 1-D vector ops. This also has an opportunity for +/// tree-reduction (in the future). +/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form, +/// with an **innermost** reduction dimension, unroll the outer dimension to +/// obtain a sequence of extract + vector.reduction + insert. This can further +/// lower to horizontal reduction ops. +/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector +/// reduction (and are thus missing either a parallel or a reduction), we lift +/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that +/// the other patterns can kick in, thus fully exiting out of the +/// vector.multi_reduction abstraction. +void populateVectorMultiReductionLoweringPatterns( + RewritePatternSet &patterns, + VectorMultiReductionLowering options = + VectorMultiReductionLowering::InnerParallel); + +/// Collects patterns to progressively lower vector contraction ops on high-D +/// into low-D reduction and product ops. +void populateVectorContractLoweringPatterns( + RewritePatternSet &patterns, + VectorTransformsOptions options = VectorTransformsOptions()); + +/// Collect patterns to convert reduction op to vector.contract and fold +/// transpose/broadcast ops into the contract. +void populateVectorReductionToContractPatterns(RewritePatternSet &patterns); + +/// Collect a set of patterns to reduce the rank of the operands of vector +/// transfer ops to operate on the largest contigious vector. +/// These patterns are useful when lowering to dialects with 1d vector type +/// such as llvm and it will result fewer memory reads. +void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( + RewritePatternSet &patterns); + /// Populate `patterns` with the following patterns. /// /// [VectorInsertStridedSliceOpDifferentRankRewritePattern] @@ -52,6 +214,235 @@ void populateVectorInsertExtractStridedSliceTransforms( RewritePatternSet &patterns); +/// Collect a set of pattern to unroll vector operations to a smaller shapes. +/// `options` structure controls which operations are unrolled and the target +/// shape. +/// `op` is unrolled to the `targetShape` as follows, for each of its operands: +/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances +/// `numUnrolledInstances` are computed from the `targetShape`. For now it is +/// assumed the unrolling factors divide the vector sizes. +/// 2. ExtractStridedSlice are created to break-up the vector operands. +/// 3. the original op is cloned `numUnrolledInstances` times, once for each +/// result. +/// 4. InsertStridedSlice are inserted to re-assemble the slices into the +/// original vectore shape. +/// +/// Example: +/// +/// opA(operand0, operand1) // numUnrolledInstances = 3 +/// +/// operand0 operand1 +/// | | +/// fork fork +/// <----------gather all fork ops ---------> +/// /|\ /|\ +/// f00 f01 f02 f10 f11 f12 +/// <---------- clone op 3 times ---------> +/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) +/// \ | / +/// <-------------------- join -------------------------> +/// +/// Other local patterns then kick in iteratively (including DCE) and compose +/// to combine the ExtractStridedSlice/InsertStridedSlice. +void populateVectorUnrollPatterns(RewritePatternSet &patterns, + const UnrollVectorOptions &options); + +//===----------------------------------------------------------------------===// +// Finer-grained patterns exposed for more control over individual lowerings. +//===----------------------------------------------------------------------===// +/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern +/// may take an extra filter to perform selection at a finer granularity. +struct VectorTransferFullPartialRewriter : public RewritePattern { + using FilterConstraintType = + std::function; + + explicit VectorTransferFullPartialRewriter( + MLIRContext *context, + VectorTransformsOptions options = VectorTransformsOptions(), + FilterConstraintType filter = + [](VectorTransferOpInterface op) { return success(); }, + PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), + filter(filter) {} + + /// Performs the rewrite. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + VectorTransformsOptions options; + FilterConstraintType filter; +}; + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToMatmulOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpToMatmulOpLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context), + vectorTransformOptions(vectorTransformOptions), filter(constraint) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; +}; + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToOuterProductOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpToOuterProductOpLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context), + vectorTransformOptions(vectorTransformOptions), filter(constraint) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; +}; + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to an output-size-unrolled sequence: +/// ``` +/// %out = arith.constant ... : vector +/// %bt = vector.transpose %b, [1, 0] +/// %aRow0 = vector.extract %a[0] +/// %btRow0 = vector.extract %bt[0] +/// %c00 = vector.reduce %atRow0, %bRow0 +/// %out00 = vector.insert %c00, %out[0, 0] +/// ... +/// %aRowLast = vector.extract %at[M-1] +/// %btRowLast = vector.extract %b[N-1] +/// %cLastLast = vector.reduce %atRowLast, %bRowLast +/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to Dot and +/// the vector.contract op is a row-major matmul or matvec. +class ContractionOpToDotLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpToDotLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context), + vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; +}; + +/// Progressive lowering of ContractionOp. +/// +/// One: +/// %x = vector.contract with at least one free/batch dimension +/// is replaced by: +/// %a = vector.contract with one less free/batch dimension +/// %b = vector.contract with one less free/batch dimension +/// .. +/// %x = combine %a %b .. +/// until a pure contraction is reached (no free/batch dimensions), +/// which is replaced by a dot-product. +/// +/// This only kicks in when either VectorTransformsOptions is set +/// to Dot or when other contraction patterns fail. +class ContractionOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, + FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context), + vectorTransformOptions(vectorTransformOptions), filter(constraint) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; + // Lower one parallel dimension. + Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, + int64_t rhsIndex, PatternRewriter &rewriter) const; + // Lower one reduction dimension. + Value lowerReduction(vector::ContractionOp op, + PatternRewriter &rewriter) const; +}; + } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -9,10 +9,8 @@ #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_ #define DIALECT_VECTOR_VECTORTRANSFORMS_H_ -#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/VectorUtils.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/PatternMatch.h" namespace mlir { class MLIRContext; @@ -26,77 +24,9 @@ namespace vector { -/// Options that control the vector unrolling. -struct UnrollVectorOptions { - using FilterConstraintFnType = std::function; - /// Callback function that indicates whether vector unrolling should be - /// attempted on the operation. - FilterConstraintFnType filterConstraint = nullptr; - UnrollVectorOptions &setFilterConstraint(FilterConstraintFnType constraint) { - filterConstraint = constraint; - return *this; - } - - using NativeShapeFnType = - std::function>(Operation *op)>; - /// Function that returns the shape of the vector to unroll to for a given - /// operation. The unrolling is aborted if the function returns `llvm::None`. - NativeShapeFnType nativeShape = nullptr; - UnrollVectorOptions &setNativeShapeFn(NativeShapeFnType fn) { - nativeShape = fn; - return *this; - } - - /// Set the native shape to use for unrolling. - UnrollVectorOptions &setNativeShape(ArrayRef shape) { - SmallVector tsShape(shape.begin(), shape.end()); - nativeShape = [=](Operation *) -> Optional> { - return tsShape; - }; - return *this; - } -}; - -/// Collect a set of pattern to unroll vector operations to a smaller shapes. -/// `options` structure controls which operations are unrolled and the target -/// shape. -/// `op` is unrolled to the `targetShape` as follows, for each of its operands: -/// 1. the unrolled type `unrolledVectorType` and number of unrolled instances -/// `numUnrolledInstances` are computed from the `targetShape`. For now it is -/// assumed the unrolling factors divide the vector sizes. -/// 2. ExtractStridedSlice are created to break-up the vector operands. -/// 3. the original op is cloned `numUnrolledInstances` times, once for each -/// result. -/// 4. InsertStridedSlice are inserted to re-assemble the slices into the -/// original vectore shape. -/// -/// Example: -/// -/// opA(operand0, operand1) // numUnrolledInstances = 3 -/// -/// operand0 operand1 -/// | | -/// fork fork -/// <----------gather all fork ops ---------> -/// /|\ /|\ -/// f00 f01 f02 f10 f11 f12 -/// <---------- clone op 3 times ---------> -/// opA0(f00, f10), opA1(f01, f11), opA2(f02, f12) -/// \ | / -/// <-------------------- join -------------------------> -/// -/// Other local patterns then kick in iteratively (including DCE) and compose -/// to combine the ExtractStridedSlice/InsertStridedSlice. -void populateVectorUnrollPatterns(RewritePatternSet &patterns, - const UnrollVectorOptions &options); - -/// Collect a set of patterns to reduce the rank of the operands of vector -/// transfer ops to operate on the largest contigious vector. -/// These patterns are useful when lowering to dialects with 1d vector type -/// such as llvm and it will result fewer memory reads. -void populateVectorTransferCollapseInnerMostContiguousDimsPatterns( - RewritePatternSet &patterns); - +//===----------------------------------------------------------------------===// +// Standalone transformations and helpers. +//===----------------------------------------------------------------------===// /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds /// masking) fastpath and a slowpath. /// If `ifOp` is not null and the result is `success, the `ifOp` points to the @@ -130,37 +60,11 @@ /// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` /// must be equal. This will be relaxed in the future but requires /// rank-reducing subviews. -LogicalResult -splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp); LogicalResult splitFullAndPartialTransfer( OpBuilder &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options = VectorTransformsOptions(), scf::IfOp *ifOp = nullptr); -/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern -/// may take an extra filter to perform selection at a finer granularity. -struct VectorTransferFullPartialRewriter : public RewritePattern { - using FilterConstraintType = - std::function; - - explicit VectorTransferFullPartialRewriter( - MLIRContext *context, - VectorTransformsOptions options = VectorTransformsOptions(), - FilterConstraintType filter = - [](VectorTransferOpInterface op) { return success(); }, - PatternBenefit benefit = 1) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), - filter(filter) {} - - /// Performs the rewrite. - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; - -private: - VectorTransformsOptions options; - FilterConstraintType filter; -}; - struct DistributeOps { ExtractMapOp extract; InsertMapOp insert; @@ -188,180 +92,6 @@ void transferOpflowOpt(FuncOp func); } // namespace vector - -//===----------------------------------------------------------------------===// -// Finer-grained patterns exposed for more control over individual lowerings. -//===----------------------------------------------------------------------===// - -/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to: -/// ``` -/// %flattened_a = vector.shape_cast %a -/// %flattened_b = vector.shape_cast %b -/// %flattened_d = vector.matmul %flattened_a, %flattened_b -/// %d = vector.shape_cast %%flattened_d -/// %e = add %c, %d -/// ``` -/// `vector.matmul` later lowers to `llvm.matrix.multiply`. -// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and -/// the vector.contract op is a row-major matrix multiply. -class ContractionOpToMatmulOpLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpToMatmulOpLowering( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context), - vectorTransformOptions(vectorTransformOptions), filter(constraint) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; -}; - -/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to a reduction_size-unrolled sequence: -/// ``` -/// %at = vector.transpose %a, [1, 0] -/// %bRow0 = vector.extract %b[0] -/// %atRow0 = vector.extract %at[0] -/// %c0 = vector.outerproduct %atRow0, %bRow0, %c -/// ... -/// %bRowK = vector.extract %b[K] -/// %atRowK = vector.extract %at[K] -/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 -/// ``` -/// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and -/// the vector.contract op is a row-major matrix multiply. -class ContractionOpToOuterProductOpLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpToOuterProductOpLowering( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context), - vectorTransformOptions(vectorTransformOptions), filter(constraint) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; -}; - -/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to an output-size-unrolled sequence: -/// ``` -/// %out = arith.constant ... : vector -/// %bt = vector.transpose %b, [1, 0] -/// %aRow0 = vector.extract %a[0] -/// %btRow0 = vector.extract %bt[0] -/// %c00 = vector.reduce %atRow0, %bRow0 -/// %out00 = vector.insert %c00, %out[0, 0] -/// ... -/// %aRowLast = vector.extract %at[M-1] -/// %btRowLast = vector.extract %b[N-1] -/// %cLastLast = vector.reduce %atRowLast, %bRowLast -/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] -/// ``` -/// -/// This only kicks in when VectorTransformsOptions is set to Dot and -/// the vector.contract op is a row-major matmul or matvec. -class ContractionOpToDotLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpToDotLowering( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context), - vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; -}; - -/// Progressive lowering of ContractionOp. -/// -/// One: -/// %x = vector.contract with at least one free/batch dimension -/// is replaced by: -/// %a = vector.contract with one less free/batch dimension -/// %b = vector.contract with one less free/batch dimension -/// .. -/// %x = combine %a %b .. -/// until a pure contraction is reached (no free/batch dimensions), -/// which is replaced by a dot-product. -/// -/// This only kicks in when either VectorTransformsOptions is set -/// to Dot or when other contraction patterns fail. -class ContractionOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, - FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context), - vectorTransformOptions(vectorTransformOptions), filter(constraint) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; - // Lower one parallel dimension. - Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, - int64_t rhsIndex, PatternRewriter &rewriter) const; - // Lower one reduction dimension. - Value lowerReduction(vector::ContractionOp op, - PatternRewriter &rewriter) const; -}; - } // namespace mlir #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_ diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -14,8 +14,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/Dialect/Vector/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/MathExtras.h" #include "mlir/Target/LLVMIR/TypeToLLVM.h" diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -21,7 +21,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Vector/VectorOps.h" +#include "mlir/Dialect/Vector/VectorRewritePatterns.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -32,6 +31,7 @@ #include "mlir/Transforms/Utils.h" using namespace mlir; +using namespace mlir::vector; using namespace linalg; namespace { @@ -191,7 +191,7 @@ } vector::populateVectorTransferPermutationMapLoweringPatterns( vectorizationPatterns); - vector::populateVetorReductionToContractPatterns(vectorizationPatterns); + vector::populateVectorReductionToContractPatterns(vectorizationPatterns); vectorizationPatterns.add( funcOp.getContext(), /*benefit=*/2); @@ -268,9 +268,14 @@ vector::populateVectorTransferLoweringPatterns(patterns, options.maxTransferRank); } - if (options.transferPartialRewrite) { - patterns.add( - context, options.vectorTransformOptions); + if (options.transposeLowering) { + vector::populateVectorTransposeLoweringPatterns( + patterns, options.vectorTransformOptions); + } + if (options.multiReductionLowering) { + vector::populateVectorMultiReductionLoweringPatterns( + patterns, + options.vectorTransformOptions.vectorMultiReductionLowering); } if (options.contractionLowering) { patterns.add( + context, options.vectorTransformOptions); } if (options.transferToSCFConversion) { populateVectorToSCFConversionPatterns(patterns, options.vectorTransferToSCFOptions); } + vector::populateVectorShapeCastLoweringPatterns(patterns); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } diff --git a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorMultiDimReductionTransforms.cpp @@ -10,14 +10,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Vector/VectorOps.h" -#include "mlir/Dialect/Vector/VectorTransforms.h" +#include "mlir/Dialect/Vector/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/VectorUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -21,21 +21,10 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" -#include "mlir/Dialect/Vector/VectorUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/ImplicitLocOpBuilder.h" -#include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" -#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/DenseSet.h" @@ -48,6 +37,7 @@ #define DEBUG_TYPE "vector-to-vector" using namespace mlir; +using namespace mlir::vector; // Helper to find an index in an affine map. static Optional getResultIndex(AffineMap map, int64_t index) { @@ -1978,9 +1968,41 @@ }); return inBoundsCond; } - -LogicalResult mlir::vector::splitFullAndPartialTransferPrecondition( - VectorTransferOpInterface xferOp) { +/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds +/// masking) fastpath and a slowpath. +/// If `ifOp` is not null and the result is `success, the `ifOp` points to the +/// newly created conditional upon function return. +/// To accomodate for the fact that the original vector.transfer indexing may be +/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the +/// scf.if op returns a view and values of type index. +/// At this time, only vector.transfer_read case is implemented. +/// +/// Example (a 2-D vector.transfer_read): +/// ``` +/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// ``` +/// is transformed into: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// // fastpath, direct cast +/// memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view : compatibleMemRefType, index, index +/// } else { +/// // slowpath, not in-bounds vector.transfer or linalg.copy. +/// memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %4 : compatibleMemRefType, index, index +// } +/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} +/// ``` +/// where `alloc` is a top of the function alloca'ed buffer of one vector. +/// +/// Preconditions: +/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` +/// must be equal. This will be relaxed in the future but requires +/// rank-reducing subviews. +static LogicalResult +splitFullAndPartialTransferPrecondition(VectorTransferOpInterface xferOp) { // TODO: expand support to these 2 cases. if (!xferOp.permutation_map().isMinorIdentity()) return failure(); @@ -3863,7 +3885,7 @@ patterns.add(options, patterns.getContext()); } -void mlir::vector::populateVetorReductionToContractPatterns( +void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp @@ -98,9 +98,8 @@ VectorTransposeLowering::EltWise}; RewritePatternSet vectorTransferPatterns(context); - // Pattern is not applied because rank-reducing vector transfer is not yet - // supported as can be seen in splitFullAndPartialTransferPrecondition, - // VectorTransforms.cpp + // Pattern is not applied: rank-reducing vector transfer is not yet supported + // (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp). vectorTransferPatterns.add( context, vectorTransformOptions); (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns)); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -536,7 +536,7 @@ RewritePatternSet canonicalizationPatterns(funcOp.getContext()); vector::populateVectorTransferPermutationMapLoweringPatterns( canonicalizationPatterns); - vector::populateVetorReductionToContractPatterns(canonicalizationPatterns); + vector::populateVectorReductionToContractPatterns(canonicalizationPatterns); stage1Patterns.push_back(std::move(canonicalizationPatterns)); } SmallVector frozenStage1Patterns; diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -14,13 +14,13 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" -#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::vector; + namespace { struct TestVectorToVectorConversion @@ -511,7 +511,7 @@ } void runOnFunction() override { RewritePatternSet patterns(&getContext()); - populateVetorReductionToContractPatterns(patterns); + populateVectorReductionToContractPatterns(patterns); (void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns)); } };