diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -110,43 +110,11 @@ void populateBubbleVectorBitCastOpPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Collect a set of transfer read/write lowering patterns. -/// -/// These patterns lower transfer ops to simpler ops like `vector.load`, -/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank -/// of a most `maxTransferRank` are lowered. This is useful when combined with -/// VectorToSCF, which reduces the rank of vector transfer ops. -void populateVectorTransferLoweringPatterns( - RewritePatternSet &patterns, - std::optional maxTransferRank = std::nullopt, - PatternBenefit benefit = 1); - /// These patterns materialize masks for various vector ops such as transfers. void populateVectorMaskMaterializationPatterns(RewritePatternSet &patterns, bool force32BitVectorIndices, PatternBenefit benefit = 1); -/// Collects patterns to progressively lower vector.broadcast ops on high-D -/// vectors to low-D vector ops. -void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Collects patterns to progressively lower vector mask ops into elementary -/// selection and insertion ops. -void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Collects patterns to progressively lower vector.shape_cast ops on high-D -/// vectors into 1-D/2-D vector ops by generating data movement extract/insert -/// ops. -void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -/// Collects patterns that lower scalar vector transfer ops to memref loads and -/// stores when beneficial. -void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - /// Returns the integer type required for subscripts in the vector dialect. IntegerType getVectorSubscriptType(Builder &builder); @@ -214,8 +182,8 @@ /// Creates a vector.mask operation around a maskable operation. Returns the /// vector.mask operation if the mask provided is valid. Otherwise, returns the /// maskable operation itself. -Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, - Value mask, Value passthru = Value()); +Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, Value mask, + Value passthru = Value()); /// Creates a vector select operation that picks values from `newValue` or /// `passthru` for each result vector lane based on `mask`. This utility is used diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h @@ -0,0 +1,248 @@ +//===- LoweringPatterns.h - Vector rewrite patterns --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H +#define MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H + +#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" + +namespace mlir { +class RewritePatternSet; + +namespace vector { + +//===----------------------------------------------------------------------===// +// Lowering pattern populate functions +//===----------------------------------------------------------------------===// + +/// Populate the pattern set with the following patterns: +/// +/// [OuterProductOpLowering] +/// Progressively lower a `vector.outerproduct` to linearized +/// `vector.extract` + `vector.fma` + `vector.insert`. +/// +/// [ContractionOpLowering] +/// Progressive lowering of ContractionOp. +/// One: +/// %x = vector.contract with at least one free/batch dimension +/// is replaced by: +/// %a = vector.contract with one less free/batch dimension +/// %b = vector.contract with one less free/batch dimension +/// +/// [ContractionOpToMatmulOpLowering] +/// Progressively lower a `vector.contract` with row-major matmul semantics to +/// linearized `vector.shape_cast` + `vector.matmul` on the way to +/// `llvm.matrix.multiply`. +/// +/// [ContractionOpToDotLowering] +/// Progressively lower a `vector.contract` with row-major matmul semantics to +/// linearized `vector.extract` + `vector.reduce` + `vector.insert`. +/// +/// [ContractionOpToOuterProductOpLowering] +/// Progressively lower a `vector.contract` with row-major matmul semantics to +/// linearized `vector.extract` + `vector.outerproduct` + `vector.insert`. +void populateVectorContractLoweringPatterns( + RewritePatternSet &patterns, VectorTransformsOptions options, + PatternBenefit benefit = 1, bool disableOuterProductLowering = false); + +/// Collect a set of patterns to convert vector.multi_reduction op into +/// a sequence of vector.reduction ops. The patterns comprise: +/// +/// [InnerOuterDimReductionConversion] +/// Rewrites vector.multi_reduction such that all reduction dimensions are +/// either innermost or outermost, by adding the proper vector.transpose +/// operations. +/// +/// [ReduceMultiDimReductionRank] +/// Once in innermost or outermost reduction +/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, +/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand +/// back. +/// +/// [TwoDimMultiReductionToElementWise] +/// Once in 2-D vector.multi_reduction form, with an **outermost** reduction +/// dimension, unroll the outer dimension to obtain a sequence of 1-D vector +/// ops. This also has an opportunity for tree-reduction (in the future). +/// +/// [TwoDimMultiReductionToReduction] +/// Once in 2-D vector.multi_reduction form, with an **innermost** reduction +/// dimension, unroll the outer dimension to obtain a sequence of extract + +/// vector.reduction + insert. This can further lower to horizontal reduction +/// ops. +/// +/// [OneDimMultiReductionToTwoDim] +/// For cases that reduce to 1-D vector reduction (and are thus missing +/// either a parallel or a reduction), we lift them back up to 2-D with a simple +/// vector.shape_cast to vector<1xk> so that the other patterns can kick in, +/// thus fully exiting out of the vector.multi_reduction abstraction. +void populateVectorMultiReductionLoweringPatterns( + RewritePatternSet &patterns, VectorMultiReductionLowering options, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// +/// [TransferReadToVectorLoadLowering] +/// Progressive lowering of BroadcastOp to ExtractOp + InsertOp + lower-D +/// BroadcastOp until dim 1. +void populateVectorBroadcastLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// +/// [CreateMaskOp] +/// Progressive lowering of CreateMaskOp to lower-D CreateMaskOp until dim 1. +/// +/// [ConstantMaskOp] +/// Progressive lowering of ConstantMaskOp to lower-D ConstantMaskOp until +/// dim 1. +void populateVectorMaskOpLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Collects patterns that lower scalar vector transfer ops to memref loads and +/// stores when beneficial. +void populateScalarVectorTransferLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// +/// [ShapeCastOp2DDownCastRewritePattern] +/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D +/// vectors progressively. +/// +/// [ShapeCastOp2DUpCastRewritePattern] +/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D +/// vectors progressively. +/// +/// [ShapeCastOpRewritePattern] +/// Reference lowering to fully unrolled sequences of single element ExtractOp + +/// InsertOp. Note that applying this pattern can almost always be considered a +/// performance bug. +void populateVectorShapeCastLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// +/// [TransposeOpLowering] +/// +/// [TransposeOp2DToShuffleLowering] +/// +void populateVectorTransposeLoweringPatterns(RewritePatternSet &patterns, + VectorTransformsOptions options, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// +/// [TransferReadToVectorLoadLowering] +/// Progressive lowering of transfer_read.This pattern supports lowering of +/// `vector.transfer_read` to a combination of `vector.load` and +/// `vector.broadcast` +/// +/// [TransferWriteToVectorStoreLowering] +/// Progressive lowering of transfer_write. This pattern supports lowering of +/// `vector.transfer_write` to `vector.store` +/// +/// [VectorLoadToMemrefLoadLowering] +/// Replace a 0-d vector.load with a memref.load + vector.broadcast. +/// +/// [VectorStoreToMemrefStoreLowering] +/// Replace a 0-d vector.store with a vector.extractelement + memref.store. +/// +/// These patterns lower transfer ops to simpler ops like `vector.load`, +/// `vector.store` and `vector.broadcast`. Only transfers with a transfer rank +/// of a most `maxTransferRank` are lowered. This is useful when combined with +/// VectorToSCF, which reduces the rank of vector transfer ops. +void populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns, + std::optional maxTransferRank = std::nullopt, + PatternBenefit benefit = 1); + +/// Collect a set of transfer read/write lowering patterns that simplify the +/// permutation map (e.g., converting it to a minor identity map) by inserting +/// broadcasts and transposes. More specifically: +/// +/// [TransferReadPermutationLowering] +/// Lower transfer_read op with permutation into a transfer_read with a +/// permutation map composed of leading zeros followed by a minor identity + +/// vector.transpose op. +/// Ex: +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (0, d1) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2) -> (d1, 0) +/// vector.transpose %v, [1, 0] +/// +/// vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) +/// vector.transpose %v, [0, 1, 3, 2, 4] +/// Note that an alternative is to transform it to linalg.transpose + +/// vector.transfer_read to do the transpose in memory instead. +/// +/// [TransferWritePermutationLowering] +/// Lower transfer_write op with permutation into a transfer_write with a +/// minor identity permutation map. (transfer_write ops cannot have broadcasts.) +/// Ex: +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2) -> (d2, d0, d1) +/// into: +/// %tmp = vector.transpose %v, [2, 0, 1] +/// vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2) -> (d0, d1, d2) +/// +/// vector.transfer_write %v ... +/// permutation_map: (d0, d1, d2, d3) -> (d3, d2) +/// into: +/// %tmp = vector.transpose %v, [1, 0] +/// %v = vector.transfer_write %tmp ... +/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) +/// +/// [TransferOpReduceRank] +/// Lower transfer_read op with broadcast in the leading dimensions into +/// transfer_read of lower rank + vector.broadcast. +/// Ex: vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) +/// into: +/// %v = vector.transfer_read ... +/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) +/// vector.broadcast %v +void populateVectorTransferPermutationMapLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// +/// [ScanToArithOps] +/// Convert vector.scan op into arith ops and vector.insert_strided_slice / +/// vector.extract_strided_slice. +void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populate the pattern set with the following patterns: +/// +/// [FlattenGather] +/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the +/// outermost dimension. For example: +/// +/// [Gather1DToConditionalLoads] +/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or +/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these +/// loads/extracts are made conditional using `scf.if` ops. +void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, + PatternBenefit benefit = 1); + +/// Populates instances of `MaskOpRewritePattern` to lower masked operations +/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and +/// not its nested `MaskableOpInterface`. +void populateVectorMaskLoweringPatternsForSideEffectingOps( + RewritePatternSet &patterns); + +} // namespace vector +} // namespace mlir +#endif // MLIR_DIALECT_VECTOR_TRANSFORMS_LOWERINGPATTERNS_H diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/Passes.h @@ -22,12 +22,6 @@ /// Creates an instance of the `vector.mask` lowering pass. std::unique_ptr createLowerVectorMaskPass(); -/// Populates instances of `MaskOpRewritePattern` to lower masked operations -/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and -/// not its nested `MaskableOpInterface`. -void populateVectorMaskLoweringPatternsForSideEffectingOps( - RewritePatternSet &patterns); - //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h @@ -9,8 +9,8 @@ #ifndef MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H #define MLIR_DIALECT_VECTOR_TRANSFORMS_VECTORREWRITEPATTERNS_H -#include #include +#include #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Transforms/VectorTransformsEnums.h.inc" @@ -23,42 +23,7 @@ class RewritePatternSet; namespace vector { - -//===----------------------------------------------------------------------===// -// Vector transformation options exposed as auxiliary structs. -//===----------------------------------------------------------------------===// -/// Structure to control the behavior of vector transform patterns. -struct VectorTransformsOptions { - /// Option to control the lowering of vector.contract. - VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; - VectorTransformsOptions & - setVectorTransformsOptions(VectorContractLowering opt) { - vectorContractLowering = opt; - return *this; - } - /// Option to control the lowering of vector.multi_reduction. - VectorMultiReductionLowering vectorMultiReductionLowering = - VectorMultiReductionLowering::InnerParallel; - VectorTransformsOptions & - setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { - vectorMultiReductionLowering = opt; - return *this; - } - /// Option to control the lowering of vector.transpose. - VectorTransposeLowering vectorTransposeLowering = - VectorTransposeLowering::EltWise; - VectorTransformsOptions & - setVectorTransposeLowering(VectorTransposeLowering opt) { - vectorTransposeLowering = opt; - return *this; - } - /// Option to control the splitting of vector transfers. - VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; - VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { - vectorTransferSplit = opt; - return *this; - } -}; +struct VectorTransformsOptions; /// Options that control the vector unrolling. struct UnrollVectorOptions { @@ -109,45 +74,6 @@ // Vector transformation exposed as populate functions over rewrite patterns. //===----------------------------------------------------------------------===// -/// Insert TransposeLowering patterns into extraction/insertion. -void populateVectorTransposeLoweringPatterns( - RewritePatternSet &patterns, - VectorTransformsOptions options = VectorTransformsOptions(), - PatternBenefit benefit = 1); - -/// Collect a set of patterns to convert vector.multi_reduction op into -/// a sequence of vector.reduction ops. The patterns comprise: -/// - InnerOuterDimReductionConversion: rewrites vector.multi_reduction such -/// that all reduction dimensions are either innermost or outermost, by adding -/// the proper vector.transpose operations. -/// - ReduceMultiDimReductionRank: once in innermost or outermost reduction -/// form, rewrites n-D vector.multi_reduction into 2-D vector.multi_reduction, -/// by introducing vector.shape_cast ops to collapse + multi-reduce + expand -/// back. -/// - TwoDimMultiReductionToElementWise: once in 2-D vector.multi_reduction -/// form, with an **outermost** reduction dimension, unroll the outer dimension -/// to obtain a sequence of 1-D vector ops. This also has an opportunity for -/// tree-reduction (in the future). -/// - TwoDimMultiReductionToReduction: once in 2-D vector.multi_reduction form, -/// with an **innermost** reduction dimension, unroll the outer dimension to -/// obtain a sequence of extract + vector.reduction + insert. This can further -/// lower to horizontal reduction ops. -/// - OneDimMultiReductionToTwoDim: for cases that reduce to 1-D vector -/// reduction (and are thus missing either a parallel or a reduction), we lift -/// them back up to 2-D with a simple vector.shape_cast to vector<1xk> so that -/// the other patterns can kick in, thus fully exiting out of the -/// vector.multi_reduction abstraction. -void populateVectorMultiReductionLoweringPatterns( - RewritePatternSet &patterns, VectorMultiReductionLowering options, - PatternBenefit benefit = 1); - -/// Collects patterns to progressively lower vector contraction ops on high-D -/// into low-D reduction and product ops. -void populateVectorContractLoweringPatterns( - RewritePatternSet &patterns, - VectorTransformsOptions options = VectorTransformsOptions(), - PatternBenefit benefit = 1); - /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul /// semantics to a contraction with MMT semantics (matrix matrix multiplication /// with the RHS transposed). This specific form is meant to have the vector @@ -174,67 +100,43 @@ void populateVectorReductionToContractPatterns(RewritePatternSet &patterns, PatternBenefit benefit = 1); -/// Collect patterns to convert scan op -void populateVectorScanLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -//===----------------------------------------------------------------------===// -// Vector.transfer patterns. -//===----------------------------------------------------------------------===// -/// Collect a set of transfer read/write lowering patterns that simplify the -/// permutation map (e.g., converting it to a minor identity map) by inserting -/// broadcasts and transposes. More specifically: -/// -/// [TransferReadPermutationLowering] -/// Lower transfer_read op with permutation into a transfer_read with a -/// permutation map composed of leading zeros followed by a minor identity + -/// vector.transpose op. -/// Ex: -/// vector.transfer_read ... -/// permutation_map: (d0, d1, d2) -> (0, d1) -/// into: -/// %v = vector.transfer_read ... -/// permutation_map: (d0, d1, d2) -> (d1, 0) -/// vector.transpose %v, [1, 0] +/// Populate `patterns` with the following patterns. /// -/// vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (0, 0, 0, d1, d3) -/// into: -/// %v = vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (0, 0, d1, 0, d3) -/// vector.transpose %v, [0, 1, 3, 2, 4] -/// Note that an alternative is to transform it to linalg.transpose + -/// vector.transfer_read to do the transpose in memory instead. +/// - VectorTransferFullPartialRewriter /// -/// [TransferWritePermutationLowering] -/// Lower transfer_write op with permutation into a transfer_write with a -/// minor identity permutation map. (transfer_write ops cannot have broadcasts.) -/// Ex: -/// vector.transfer_write %v ... -/// permutation_map: (d0, d1, d2) -> (d2, d0, d1) -/// into: -/// %tmp = vector.transpose %v, [2, 0, 1] -/// vector.transfer_write %tmp ... -/// permutation_map: (d0, d1, d2) -> (d0, d1, d2) +/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds +/// masking) fast path and a slow path. /// -/// vector.transfer_write %v ... -/// permutation_map: (d0, d1, d2, d3) -> (d3, d2) -/// into: -/// %tmp = vector.transpose %v, [1, 0] -/// %v = vector.transfer_write %tmp ... -/// permutation_map: (d0, d1, d2, d3) -> (d2, d3) +/// Example (a 2-D vector.transfer_read): +/// ``` +/// %1 = vector.transfer_read %0[...], %pad : memref, vector<...> +/// ``` +/// is transformed into: +/// ``` +/// %1:3 = scf.if (%inBounds) { +/// // fast path, direct cast +/// memref.cast %A: memref to compatibleMemRefType +/// scf.yield %view : compatibleMemRefType, index, index +/// } else { +/// // slow path, not in-bounds vector.transfer or linalg.copy. +/// memref.cast %alloc: memref to compatibleMemRefType +/// scf.yield %4 : compatibleMemRefType, index, index +// } +/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} +/// ``` +/// where `alloc` is a top of the function alloca'ed buffer of one vector. /// -/// [TransferOpReduceRank] -/// Lower transfer_read op with broadcast in the leading dimensions into -/// transfer_read of lower rank + vector.broadcast. -/// Ex: vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (0, d1, 0, d3) -/// into: -/// %v = vector.transfer_read ... -/// permutation_map: (d0, d1, d2, d3) -> (d1, 0, d3) -/// vector.broadcast %v -void populateVectorTransferPermutationMapLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit = 1); +/// Preconditions: +/// 1. `xferOp.permutation_map()` must be a minor identity map +/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` +/// must be equal. This will be relaxed in the future but requires +/// rank-reducing subviews. +void populateVectorTransferFullPartialPatterns( + RewritePatternSet &patterns, const VectorTransformsOptions &options); + +//===----------------------------------------------------------------------===// +// Vector.transfer patterns. +//===----------------------------------------------------------------------===// /// Collect a set of patterns to reduce the rank of the operands of vector /// transfer ops to operate on the largest contigious vector. @@ -334,220 +236,6 @@ const UnrollVectorOptions &options, PatternBenefit benefit = 1); -/// Expands `vector.gather` ops into a series of conditional scalar loads -/// (`vector.load` for memrefs or `tensor.extract` for tensors). These loads are -/// conditional to avoid out-of-bounds memory accesses and guarded with `scf.if` -/// ops. This lowering path is intended for targets that do not feature -/// dedicated gather ops. -void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns, - PatternBenefit benefit = 1); - -//===----------------------------------------------------------------------===// -// Finer-grained patterns exposed for more control over individual lowerings. -//===----------------------------------------------------------------------===// -/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern -/// may take an extra filter to perform selection at a finer granularity. -struct VectorTransferFullPartialRewriter : public RewritePattern { - using FilterConstraintType = - std::function; - - explicit VectorTransferFullPartialRewriter( - MLIRContext *context, - VectorTransformsOptions options = VectorTransformsOptions(), - FilterConstraintType filter = - [](VectorTransferOpInterface op) { return success(); }, - PatternBenefit benefit = 1) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), - filter(std::move(filter)) {} - - /// Performs the rewrite. - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override; - -private: - VectorTransformsOptions options; - FilterConstraintType filter; -}; - -/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to: -/// ``` -/// %flattened_a = vector.shape_cast %a -/// %flattened_b = vector.shape_cast %b -/// %flattened_d = vector.matmul %flattened_a, %flattened_b -/// %d = vector.shape_cast %%flattened_d -/// %e = add %c, %d -/// ``` -/// `vector.matmul` later lowers to `llvm.matrix.multiply`. -// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and -/// the vector.contract op is a row-major matrix multiply. -class ContractionOpToMatmulOpLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpToMatmulOpLowering( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1, - FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), - filter(std::move(constraint)) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; -}; - -/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to a reduction_size-unrolled sequence: -/// ``` -/// %at = vector.transpose %a, [1, 0] -/// %bRow0 = vector.extract %b[0] -/// %atRow0 = vector.extract %at[0] -/// %c0 = vector.outerproduct %atRow0, %bRow0, %c -/// ... -/// %bRowK = vector.extract %b[K] -/// %atRowK = vector.extract %at[K] -/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 -/// ``` -/// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct and -/// the vector.contract op is a row-major matrix multiply. -class ContractionOpToOuterProductOpLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpToOuterProductOpLowering( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1, - FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), - filter(std::move(constraint)) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; -}; - -/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to an output-size-unrolled sequence: -/// ``` -/// %out = arith.constant ... : vector -/// %bt = vector.transpose %b, [1, 0] -/// %aRow0 = vector.extract %a[0] -/// %btRow0 = vector.extract %bt[0] -/// %c00 = vector.reduce %atRow0, %bRow0 -/// %out00 = vector.insert %c00, %out[0, 0] -/// ... -/// %aRowLast = vector.extract %at[M-1] -/// %btRowLast = vector.extract %b[N-1] -/// %cLastLast = vector.reduce %atRowLast, %bRowLast -/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] -/// ``` -/// -/// This only kicks in when VectorTransformsOptions is set to Dot and -/// the vector.contract op is a row-major matmul or matvec. -class ContractionOpToDotLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpToDotLowering( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1, - const FilterConstraintType &constraint = defaultFilter) - : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; -}; - -/// Progressive lowering of ContractionOp. -/// -/// One: -/// %x = vector.contract with at least one free/batch dimension -/// is replaced by: -/// %a = vector.contract with one less free/batch dimension -/// %b = vector.contract with one less free/batch dimension -/// .. -/// %x = combine %a %b .. -/// until a pure contraction is reached (no free/batch dimensions), -/// which is replaced by a dot-product. -/// -/// This only kicks in when either VectorTransformsOptions is set -/// to Dot or when other contraction patterns fail. -class ContractionOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - using FilterConstraintType = - std::function; - - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - - ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1, - FilterConstraintType constraint = defaultFilter) - : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), - filter(std::move(constraint)) {} - - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; - // Lower one parallel dimension. - FailureOr lowerParallel(PatternRewriter &rewriter, - vector::ContractionOp op, int64_t lhsIndex, - int64_t rhsIndex, Value mask) const; - // Lower one reduction dimension. - FailureOr lowerReduction(PatternRewriter &rewriter, - vector::ContractionOp op, Value mask) const; -}; - } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorTransforms.h @@ -24,17 +24,53 @@ namespace vector { +//===----------------------------------------------------------------------===// +// Vector transformation options exposed as auxiliary structs. +//===----------------------------------------------------------------------===// +/// Structure to control the behavior of vector transform patterns. +struct VectorTransformsOptions { + /// Option to control the lowering of vector.contract. + VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; + VectorTransformsOptions & + setVectorTransformsOptions(VectorContractLowering opt) { + vectorContractLowering = opt; + return *this; + } + /// Option to control the lowering of vector.multi_reduction. + VectorMultiReductionLowering vectorMultiReductionLowering = + VectorMultiReductionLowering::InnerParallel; + VectorTransformsOptions & + setVectorMultiReductionLowering(VectorMultiReductionLowering opt) { + vectorMultiReductionLowering = opt; + return *this; + } + /// Option to control the lowering of vector.transpose. + VectorTransposeLowering vectorTransposeLowering = + VectorTransposeLowering::EltWise; + VectorTransformsOptions & + setVectorTransposeLowering(VectorTransposeLowering opt) { + vectorTransposeLowering = opt; + return *this; + } + /// Option to control the splitting of vector transfers. + VectorTransferSplit vectorTransferSplit = VectorTransferSplit::None; + VectorTransformsOptions &setVectorTransferSplit(VectorTransferSplit opt) { + vectorTransferSplit = opt; + return *this; + } +}; + //===----------------------------------------------------------------------===// // Standalone transformations and helpers. //===----------------------------------------------------------------------===// -/// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds -/// masking) fastpath and a slowpath. -/// If `ifOp` is not null and the result is `success, the `ifOp` points to the -/// newly created conditional upon function return. -/// To accomodate for the fact that the original vector.transfer indexing may be -/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the -/// scf.if op returns a view and values of type index. -/// At this time, only vector.transfer_read case is implemented. +/// Split a vector.transfer operation into an in-bounds (i.e., no +/// out-of-bounds masking) fastpath and a slowpath. If `ifOp` is not null and +/// the result is `success, the `ifOp` points to the newly created conditional +/// upon function return. To accomodate for the fact that the original +/// vector.transfer indexing may be arbitrary and the slow path indexes +/// @[0...0] in the temporary buffer, the scf.if op returns a view and values +/// of type index. At this time, only vector.transfer_read case is +/// implemented. /// /// Example (a 2-D vector.transfer_read): /// ``` @@ -51,15 +87,16 @@ /// memref.cast %alloc: memref to compatibleMemRefType /// scf.yield %4 : compatibleMemRefType, index, index // } -/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... true]} +/// %0 = vector.transfer_read %1#0[%1#1, %1#2] {in_bounds = [true ... +/// true]} /// ``` /// where `alloc` is a top of the function alloca'ed buffer of one vector. /// /// Preconditions: /// 1. `xferOp.permutation_map()` must be a minor identity map -/// 2. the rank of the `xferOp.memref()` and the rank of the `xferOp.vector()` -/// must be equal. This will be relaxed in the future but requires -/// rank-reducing subviews. +/// 2. the rank of the `xferOp.memref()` and the rank of the +/// `xferOp.vector()` must be equal. This will be relaxed in the future but +/// requires rank-reducing subviews. LogicalResult splitFullAndPartialTransfer( RewriterBase &b, VectorTransferOpInterface xferOp, VectorTransformsOptions options = VectorTransformsOptions(), diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Dialect/X86Vector/X86VectorDialect.h" @@ -64,10 +65,11 @@ RewritePatternSet patterns(&getContext()); populateVectorToVectorCanonicalizationPatterns(patterns); populateVectorBroadcastLoweringPatterns(patterns); - populateVectorContractLoweringPatterns(patterns); + populateVectorContractLoweringPatterns(patterns, VectorTransformsOptions()); populateVectorMaskOpLoweringPatterns(patterns); populateVectorShapeCastLoweringPatterns(patterns); - populateVectorTransposeLoweringPatterns(patterns); + populateVectorTransposeLoweringPatterns(patterns, + VectorTransformsOptions()); // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -10,8 +10,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" @@ -20,6 +20,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -26,6 +26,7 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Utils/Utils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -7,13 +7,14 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" - #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" #include "mlir/Dialect/PDL/IR/PDL.h" #include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/Parser/Parser.h" @@ -82,10 +83,9 @@ // In the future we may want to more finely select particular stages. // Stage 1: contraction lowerings. - patterns.add(vectorTransformOptions, - ctx); + populateVectorContractLoweringPatterns( + patterns, vectorTransformOptions, /*benefit=*/1, + /*disableOuterProductLowering*/ true); vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); // Stage 2: multi-reduction lowerings. @@ -93,8 +93,7 @@ patterns, vectorTransformOptions.vectorMultiReductionLowering); // Stage 3: Rewrite vector.transfer into full and partial parts. - patterns.add( - ctx, vectorTransformOptions); + populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); // Stage 4: Lower vector transfers. vector::populateVectorTransferLoweringPatterns(patterns, maxTransferRank); @@ -107,8 +106,8 @@ vector::populateVectorShapeCastLoweringPatterns(patterns); // Stage 7: Lower vector.transpose. - vector::populateVectorTransposeLoweringPatterns(patterns, - vectorTransformOptions); + vector::populateVectorTransposeLoweringPatterns( + patterns, vectorTransformOptions, /*benefit=*/1); if (getTransposeAvx2Lowering()) x86vector::avx2::populateSpecializedTransposeLoweringPatterns( patterns, avx2LoweringOptions, /*benefit=*/10); diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -1,14 +1,20 @@ add_mlir_dialect_library(MLIRVectorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + LowerVectorBroadcast.cpp + LowerVectorContract.cpp + LowerVectorGather.cpp LowerVectorMask.cpp + LowerVectorMultiReduction.cpp + LowerVectorScan.cpp + LowerVectorShapeCast.cpp + LowerVectorTransfer.cpp + LowerVectorTranspose.cpp VectorDistribute.cpp VectorDropLeadUnitDim.cpp VectorInsertExtractStridedSliceRewritePatterns.cpp - VectorMultiDimReductionTransforms.cpp VectorTransferOpTransforms.cpp VectorTransferSplitRewritePatterns.cpp - VectorTransferPermutationMapRewritePatterns.cpp VectorTransforms.cpp VectorUnroll.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -0,0 +1,156 @@ +//===- LowerVectorBroadcast.cpp - Lower 'vector.broadcast' operation ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.broadcast' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "vector-broadcast-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { +/// Progressive lowering of BroadcastOp. +class BroadcastOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::BroadcastOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + VectorType dstType = op.getResultVectorType(); + VectorType srcType = op.getSourceType().dyn_cast(); + Type eltType = dstType.getElementType(); + + // Scalar to any vector can use splat. + if (!srcType) { + rewriter.replaceOpWithNewOp(op, dstType, op.getSource()); + return success(); + } + + // Determine rank of source and destination. + int64_t srcRank = srcType.getRank(); + int64_t dstRank = dstType.getRank(); + + // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + if (srcRank <= 1 && dstRank == 1) { + Value ext; + if (srcRank == 0) + ext = rewriter.create(loc, op.getSource()); + else + ext = rewriter.create(loc, op.getSource(), 0); + rewriter.replaceOpWithNewOp(op, dstType, ext); + return success(); + } + + // Duplicate this rank. + // For example: + // %x = broadcast %y : k-D to n-D, k < n + // becomes: + // %b = broadcast %y : k-D to (n-1)-D + // %x = [%b,%b,%b,%b] : n-D + // becomes: + // %b = [%y,%y] : (n-1)-D + // %x = [%b,%b,%b,%b] : n-D + if (srcRank < dstRank) { + // Duplication. + VectorType resType = + VectorType::get(dstType.getShape().drop_front(), eltType); + Value bcst = + rewriter.create(loc, resType, op.getSource()); + Value result = rewriter.create( + loc, dstType, rewriter.getZeroAttr(dstType)); + for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) + result = rewriter.create(loc, bcst, result, d); + rewriter.replaceOp(op, result); + return success(); + } + + // Find non-matching dimension, if any. + assert(srcRank == dstRank); + int64_t m = -1; + for (int64_t r = 0; r < dstRank; r++) + if (srcType.getDimSize(r) != dstType.getDimSize(r)) { + m = r; + break; + } + + // All trailing dimensions are the same. Simply pass through. + if (m == -1) { + rewriter.replaceOp(op, op.getSource()); + return success(); + } + + // Any non-matching dimension forces a stretch along this rank. + // For example: + // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> + // becomes: + // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> + // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> + // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> + // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> + // %x = [%a,%b,%c,%d] + // becomes: + // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> + // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> + // %a = [%u, %v] + // .. + // %x = [%a,%b,%c,%d] + VectorType resType = + VectorType::get(dstType.getShape().drop_front(), eltType); + Value result = rewriter.create( + loc, dstType, rewriter.getZeroAttr(dstType)); + if (m == 0) { + // Stetch at start. + Value ext = rewriter.create(loc, op.getSource(), 0); + Value bcst = rewriter.create(loc, resType, ext); + for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) + result = rewriter.create(loc, bcst, result, d); + } else { + // Stetch not at start. + for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { + Value ext = rewriter.create(loc, op.getSource(), d); + Value bcst = rewriter.create(loc, resType, ext); + result = rewriter.create(loc, bcst, result, d); + } + } + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +void mlir::vector::populateVectorBroadcastLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -0,0 +1,1329 @@ +//===- LowerVectorContract.cpp - Lower 'vector.contract' operation --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.contract' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "vector-contract-lowering" + +using namespace mlir; +using namespace mlir::vector; + +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +// Helper to find an index in an affine map. +static std::optional getResultIndex(AffineMap map, int64_t index) { + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t idx = map.getDimPosition(i); + if (idx == index) + return i; + } + return std::nullopt; +} + +// Helper to construct iterator types with one index removed. +static SmallVector adjustIter(ArrayAttr iteratorTypes, + int64_t index) { + SmallVector results; + for (const auto &it : llvm::enumerate(iteratorTypes)) { + int64_t idx = it.index(); + if (idx == index) + continue; + results.push_back(it.value()); + } + return results; +} + +// Helper to construct an affine map with one index removed. +static AffineMap adjustMap(AffineMap map, int64_t index, + PatternRewriter &rewriter) { + auto *ctx = rewriter.getContext(); + SmallVector results; + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t idx = map.getDimPosition(i); + if (idx == index) + continue; + // Re-insert remaining indices, but renamed when occurring + // after the removed index. + auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); + results.push_back(targetExpr); + } + return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); +} + +// Helper method to possibly drop a dimension in a load. +// TODO +static Value reshapeLoad(Location loc, Value val, VectorType type, + int64_t index, int64_t pos, + PatternRewriter &rewriter) { + if (index == -1) + return val; + Type lowType = VectorType::Builder(type).dropDim(0); + // At extraction dimension? + if (index == 0) { + auto posAttr = rewriter.getI64ArrayAttr(pos); + return rewriter.create(loc, lowType, val, posAttr); + } + // Unroll leading dimensions. + VectorType vType = lowType.cast(); + Type resType = VectorType::Builder(type).dropDim(index); + auto resVectorType = resType.cast(); + Value result = rewriter.create( + loc, resVectorType, rewriter.getZeroAttr(resVectorType)); + for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { + auto posAttr = rewriter.getI64ArrayAttr(d); + Value ext = rewriter.create(loc, vType, val, posAttr); + Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); + result = rewriter.create(loc, resVectorType, load, result, + posAttr); + } + return result; +} + +// Helper method to possibly drop a dimension in a store. +// TODO +static Value reshapeStore(Location loc, Value val, Value result, + VectorType type, int64_t index, int64_t pos, + PatternRewriter &rewriter) { + // Unmodified? + if (index == -1) + return val; + // At insertion dimension? + if (index == 0) { + auto posAttr = rewriter.getI64ArrayAttr(pos); + return rewriter.create(loc, type, val, result, posAttr); + } + // Unroll leading dimensions. + Type lowType = VectorType::Builder(type).dropDim(0); + VectorType vType = lowType.cast(); + Type insType = VectorType::Builder(vType).dropDim(0); + for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { + auto posAttr = rewriter.getI64ArrayAttr(d); + Value ext = rewriter.create(loc, vType, result, posAttr); + Value ins = rewriter.create(loc, insType, val, posAttr); + Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); + result = rewriter.create(loc, type, sto, result, posAttr); + } + return result; +} + +/// Helper to create arithmetic operation associated with a kind of contraction. +static std::optional +createContractArithOp(Location loc, Value x, Value y, Value acc, + vector::CombiningKind kind, PatternRewriter &rewriter, + bool isInt, Value mask = Value()) { + using vector::CombiningKind; + Value mul; + + if (isInt) { + if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) + // Only valid for floating point types. + return std::nullopt; + mul = rewriter.create(loc, x, y); + } else { + // Float case. + if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || + kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || + kind == CombiningKind::MAXSI || kind == CombiningKind::OR || + kind == CombiningKind::XOR) + // Only valid for integer types. + return std::nullopt; + // Special case for fused multiply-add. + if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { + Value fma = rewriter.create(loc, x, y, acc); + if (mask) + // The fma op doesn't need explicit masking. However, fma ops used in + // reductions must preserve previous 'acc' values for masked-out lanes. + fma = selectPassthru(rewriter, mask, fma, acc); + return fma; + } + mul = rewriter.create(loc, x, y); + } + + if (!acc) + return std::optional(mul); + + return makeArithReduction(rewriter, loc, kind, mul, acc, mask); +} + +/// Return the positions of the reductions in the given map. +static SmallVector getReductionIndex(AffineMap map, + ArrayAttr iteratorTypes) { + SmallVector dimsIdx; + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) + dimsIdx.push_back(i); + } + return dimsIdx; +} + +/// Look for a given dimension in an affine map and return its position. Return +/// std::nullopt if the dimension is not in the map results. +static std::optional getDimPosition(AffineMap map, unsigned dim) { + for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { + if (map.getDimPosition(i) == dim) + return i; + } + return std::nullopt; +} + +/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using +/// operands `x` and `y`. +static Value createAdd(Location loc, Value x, Value y, bool isInt, + PatternRewriter &rewriter) { + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); +} + +/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using +/// operands `x and `y`. +static Value createMul(Location loc, Value x, Value y, bool isInt, + PatternRewriter &rewriter) { + if (isInt) + return rewriter.create(loc, x, y); + return rewriter.create(loc, x, y); +} + +namespace { + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToMatmulOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpToMatmulOpLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1, + FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions), + filter(std::move(constraint)) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; +}; + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct and +/// the vector.contract op is a row-major matrix multiply. +class ContractionOpToOuterProductOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpToOuterProductOpLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1, + FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions), + filter(std::move(constraint)) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; +}; + +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to an output-size-unrolled sequence: +/// ``` +/// %out = arith.constant ... : vector +/// %bt = vector.transpose %b, [1, 0] +/// %aRow0 = vector.extract %a[0] +/// %btRow0 = vector.extract %bt[0] +/// %c00 = vector.reduce %atRow0, %bRow0 +/// %out00 = vector.insert %c00, %out[0, 0] +/// ... +/// %aRowLast = vector.extract %at[M-1] +/// %btRowLast = vector.extract %b[N-1] +/// %cLastLast = vector.reduce %atRowLast, %bRowLast +/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to Dot and +/// the vector.contract op is a row-major matmul or matvec. +class ContractionOpToDotLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpToDotLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1, + const FilterConstraintType &constraint = defaultFilter) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; +}; + +/// Progressive lowering of ContractionOp. +/// +/// One: +/// %x = vector.contract with at least one free/batch dimension +/// is replaced by: +/// %a = vector.contract with one less free/batch dimension +/// %b = vector.contract with one less free/batch dimension +/// .. +/// %x = combine %a %b .. +/// until a pure contraction is reached (no free/batch dimensions), +/// which is replaced by a dot-product. +/// +/// This only kicks in when either VectorTransformsOptions is set +/// to Dot or when other contraction patterns fail. +class ContractionOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpLowering(vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1, + FilterConstraintType constraint = defaultFilter) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions), + filter(std::move(constraint)) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; + // Lower one parallel dimension. + FailureOr lowerParallel(PatternRewriter &rewriter, + vector::ContractionOp op, int64_t lhsIndex, + int64_t rhsIndex, Value mask) const; + // Lower one reduction dimension. + FailureOr lowerReduction(PatternRewriter &rewriter, + vector::ContractionOp op, Value mask) const; +}; + +/// Generate a vector implementation for matmat, matvec and tmatvec. +/// This unrolls outer-products along the reduction dimension. +struct UnrolledOuterProductGenerator + : public StructuredGenerator { + UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) + : StructuredGenerator(b, op), + kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), + res(op.getAcc()), lhsType(op.getLhsType()) { + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + mask = maskableOp.getMaskingOp().getMask(); + } + + Value t(Value v, ArrayRef perm = {1, 0}) { + if (!v) + return v; + return rewriter.create(loc, v, perm); + } + + Value promote(Value v, Type dstElementType) { + Type elementType = v.getType(); + auto vecType = elementType.dyn_cast(); + if (vecType) + elementType = vecType.getElementType(); + if (elementType == dstElementType) + return v; + Type promotedType = dstElementType; + if (vecType) + promotedType = VectorType::get(vecType.getShape(), promotedType); + if (dstElementType.isa()) + return rewriter.create(loc, promotedType, v); + return rewriter.create(loc, promotedType, v); + } + + FailureOr outerProd(Value lhs, Value rhs, Value res, int reductionSize, + std::optional maybeMask = std::nullopt) { + assert(reductionSize > 0); + // Incremental support for masking. + if (mask && !maybeMask.has_value()) + return failure(); + + Type resElementType = res.getType().cast().getElementType(); + for (int64_t k = 0; k < reductionSize; ++k) { + Value extractA = rewriter.create(loc, lhs, k); + Value extractB = rewriter.create(loc, rhs, k); + extractA = promote(extractA, resElementType); + extractB = promote(extractB, resElementType); + Value extractMask; + if (maybeMask.has_value() && maybeMask.value()) + extractMask = + rewriter.create(loc, maybeMask.value(), k); + + Operation *outerProdOp = rewriter.create( + loc, res.getType(), extractA, extractB, res, kind); + res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); + } + return res; + } + + /// Two outer parallel, one inner reduction (matmat flavor). + FailureOr matmat() { + if (!iters({Par(), Par(), Red()})) + return failure(); + // Set up the parallel/reduction structure in the right form. + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + // Classical row-major matmul: Just permute the lhs. + if (layout({{m, k}, {k, n}, {m, n}})) + return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), + t(mask, {2, 0, 1})); + // TODO: may be better to fail and use some vector -> scalar reduction. + if (layout({{m, k}, {n, k}, {m, n}})) { + Value tlhs = t(lhs); + return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); + } + // No need to permute anything. + if (layout({{k, m}, {k, n}, {m, n}})) + return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); + // Just permute the rhs. + if (layout({{k, m}, {n, k}, {m, n}})) + return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); + // Transposed output: swap RHS and LHS. + // Classical row-major matmul: permute the lhs. + if (layout({{m, k}, {k, n}, {n, m}})) + return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); + // TODO: may be better to fail and use some vector -> scalar reduction. + if (layout({{m, k}, {n, k}, {n, m}})) { + Value trhs = t(rhs); + return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); + } + if (layout({{k, m}, {k, n}, {n, m}})) + return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); + if (layout({{k, m}, {n, k}, {n, m}})) + return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); + return failure(); + } + + /// One outer parallel, one inner reduction (matvec flavor) + FailureOr matvec() { + if (!iters({Par(), Red()})) + return failure(); + AffineExpr m, k; + bindDims(rewriter.getContext(), m, k); + + // Case mat-vec: transpose. + if (layout({{m, k}, {k}, {m}})) + return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask)); + // Case mat-trans-vec: ready to go. + if (layout({{k, m}, {k}, {m}})) + return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); + // Case vec-mat: swap and transpose. + if (layout({{k}, {m, k}, {m}})) + return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); + // Case vec-mat-trans: swap and ready to go. + if (layout({{k}, {k, m}, {m}})) + return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); + return failure(); + } + + // + // One outer reduction, one inner parallel (tmatvec flavor) + // + FailureOr tmatvec() { + if (!iters({Red(), Par()})) + return failure(); + AffineExpr k, m; + bindDims(rewriter.getContext(), k, m); + + // Case mat-vec: transpose. + if (layout({{m, k}, {k}, {m}})) + return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); + // Case mat-trans-vec: ready to go. + if (layout({{k, m}, {k}, {m}})) + return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); + // Case vec-mat: swap and transpose. + if (layout({{k}, {m, k}, {m}})) + return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); + // Case vec-mat-trans: swap and ready to go. + if (layout({{k}, {k, m}, {m}})) + return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); + return failure(); + } + +private: + vector::CombiningKind kind; + Value lhs, rhs, res, mask; + VectorType lhsType; +}; + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct but +/// otherwise supports any layout permutation of the matrix-multiply. +LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { + // TODO: Remove native masks from contraction op? + if (!op.getMasks().empty()) + return failure(); + + if (vectorTransformOptions.vectorContractLowering != + vector::VectorContractLowering::OuterProduct) + return failure(); + + if (failed(filter(op))) + return failure(); + + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = cast(op.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = op; + } + + UnrolledOuterProductGenerator e(rewriter, op); + FailureOr matmatRes = e.matmat(); + if (succeeded(matmatRes)) { + rewriter.replaceOp(rootOp, *matmatRes); + return success(); + } + FailureOr matvecRes = e.matvec(); + if (succeeded(matvecRes)) { + rewriter.replaceOp(rootOp, *matvecRes); + return success(); + } + FailureOr tmatvecRes = e.tmatvec(); + if (succeeded(tmatvecRes)) { + rewriter.replaceOp(rootOp, *tmatvecRes); + return success(); + } + + return failure(); +} + +LogicalResult +ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? + if (!op.getMasks().empty()) + return failure(); + + if (failed(filter(op))) + return failure(); + + if (vectorTransformOptions.vectorContractLowering != + vector::VectorContractLowering::Dot) + return failure(); + + auto iteratorTypes = op.getIteratorTypes().getValue(); + static constexpr std::array perm = {1, 0}; + Location loc = op.getLoc(); + Value lhs = op.getLhs(), rhs = op.getRhs(); + + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + SmallVector maps = op.getIndexingMapsArray(); + // + // In the following we wish to make the reduction dimension innermost so we + // can load vectors and just fmul + reduce into a scalar. + // + if (isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2])) { + // + // Two outer parallel, one inner reduction (matmat flavor). + // + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { + // No need to permute anything. + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + // This is the classical row-major matmul. Just permute the lhs. + Value tmp = lhs; + lhs = rewriter.create(loc, rhs, perm); + rhs = tmp; + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + Value tmp = lhs; + lhs = rewriter.create(loc, rhs, perm); + rhs = rewriter.create(loc, tmp, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + Value tmp = rhs; + rhs = rewriter.create(loc, lhs, perm); + lhs = tmp; + } else { + return failure(); + } + } else if (isParallelIterator(iteratorTypes[0]) && + isReductionIterator(iteratorTypes[1])) { + // + // One outer parallel, one inner reduction (matvec flavor) + // + if (maps == infer({{m, n}, {n}, {m}})) { + // No need to permute anything. + } else if (maps == infer({{n, m}, {n}, {m}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{n}, {m, n}, {m}})) { + std::swap(lhs, rhs); + } else if (maps == infer({{n}, {n, m}, {m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else { + return failure(); + } + } else { + return failure(); + } + + VectorType dstType = op.getResultType().cast(); + assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && + "Expected dst type of rank 1 or 2"); + + unsigned rank = dstType.getRank(); + unsigned dstRows = dstType.getShape()[0]; + unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; + + // ExtractOp does not allow dynamic indexing, we must unroll explicitly. + Value res = rewriter.create(loc, dstType, + rewriter.getZeroAttr(dstType)); + bool isInt = dstType.getElementType().isa(); + for (unsigned r = 0; r < dstRows; ++r) { + Value a = rewriter.create(op.getLoc(), lhs, r); + for (unsigned c = 0; c < dstColumns; ++c) { + Value b = rank == 1 + ? rhs + : rewriter.create(op.getLoc(), rhs, c); + Value m = createMul(op.getLoc(), a, b, isInt, rewriter); + Value reduced = rewriter.create( + op.getLoc(), vector::CombiningKind::ADD, m); + + SmallVector pos = rank == 1 ? SmallVector{r} + : SmallVector{r, c}; + res = rewriter.create(op.getLoc(), reduced, res, pos); + } + } + if (auto acc = op.getAcc()) + res = createAdd(op.getLoc(), res, acc, isInt, rewriter); + rewriter.replaceOp(op, res); + return success(); +} + +/// Lower vector.contract with all size one reduction dimensions to +/// elementwise ops when possible. +struct ContractOpToElementwise + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + ContractOpToElementwise( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1, + const FilterConstraintType &constraint = defaultFilter) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + // TODO: Support vector.mask. + auto maskableOp = cast(contractOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? + if (!contractOp.getMasks().empty()) + return failure(); + + if (failed(filter(contractOp))) + return failure(); + + if (vectorTransformOptions.vectorContractLowering != + vector::VectorContractLowering::ParallelArith) + return failure(); + + ArrayRef lhsShape = contractOp.getLhsType().getShape(); + ArrayRef rhsShape = contractOp.getRhsType().getShape(); + AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; + AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; + SmallVector lhsReductionDims = + getReductionIndex(lhsMap, contractOp.getIteratorTypes()); + SmallVector rhsReductionDims = + getReductionIndex(rhsMap, contractOp.getIteratorTypes()); + // All the reduction dimensions must be a size 1. + for (int64_t dim : lhsReductionDims) { + if (lhsShape[dim] != 1) + return failure(); + } + for (int64_t dim : rhsReductionDims) { + if (rhsShape[dim] != 1) + return failure(); + } + AffineMap accMap = contractOp.getIndexingMapsArray()[2]; + unsigned numParallelDims = accMap.getNumResults(); + unsigned numLhsDimToBroadcast = + numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); + unsigned numRhsDimToBroadcast = + numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); + SmallVector lhsDims; + SmallVector lhsTranspose; + SmallVector rhsDims; + SmallVector rhsTranspose; + for (int64_t dim : lhsReductionDims) + lhsTranspose.push_back(numLhsDimToBroadcast + dim); + for (int64_t dim : rhsReductionDims) + rhsTranspose.push_back(numRhsDimToBroadcast + dim); + // Loop through the parallel dimensions to calculate the dimensions to + // broadcast and to permute in order to extract only parallel dimensions. + for (unsigned i = 0; i < numParallelDims; i++) { + std::optional lhsDim = + getDimPosition(lhsMap, accMap.getDimPosition(i)); + if (lhsDim) { + lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); + } else { + // If the parallel dimension doesn't exist we will have to broadcast it. + lhsDims.push_back( + contractOp.getResultType().cast().getDimSize(i)); + lhsTranspose.push_back(lhsDims.size() - 1); + } + std::optional rhsDim = + getDimPosition(rhsMap, accMap.getDimPosition(i)); + if (rhsDim) { + rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); + } else { + // If the parallel dimension doesn't exist we will have to broadcast it. + rhsDims.push_back( + contractOp.getResultType().cast().getDimSize(i)); + rhsTranspose.push_back(rhsDims.size() - 1); + } + } + Value newLhs = contractOp.getLhs(); + Value newRhs = contractOp.getRhs(); + Location loc = contractOp.getLoc(); + if (!lhsDims.empty()) { + lhsDims.append(lhsShape.begin(), lhsShape.end()); + auto expandedType = + VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); + newLhs = rewriter.create(loc, expandedType, newLhs); + } + if (!rhsDims.empty()) { + rhsDims.append(rhsShape.begin(), rhsShape.end()); + auto expandedType = + VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); + newRhs = rewriter.create(loc, expandedType, newRhs); + } + bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); + newLhs = rewriter.create(loc, newLhs, lhsTranspose); + newRhs = rewriter.create(loc, newRhs, rhsTranspose); + SmallVector lhsOffsets(lhsReductionDims.size(), 0); + SmallVector rhsOffsets(rhsReductionDims.size(), 0); + newLhs = rewriter.create( + loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); + newRhs = rewriter.create( + loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); + std::optional result = + createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), + contractOp.getKind(), rewriter, isInt); + rewriter.replaceOp(contractOp, {*result}); + return success(); + } + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; + FilterConstraintType filter; +}; + +/// Progressive lowering of ContractionOp. +/// One: +/// %x = vector.contract with at least one free/batch dimension +/// is replaced by: +/// %a = vector.contract with one less free/batch dimension +/// %b = vector.contract with one less free/batch dimension +/// .. +/// %x = combine %a %b .. +/// until a pure contraction is reached (no free/batch dimensions), +/// which is replaced by a dot-product. +/// +/// This only kicks in when either VectorTransformsOptions is set +/// to DOT or when other contraction patterns fail. +// +// TODO: break down into transpose/reshape/cast ops +// when they become available to avoid code dup +// TODO: investigate lowering order impact on performance +LogicalResult +ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + // TODO: Remove native masks from contraction op? + if (!op.getMasks().empty()) + return failure(); + + if (failed(filter(op))) + return failure(); + + // TODO: support mixed mode contract lowering. + if (op.getLhsType().getElementType() != + getElementTypeOrSelf(op.getAccType()) || + op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) + return failure(); + + // TODO: the code below assumes the default contraction, make sure it supports + // other kinds before enabling this lowering. + if (op.getKind() != vector::CombiningKind::ADD) { + return rewriter.notifyMatchFailure( + op, "contractions other than 'add' not supported"); + } + + // TODO: implement benefits, cost models. + MLIRContext *ctx = op.getContext(); + ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); + if (succeeded(pat1.matchAndRewrite(op, rewriter))) + return success(); + ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); + if (succeeded(pat2.matchAndRewrite(op, rewriter))) + return success(); + ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); + if (succeeded(pat3.matchAndRewrite(op, rewriter))) + return success(); + ContractOpToElementwise pat4(vectorTransformOptions, ctx); + if (succeeded(pat4.matchAndRewrite(op, rewriter))) + return success(); + + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + Operation *rootOp = op; + Value mask; + if (op.isMasked()) { + rewriter.setInsertionPoint(op.getMaskingOp()); + rootOp = op.getMaskingOp(); + mask = op.getMaskingOp().getMask(); + } + + // Find first batch dimension in LHS/RHS, and lower when found. + std::vector> batchDimMap = op.getBatchDimMap(); + if (!batchDimMap.empty()) { + int64_t lhsIndex = batchDimMap[0].first; + int64_t rhsIndex = batchDimMap[0].second; + auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(rootOp, *newOp); + return success(); + } + + // Collect contracting dimensions. + std::vector> contractingDimMap = + op.getContractingDimMap(); + DenseSet lhsContractingDimSet; + DenseSet rhsContractingDimSet; + for (auto &dimPair : contractingDimMap) { + lhsContractingDimSet.insert(dimPair.first); + rhsContractingDimSet.insert(dimPair.second); + } + + // Find first free dimension in LHS, and lower when found. + VectorType lhsType = op.getLhsType(); + for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { + if (lhsContractingDimSet.count(lhsIndex) == 0) { + auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(rootOp, *newOp); + return success(); + } + } + + // Find first free dimension in RHS, and lower when found. + VectorType rhsType = op.getRhsType(); + for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { + if (rhsContractingDimSet.count(rhsIndex) == 0) { + auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(rootOp, *newOp); + return success(); + } + } + + // Lower the first remaining reduction dimension. + if (!contractingDimMap.empty()) { + auto newOp = lowerReduction(rewriter, op, mask); + if (failed(newOp)) + return failure(); + rewriter.replaceOp(rootOp, *newOp); + return success(); + } + + return failure(); +} + +// Lower one parallel dimension. +// Incidentally also tolerates unit-size (hence trivial) reduction dimensions. +// TODO: consider reusing existing contract unrolling +FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, + vector::ContractionOp op, + int64_t lhsIndex, + int64_t rhsIndex, + Value mask) const { + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + VectorType resType = op.getResultType().cast(); + // Find the iterator type index and result index. + SmallVector iMap = op.getIndexingMapsArray(); + int64_t iterIndex = -1; + int64_t dimSize = -1; + if (lhsIndex >= 0) { + iterIndex = iMap[0].getDimPosition(lhsIndex); + if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex + << " to map to the same dimension"; + }); + dimSize = lhsType.getDimSize(lhsIndex); + } else if (rhsIndex >= 0) { + iterIndex = iMap[1].getDimPosition(rhsIndex); + dimSize = rhsType.getDimSize(rhsIndex); + } + if (iterIndex < 0) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected either lhsIndex=" << lhsIndex + << " or rhsIndex=" << rhsIndex << " to be nonnegative"; + }); + // value_or(-1) means that we tolerate a dimension not appearing + // in the result map. That can't happen for actual parallel iterators, but + // the caller ContractionOpLowering::matchAndRewrite is currently calling + // lowerParallel also for the case of unit-size reduction dims appearing only + // on one of LHS or RHS, not both. At the moment, such cases are created by + // CastAwayContractionLeadingOneDim, so we need to either support that or + // modify that pattern. + int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); + if (resIndex == -1 && dimSize != 1) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected the dimension for iterIndex=" << iterIndex + << " to either appear in the result map, or to be a unit dimension"; + }); + + // Construct new iterator types and affine map array attribute. + std::array lowIndexingMaps = { + adjustMap(iMap[0], iterIndex, rewriter), + adjustMap(iMap[1], iterIndex, rewriter), + adjustMap(iMap[2], iterIndex, rewriter)}; + auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); + auto lowIter = + rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); + // Unroll into a series of lower dimensional vector.contract ops. + Location loc = op.getLoc(); + Value result = rewriter.create( + loc, resType, rewriter.getZeroAttr(resType)); + + for (int64_t d = 0; d < dimSize; ++d) { + auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); + auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); + + Value lowMask; + if (mask) + lowMask = reshapeLoad(loc, mask, cast(mask.getType()), + iterIndex, d, rewriter); + + Operation *lowContract = rewriter.create( + loc, lhs, rhs, acc, lowAffine, lowIter); + lowContract = maskOperation(rewriter, lowContract, lowMask); + result = reshapeStore(loc, lowContract->getResult(0), result, resType, + resIndex, d, rewriter); + } + return result; +} + +// Lower one reduction dimension. +FailureOr ContractionOpLowering::lowerReduction( + PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { + auto loc = op.getLoc(); + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + Type resType = op.getResultType(); + if (resType.isa()) + return rewriter.notifyMatchFailure(op, + "did not expect a VectorType result"); + bool isInt = resType.isa(); + // Use iterator index 0. + int64_t iterIndex = 0; + SmallVector iMap = op.getIndexingMapsArray(); + std::optional lookupLhs = getResultIndex(iMap[0], iterIndex); + std::optional lookupRhs = getResultIndex(iMap[1], iterIndex); + if (!lookupLhs.has_value()) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; + }); + if (!lookupRhs.has_value()) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; + }); + int64_t lhsIndex = *lookupLhs; + int64_t rhsIndex = *lookupRhs; + int64_t dimSize = lhsType.getDimSize(lhsIndex); + if (dimSize != rhsType.getDimSize(rhsIndex)) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "expect LHS dimension " << lhsIndex + << " to have the same size as RHS dimension " << rhsIndex; + }); + // Base case. + if (lhsType.getRank() == 1) { + if (rhsType.getRank() != 1) + return rewriter.notifyMatchFailure( + op, "When LHS has rank 1, expected also RHS to have rank 1"); + Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); + auto kind = vector::CombiningKind::ADD; + + Value acc = op.getAcc(); + Operation *reductionOp = + acc ? rewriter.create(loc, kind, m, acc) + : rewriter.create(loc, kind, m); + return maskOperation(rewriter, reductionOp, mask)->getResult(0); + } + // Construct new iterator types and affine map array attribute. + std::array lowIndexingMaps = { + adjustMap(iMap[0], iterIndex, rewriter), + adjustMap(iMap[1], iterIndex, rewriter), + adjustMap(iMap[2], iterIndex, rewriter)}; + auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); + auto lowIter = + rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); + // Unroll into a series of lower dimensional vector.contract ops. + // By feeding the initial accumulator into the first contraction, + // and the result of each contraction into the next, eventually + // the sum of all reductions is computed. + Value result = op.getAcc(); + for (int64_t d = 0; d < dimSize; ++d) { + auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); + auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); + Value newMask; + if (mask) + newMask = reshapeLoad(loc, mask, cast(mask.getType()), + iterIndex, d, rewriter); + + Operation *newContract = rewriter.create( + loc, lhs, rhs, result, lowAffine, lowIter); + result = maskOperation(rewriter, newContract, newMask)->getResult(0); + } + return result; +} + +/// Progressive lowering of OuterProductOp. +/// One: +/// %x = vector.outerproduct %lhs, %rhs, %acc +/// is replaced by: +/// %z = zero-result +/// %0 = vector.extract %lhs[0] +/// %1 = vector.broadcast %0 +/// %2 = vector.extract %acc[0] +/// %3 = vector.fma %1, %rhs, %2 +/// %4 = vector.insert %3, %z[0] +/// .. +/// %x = vector.insert %.., %..[N-1] +/// +class OuterProductOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::OuterProductOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType lhsType = op.getOperandVectorTypeLHS(); + VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); + VectorType resType = op.getResultVectorType(); + Type eltType = resType.getElementType(); + bool isInt = eltType.isa(); + Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; + vector::CombiningKind kind = op.getKind(); + + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = cast(op.getOperation()); + Operation *rootOp; + Value mask; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + mask = maskableOp.getMaskingOp().getMask(); + } else { + rootOp = op; + } + + if (!rhsType) { + // Special case: AXPY operation. + Value b = rewriter.create(loc, lhsType, op.getRhs()); + std::optional mult = createContractArithOp( + loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); + if (!mult.has_value()) + return failure(); + rewriter.replaceOp(rootOp, *mult); + return success(); + } + + Value result = rewriter.create( + loc, resType, rewriter.getZeroAttr(resType)); + for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { + auto pos = rewriter.getI64ArrayAttr(d); + Value x = rewriter.create(loc, op.getLhs(), pos); + Value a = rewriter.create(loc, rhsType, x); + Value r = nullptr; + if (acc) + r = rewriter.create(loc, acc, pos); + Value extrMask; + if (mask) + extrMask = rewriter.create(loc, mask, pos); + + std::optional m = createContractArithOp( + loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); + if (!m.has_value()) + return failure(); + result = rewriter.create(loc, resType, *m, result, pos); + } + + rewriter.replaceOp(rootOp, result); + return success(); + } +}; + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %mta = maybe_transpose +/// %mtb = maybe_transpose +/// %flattened_a = vector.shape_cast %mta +/// %flattened_b = vector.shape_cast %mtb +/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %mtd = vector.shape_cast %flattened_d +/// %d = maybe_untranspose %mtd +/// %e = add %c, %d +/// ``` +/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +// +/// This only kicks in when VectorTransformsOptions is set to `Matmul`. +/// vector.transpose operations are inserted if the vector.contract op is not a +/// row-major matrix multiply. +LogicalResult +ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rew) const { + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? + if (!op.getMasks().empty()) + return failure(); + if (vectorTransformOptions.vectorContractLowering != + vector::VectorContractLowering::Matmul) + return failure(); + if (failed(filter(op))) + return failure(); + + auto iteratorTypes = op.getIteratorTypes().getValue(); + if (!isParallelIterator(iteratorTypes[0]) || + !isParallelIterator(iteratorTypes[1]) || + !isReductionIterator(iteratorTypes[2])) + return failure(); + + Type elementType = op.getLhsType().getElementType(); + if (!elementType.isIntOrFloat()) + return failure(); + + Type dstElementType = op.getType(); + if (auto vecType = dstElementType.dyn_cast()) + dstElementType = vecType.getElementType(); + if (elementType != dstElementType) + return failure(); + + // Perform lhs + rhs transpositions to conform to matmul row-major semantics. + // Bail out if the contraction cannot be put in this form. + MLIRContext *ctx = op.getContext(); + Location loc = op.getLoc(); + AffineExpr m, n, k; + bindDims(rew.getContext(), m, n, k); + // LHS must be A(m, k) or A(k, m). + Value lhs = op.getLhs(); + auto lhsMap = op.getIndexingMapsArray()[0]; + if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) + lhs = rew.create(loc, lhs, ArrayRef{1, 0}); + else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) + return failure(); + + // RHS must be B(k, n) or B(n, k). + Value rhs = op.getRhs(); + auto rhsMap = op.getIndexingMapsArray()[1]; + if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) + rhs = rew.create(loc, rhs, ArrayRef{1, 0}); + else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) + return failure(); + + // At this point lhs and rhs are in row-major. + VectorType lhsType = lhs.getType().cast(); + VectorType rhsType = rhs.getType().cast(); + int64_t lhsRows = lhsType.getDimSize(0); + int64_t lhsColumns = lhsType.getDimSize(1); + int64_t rhsColumns = rhsType.getDimSize(1); + + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + lhs = rew.create(loc, flattenedLHSType, lhs); + + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + rhs = rew.create(loc, flattenedRHSType, rhs); + + Value mul = rew.create(loc, lhs, rhs, lhsRows, lhsColumns, + rhsColumns); + mul = rew.create( + loc, + VectorType::get({lhsRows, rhsColumns}, + getElementTypeOrSelf(op.getAcc().getType())), + mul); + + // ACC must be C(m, n) or C(n, m). + auto accMap = op.getIndexingMapsArray()[2]; + if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) + mul = rew.create(loc, mul, ArrayRef{1, 0}); + else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) + llvm_unreachable("invalid contraction semantics"); + + Value res = + elementType.isa() + ? static_cast(rew.create(loc, op.getAcc(), mul)) + : static_cast( + rew.create(loc, op.getAcc(), mul)); + + rew.replaceOp(op, res); + return success(); +} +} // namespace + +void mlir::vector::populateVectorContractLoweringPatterns( + RewritePatternSet &patterns, VectorTransformsOptions options, + PatternBenefit benefit, bool disableOuterProductLowering) { + if (!disableOuterProductLowering) + patterns.add(patterns.getContext(), benefit); + patterns.add( + options, patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -0,0 +1,173 @@ +//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.scan' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "vector-broadcast-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { +/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru : +/// ... into vector<2x3xf32> +/// +/// ==> +/// +/// %0 = arith.constant dense<0.0> : vector<2x3xf32> +/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ... +/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ... +/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d gather ops. +struct FlattenGather : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::GatherOp op, + PatternRewriter &rewriter) const override { + VectorType resultTy = op.getType(); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already flat"); + + Location loc = op.getLoc(); + Value indexVec = op.getIndexVec(); + Value maskVec = op.getMask(); + Value passThruVec = op.getPassThru(); + + Value result = rewriter.create( + loc, resultTy, rewriter.getZeroAttr(resultTy)); + + Type subTy = VectorType::get(resultTy.getShape().drop_front(), + resultTy.getElementType()); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + int64_t thisIdx[1] = {i}; + + Value indexSubVec = + rewriter.create(loc, indexVec, thisIdx); + Value maskSubVec = + rewriter.create(loc, maskVec, thisIdx); + Value passThruSubVec = + rewriter.create(loc, passThruVec, thisIdx); + Value subGather = rewriter.create( + loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, + passThruSubVec); + result = + rewriter.create(loc, subGather, result, thisIdx); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or +/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these +/// loads/extracts are made conditional using `scf.if` ops. +struct Gather1DToConditionalLoads : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::GatherOp op, + PatternRewriter &rewriter) const override { + VectorType resultTy = op.getType(); + if (resultTy.getRank() != 1) + return rewriter.notifyMatchFailure(op, "unsupported rank"); + + Location loc = op.getLoc(); + Type elemTy = resultTy.getElementType(); + // Vector type with a single element. Used to generate `vector.loads`. + VectorType elemVecTy = VectorType::get({1}, elemTy); + + Value condMask = op.getMask(); + Value base = op.getBase(); + Value indexVec = rewriter.createOrFold( + loc, op.getIndexVectorType().clone(rewriter.getIndexType()), + op.getIndexVec()); + auto baseOffsets = llvm::to_vector(op.getIndices()); + Value lastBaseOffset = baseOffsets.back(); + + Value result = op.getPassThru(); + + // Emit a conditional access for each vector element. + for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) { + int64_t thisIdx[1] = {i}; + Value condition = + rewriter.create(loc, condMask, thisIdx); + Value index = rewriter.create(loc, indexVec, thisIdx); + baseOffsets.back() = + rewriter.createOrFold(loc, lastBaseOffset, index); + + auto loadBuilder = [&](OpBuilder &b, Location loc) { + Value extracted; + if (isa(base.getType())) { + // `vector.load` does not support scalar result; emit a vector load + // and extract the single result instead. + Value load = + b.create(loc, elemVecTy, base, baseOffsets); + int64_t zeroIdx[1] = {0}; + extracted = b.create(loc, load, zeroIdx); + } else { + extracted = b.create(loc, base, baseOffsets); + } + + Value newResult = + b.create(loc, extracted, result, thisIdx); + b.create(loc, newResult); + }; + auto passThruBuilder = [result](OpBuilder &b, Location loc) { + b.create(loc, result); + }; + + result = + rewriter + .create(loc, condition, /*thenBuilder=*/loadBuilder, + /*elseBuilder=*/passThruBuilder) + .getResult(0); + } + + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +void mlir::vector::populateVectorGatherLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), + benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements target-independent rewrites and utilitites to lower the +// This file implements target-independent rewrites and utilities to lower the // 'vector.mask' operation. // //===----------------------------------------------------------------------===// @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -30,6 +31,147 @@ using namespace mlir; using namespace mlir::vector; +//===----------------------------------------------------------------------===// +// populateVectorMaskOpLoweringPatterns +//===----------------------------------------------------------------------===// + +namespace { +/// Progressive lowering of CreateMaskOp. +/// One: +/// %x = vector.create_mask %a, ... : vector +/// is replaced by: +/// %l = vector.create_mask ... : vector<...> ; one lower rank +/// %0 = arith.cmpi "slt", %ci, %a | +/// %1 = select %0, %l, %zeroes | +/// %r = vector.insert %1, %pr [i] | d-times +/// %x = .... +/// until a one-dimensional vector is reached. +class CreateMaskOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::CreateMaskOp op, + PatternRewriter &rewriter) const override { + auto dstType = op.getResult().getType().cast(); + int64_t rank = dstType.getRank(); + if (rank <= 1) + return rewriter.notifyMatchFailure( + op, "0-D and 1-D vectors are handled separately"); + + auto loc = op.getLoc(); + auto eltType = dstType.getElementType(); + int64_t dim = dstType.getDimSize(0); + Value idx = op.getOperand(0); + + VectorType lowType = + VectorType::get(dstType.getShape().drop_front(), eltType); + Value trueVal = rewriter.create( + loc, lowType, op.getOperands().drop_front()); + Value falseVal = rewriter.create( + loc, lowType, rewriter.getZeroAttr(lowType)); + Value result = rewriter.create( + loc, dstType, rewriter.getZeroAttr(dstType)); + for (int64_t d = 0; d < dim; d++) { + Value bnd = + rewriter.create(loc, rewriter.getIndexAttr(d)); + Value val = rewriter.create(loc, arith::CmpIPredicate::slt, + bnd, idx); + Value sel = rewriter.create(loc, val, trueVal, falseVal); + auto pos = rewriter.getI64ArrayAttr(d); + result = + rewriter.create(loc, dstType, sel, result, pos); + } + rewriter.replaceOp(op, result); + return success(); + } +}; + +/// Progressive lowering of ConstantMaskOp. +/// One: +/// %x = vector.constant_mask [a,b] +/// is replaced by: +/// %z = zero-result +/// %l = vector.constant_mask [b] +/// %4 = vector.insert %l, %z[0] +/// .. +/// %x = vector.insert %l, %..[a-1] +/// until a one-dimensional vector is reached. All these operations +/// will be folded at LLVM IR level. +class ConstantMaskOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ConstantMaskOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto dstType = op.getType(); + auto eltType = dstType.getElementType(); + auto dimSizes = op.getMaskDimSizes(); + int64_t rank = dstType.getRank(); + + if (rank == 0) { + assert(dimSizes.size() == 1 && + "Expected exactly one dim size for a 0-D vector"); + bool value = dimSizes[0].cast().getInt() == 1; + rewriter.replaceOpWithNewOp( + op, dstType, + DenseIntElementsAttr::get( + VectorType::get(ArrayRef{}, rewriter.getI1Type()), + ArrayRef{value})); + return success(); + } + + // Scalable constant masks can only be lowered for the "none set" case. + if (dstType.cast().isScalable()) { + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(dstType, false)); + return success(); + } + + int64_t trueDim = std::min(dstType.getDimSize(0), + dimSizes[0].cast().getInt()); + + if (rank == 1) { + // Express constant 1-D case in explicit vector form: + // [T,..,T,F,..,F]. + SmallVector values(dstType.getDimSize(0)); + for (int64_t d = 0; d < trueDim; d++) + values[d] = true; + rewriter.replaceOpWithNewOp( + op, dstType, rewriter.getBoolVectorAttr(values)); + return success(); + } + + VectorType lowType = + VectorType::get(dstType.getShape().drop_front(), eltType); + SmallVector newDimSizes; + for (int64_t r = 1; r < rank; r++) + newDimSizes.push_back(dimSizes[r].cast().getInt()); + Value trueVal = rewriter.create( + loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); + Value result = rewriter.create( + loc, dstType, rewriter.getZeroAttr(dstType)); + for (int64_t d = 0; d < trueDim; d++) { + auto pos = rewriter.getI64ArrayAttr(d); + result = + rewriter.create(loc, dstType, trueVal, result, pos); + } + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + +void mlir::vector::populateVectorMaskOpLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add( + patterns.getContext(), benefit); +} + +//===----------------------------------------------------------------------===// +// populateVectorMaskLoweringPatternsForSideEffectingOps +//===----------------------------------------------------------------------===// + namespace { /// The `MaskOpRewritePattern` implements a pattern that follows a two-fold diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp rename from mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -1,4 +1,4 @@ -//===- VectorMultiDimReductionTransforms.cpp - Multi-Reduction Transforms -===// +//===- LowerVectorMultiReduction.cpp - Lower `vector.multi_reduction` op --===// // /// Part of the LLVM Project, under the Apache License v2.0 with LLVM /// Exceptions. See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,13 @@ // //===----------------------------------------------------------------------===// // -/// This file implements target-independent rewrites of MultiDimReductionOp. +// This file implements target-independent rewrites and utilities to lower the +// 'vector.multi_reduction' operation. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/IR/Builders.h" #include "mlir/IR/TypeUtilities.h" @@ -19,6 +20,7 @@ using namespace mlir; +namespace { /// This file implements the following transformations as composable atomic /// patterns. @@ -441,6 +443,7 @@ return success(); } }; +} // namespace void mlir::vector::populateVectorMultiReductionLoweringPatterns( RewritePatternSet &patterns, VectorMultiReductionLowering options, diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -0,0 +1,251 @@ +//===- LowerVectorScam.cpp - Lower 'vector.scan' operation ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.scan' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "vector-broadcast-lowering" + +using namespace mlir; +using namespace mlir::vector; + +/// This function constructs the appropriate integer or float +/// operation given the vector combining kind and operands. The +/// supported int operations are : add, mul, min (signed/unsigned), +/// max(signed/unsigned), and, or, xor. The supported float +/// operations are : add, mul, min and max. +static Value genOperator(Location loc, Value x, Value y, + vector::CombiningKind kind, + PatternRewriter &rewriter) { + using vector::CombiningKind; + + auto elType = x.getType().cast().getElementType(); + bool isInt = elType.isIntOrIndex(); + + Value combinedResult{nullptr}; + switch (kind) { + case CombiningKind::ADD: + if (isInt) + combinedResult = rewriter.create(loc, x, y); + else + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MUL: + if (isInt) + combinedResult = rewriter.create(loc, x, y); + else + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MINUI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MINSI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MAXUI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MAXSI: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::AND: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::OR: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::XOR: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MINF: + combinedResult = rewriter.create(loc, x, y); + break; + case CombiningKind::MAXF: + combinedResult = rewriter.create(loc, x, y); + break; + } + return combinedResult; +} + +/// This function checks to see if the vector combining kind +/// is consistent with the integer or float element type. +static bool isValidKind(bool isInt, vector::CombiningKind kind) { + using vector::CombiningKind; + enum class KindType { FLOAT, INT, INVALID }; + KindType type{KindType::INVALID}; + switch (kind) { + case CombiningKind::MINF: + case CombiningKind::MAXF: + type = KindType::FLOAT; + break; + case CombiningKind::MINUI: + case CombiningKind::MINSI: + case CombiningKind::MAXUI: + case CombiningKind::MAXSI: + case CombiningKind::AND: + case CombiningKind::OR: + case CombiningKind::XOR: + type = KindType::INT; + break; + case CombiningKind::ADD: + case CombiningKind::MUL: + type = isInt ? KindType::INT : KindType::FLOAT; + break; + } + bool isValidIntKind = (type == KindType::INT) && isInt; + bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); + return (isValidIntKind || isValidFloatKind); +} + +namespace { +/// Convert vector.scan op into arith ops and vector.insert_strided_slice / +/// vector.extract_strided_slice. +/// +/// Example: +/// +/// ``` +/// %0:2 = vector.scan , %arg0, %arg1 +/// {inclusive = true, reduction_dim = 1} : +/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) +/// ``` +/// +/// is converted to: +/// +/// ``` +/// %cst = arith.constant dense<0> : vector<2x3xi32> +/// %0 = vector.extract_strided_slice %arg0 +/// {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} +/// : vector<2x3xi32> to vector<2x1xi32> +/// %1 = vector.insert_strided_slice %0, %cst +/// {offsets = [0, 0], strides = [1, 1]} +/// : vector<2x1xi32> into vector<2x3xi32> +/// %2 = vector.extract_strided_slice %arg0 +/// {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} +/// : vector<2x3xi32> to vector<2x1xi32> +/// %3 = arith.muli %0, %2 : vector<2x1xi32> +/// %4 = vector.insert_strided_slice %3, %1 +/// {offsets = [0, 1], strides = [1, 1]} +/// : vector<2x1xi32> into vector<2x3xi32> +/// %5 = vector.extract_strided_slice %arg0 +/// {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} +/// : vector<2x3xi32> to vector<2x1xi32> +/// %6 = arith.muli %3, %5 : vector<2x1xi32> +/// %7 = vector.insert_strided_slice %6, %4 +/// {offsets = [0, 2], strides = [1, 1]} +/// : vector<2x1xi32> into vector<2x3xi32> +/// %8 = vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> +/// return %7, %8 : vector<2x3xi32>, vector<2xi32> +/// ``` +struct ScanToArithOps : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ScanOp scanOp, + PatternRewriter &rewriter) const override { + auto loc = scanOp.getLoc(); + VectorType destType = scanOp.getDestType(); + ArrayRef destShape = destType.getShape(); + auto elType = destType.getElementType(); + bool isInt = elType.isIntOrIndex(); + if (!isValidKind(isInt, scanOp.getKind())) + return failure(); + + VectorType resType = VectorType::get(destShape, elType); + Value result = rewriter.create( + loc, resType, rewriter.getZeroAttr(resType)); + int64_t reductionDim = scanOp.getReductionDim(); + bool inclusive = scanOp.getInclusive(); + int64_t destRank = destType.getRank(); + VectorType initialValueType = scanOp.getInitialValueType(); + int64_t initialValueRank = initialValueType.getRank(); + + SmallVector reductionShape(destShape.begin(), destShape.end()); + reductionShape[reductionDim] = 1; + VectorType reductionType = VectorType::get(reductionShape, elType); + SmallVector offsets(destRank, 0); + SmallVector strides(destRank, 1); + SmallVector sizes(destShape.begin(), destShape.end()); + sizes[reductionDim] = 1; + ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); + ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); + + Value lastOutput, lastInput; + for (int i = 0; i < destShape[reductionDim]; i++) { + offsets[reductionDim] = i; + ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); + Value input = rewriter.create( + loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, + scanStrides); + Value output; + if (i == 0) { + if (inclusive) { + output = input; + } else { + if (initialValueRank == 0) { + // ShapeCastOp cannot handle 0-D vectors + output = rewriter.create( + loc, input.getType(), scanOp.getInitialValue()); + } else { + output = rewriter.create( + loc, input.getType(), scanOp.getInitialValue()); + } + } + } else { + Value y = inclusive ? input : lastInput; + output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter); + assert(output != nullptr); + } + result = rewriter.create( + loc, output, result, offsets, strides); + lastOutput = output; + lastInput = input; + } + + Value reduction; + if (initialValueRank == 0) { + Value v = rewriter.create(loc, lastOutput, 0); + reduction = + rewriter.create(loc, initialValueType, v); + } else { + reduction = rewriter.create(loc, initialValueType, + lastOutput); + } + + rewriter.replaceOp(scanOp, {result, reduction}); + return success(); + } +}; +} // namespace + +void mlir::vector::populateVectorScanLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -0,0 +1,177 @@ +//===- LowerVectorShapeCast.cpp - Lower 'vector.shape_cast' operation -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.shape_cast' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "vector-shape-cast-lowering" + +using namespace mlir; +using namespace mlir::vector; + +namespace { +/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D +/// vectors progressively on the way to target llvm.matrix intrinsics. +/// This iterates over the most major dimension of the 2-D vector and performs +/// rewrites into: +/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D +class ShapeCastOp2DDownCastRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) + return failure(); + + auto loc = op.getLoc(); + Value desc = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); + unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; + for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { + Value vec = rewriter.create(loc, op.getSource(), i); + desc = rewriter.create( + loc, vec, desc, + /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); + } + rewriter.replaceOp(op, desc); + return success(); + } +}; + +/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D +/// vectors progressively. +/// This iterates over the most major dimension of the 2-D vector and performs +/// rewrites into: +/// vector.extract_strided_slice from 1-D + vector.insert into 2-D +/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. +class ShapeCastOp2DUpCastRewritePattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) + return failure(); + + auto loc = op.getLoc(); + Value desc = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); + unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; + for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { + Value vec = rewriter.create( + loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, + /*sizes=*/mostMinorVectorSize, + /*strides=*/1); + desc = rewriter.create(loc, vec, desc, i); + } + rewriter.replaceOp(op, desc); + return success(); + } +}; + +// We typically should not lower general shape cast operations into data +// movement instructions, since the assumption is that these casts are +// optimized away during progressive lowering. For completeness, however, +// we fall back to a reference implementation that moves all elements +// into the right place if we get here. +class ShapeCastOpRewritePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ShapeCastOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto sourceVectorType = op.getSourceVectorType(); + auto resultVectorType = op.getResultVectorType(); + + // Special case 2D / 1D lowerings with better implementations. + // TODO: make is ND / 1D to allow generic ND -> 1D -> MD. + int64_t srcRank = sourceVectorType.getRank(); + int64_t resRank = resultVectorType.getRank(); + if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) + return failure(); + + // Generic ShapeCast lowering path goes all the way down to unrolled scalar + // extract/insert chains. + // TODO: consider evolving the semantics to only allow 1D source or dest and + // drop this potentially very expensive lowering. + // Compute number of elements involved in the reshape. + int64_t numElts = 1; + for (int64_t r = 0; r < srcRank; r++) + numElts *= sourceVectorType.getDimSize(r); + // Replace with data movement operations: + // x[0,0,0] = y[0,0] + // x[0,0,1] = y[0,1] + // x[0,1,0] = y[0,2] + // etc., incrementing the two index vectors "row-major" + // within the source and result shape. + SmallVector srcIdx(srcRank); + SmallVector resIdx(resRank); + Value result = rewriter.create( + loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); + for (int64_t i = 0; i < numElts; i++) { + if (i != 0) { + incIdx(srcIdx, sourceVectorType, srcRank - 1); + incIdx(resIdx, resultVectorType, resRank - 1); + } + Value e = rewriter.create(loc, op.getSource(), srcIdx); + result = rewriter.create(loc, e, result, resIdx); + } + rewriter.replaceOp(op, result); + return success(); + } + +private: + static void incIdx(SmallVector &idx, VectorType tp, int64_t r) { + assert(0 <= r && r < tp.getRank()); + if (++idx[r] == tp.getDimSize(r)) { + idx[r] = 0; + incIdx(idx, tp, r - 1); + } + } +}; +} // namespace + +void mlir::vector::populateVectorShapeCastLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add( + patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp rename from mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp rename to mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferPermutationMapRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -14,7 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Interfaces/VectorInterfaces.h" using namespace mlir; @@ -46,6 +46,11 @@ return builder.create(loc, newVecType, vec); } +//===----------------------------------------------------------------------===// +// populateVectorTransferPermutationMapLoweringPatterns +//===----------------------------------------------------------------------===// + +namespace { /// Lower transfer_read op with permutation into a transfer_read with a /// permutation map composed of leading zeros followed by a minor identiy + /// vector.transpose op. @@ -332,6 +337,8 @@ } }; +} // namespace + void mlir::vector::populateVectorTransferPermutationMapLoweringPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns @@ -339,3 +346,239 @@ TransferOpReduceRank, TransferWriteNonPermutationLowering>( patterns.getContext(), benefit); } + +//===----------------------------------------------------------------------===// +// populateVectorTransferLoweringPatterns +//===----------------------------------------------------------------------===// + +namespace { +/// Progressive lowering of transfer_read. This pattern supports lowering of +/// `vector.transfer_read` to a combination of `vector.load` and +/// `vector.broadcast` if all of the following hold: +/// - Stride of most minor memref dimension must be 1. +/// - Out-of-bounds masking is not required. +/// - If the memref's element type is a vector type then it coincides with the +/// result type. +/// - The permutation map doesn't perform permutation (broadcasting is allowed). +struct TransferReadToVectorLoadLowering + : public OpRewritePattern { + TransferReadToVectorLoadLowering(MLIRContext *context, + std::optional maxRank, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + maxTransferRank(maxRank) {} + + LogicalResult matchAndRewrite(vector::TransferReadOp read, + PatternRewriter &rewriter) const override { + if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) + return failure(); + + SmallVector broadcastedDims; + // Permutations are handled by VectorToSCF or + // populateVectorTransferPermutationMapLoweringPatterns. + // We let the 0-d corner case pass-through as it is supported. + if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( + &broadcastedDims)) + return failure(); + + auto memRefType = read.getShapedType().dyn_cast(); + if (!memRefType) + return failure(); + + // Non-unit strides are handled by VectorToSCF. + if (!vector::isLastMemrefDimUnitStride(memRefType)) + return failure(); + + // If there is broadcasting involved then we first load the unbroadcasted + // vector, and then broadcast it with `vector.broadcast`. + ArrayRef vectorShape = read.getVectorType().getShape(); + SmallVector unbroadcastedVectorShape(vectorShape.begin(), + vectorShape.end()); + for (unsigned i : broadcastedDims) + unbroadcastedVectorShape[i] = 1; + VectorType unbroadcastedVectorType = VectorType::get( + unbroadcastedVectorShape, read.getVectorType().getElementType()); + + // `vector.load` supports vector types as memref's elements only when the + // resulting vector type is the same as the element type. + auto memrefElTy = memRefType.getElementType(); + if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) + return failure(); + + // Otherwise, element types of the memref and the vector must match. + if (!memrefElTy.isa() && + memrefElTy != read.getVectorType().getElementType()) + return failure(); + + // Out-of-bounds dims are handled by MaterializeTransferMask. + if (read.hasOutOfBoundsDim()) + return failure(); + + // Create vector load op. + Operation *loadOp; + if (read.getMask()) { + Value fill = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.getPadding()); + loadOp = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.getSource(), + read.getIndices(), read.getMask(), fill); + } else { + loadOp = rewriter.create( + read.getLoc(), unbroadcastedVectorType, read.getSource(), + read.getIndices()); + } + + // Insert a broadcasting op if required. + if (!broadcastedDims.empty()) { + rewriter.replaceOpWithNewOp( + read, read.getVectorType(), loadOp->getResult(0)); + } else { + rewriter.replaceOp(read, loadOp->getResult(0)); + } + + return success(); + } + + std::optional maxTransferRank; +}; + +/// Replace a 0-d vector.load with a memref.load + vector.broadcast. +// TODO: we shouldn't cross the vector/scalar domains just for this +// but atm we lack the infra to avoid it. Possible solutions include: +// - go directly to LLVM + bitcast +// - introduce a bitcast op and likely a new pointer dialect +// - let memref.load/store additionally support the 0-d vector case +// There are still deeper data layout issues lingering even in this +// trivial case (for architectures for which this matters). +struct VectorLoadToMemrefLoadLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::LoadOp loadOp, + PatternRewriter &rewriter) const override { + auto vecType = loadOp.getVectorType(); + if (vecType.getNumElements() != 1) + return failure(); + auto memrefLoad = rewriter.create( + loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); + rewriter.replaceOpWithNewOp(loadOp, vecType, + memrefLoad); + return success(); + } +}; + +/// Replace a 0-d vector.store with a vector.extractelement + memref.store. +struct VectorStoreToMemrefStoreLowering + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::StoreOp storeOp, + PatternRewriter &rewriter) const override { + auto vecType = storeOp.getVectorType(); + if (vecType.getNumElements() != 1) + return failure(); + Value extracted; + if (vecType.getRank() == 0) { + // TODO: Unifiy once ExtractOp supports 0-d vectors. + extracted = rewriter.create( + storeOp.getLoc(), storeOp.getValueToStore()); + } else { + SmallVector indices(vecType.getRank(), 0); + extracted = rewriter.create( + storeOp.getLoc(), storeOp.getValueToStore(), indices); + } + + rewriter.replaceOpWithNewOp( + storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); + return success(); + } +}; + +/// Progressive lowering of transfer_write. This pattern supports lowering of +/// `vector.transfer_write` to `vector.store` if all of the following hold: +/// - Stride of most minor memref dimension must be 1. +/// - Out-of-bounds masking is not required. +/// - If the memref's element type is a vector type then it coincides with the +/// type of the written value. +/// - The permutation map is the minor identity map (neither permutation nor +/// broadcasting is allowed). +struct TransferWriteToVectorStoreLowering + : public OpRewritePattern { + TransferWriteToVectorStoreLowering(MLIRContext *context, + std::optional maxRank, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + maxTransferRank(maxRank) {} + + LogicalResult matchAndRewrite(vector::TransferWriteOp write, + PatternRewriter &rewriter) const override { + if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "rank exceeds maxTransferRank: " << write; + }); + + // Permutations are handled by VectorToSCF or + // populateVectorTransferPermutationMapLoweringPatterns. + if ( // pass-through for the 0-d corner case. + !write.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "permutation map is not minor identity: " << write; + }); + + auto memRefType = write.getShapedType().dyn_cast(); + if (!memRefType) + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "not a memref type: " << write; + }); + + // Non-unit strides are handled by VectorToSCF. + if (!vector::isLastMemrefDimUnitStride(memRefType)) + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "most minor stride is not 1: " << write; + }); + + // `vector.store` supports vector types as memref's elements only when the + // type of the vector value being written is the same as the element type. + auto memrefElTy = memRefType.getElementType(); + if (memrefElTy.isa() && memrefElTy != write.getVectorType()) + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "elemental type mismatch: " << write; + }); + + // Otherwise, element types of the memref and the vector must match. + if (!memrefElTy.isa() && + memrefElTy != write.getVectorType().getElementType()) + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "elemental type mismatch: " << write; + }); + + // Out-of-bounds dims are handled by MaterializeTransferMask. + if (write.hasOutOfBoundsDim()) + return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { + diag << "out of bounds dim: " << write; + }); + if (write.getMask()) { + rewriter.replaceOpWithNewOp( + write, write.getSource(), write.getIndices(), write.getMask(), + write.getVector()); + } else { + rewriter.replaceOpWithNewOp( + write, write.getVector(), write.getSource(), write.getIndices()); + } + return success(); + } + + std::optional maxTransferRank; +}; +} // namespace + +void mlir::vector::populateVectorTransferLoweringPatterns( + RewritePatternSet &patterns, std::optional maxTransferRank, + PatternBenefit benefit) { + patterns.add(patterns.getContext(), + maxTransferRank, benefit); + patterns + .add( + patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -0,0 +1,210 @@ +//===- LowerVectorTranspose.cpp - Lower 'vector.transpose' operation ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.transpose' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/Vector/Utils/VectorUtils.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/VectorInterfaces.h" +#include "mlir/Support/LogicalResult.h" + +#define DEBUG_TYPE "vector-shape-cast-lowering" + +using namespace mlir; +using namespace mlir::vector; + +/// Given a 'transpose' pattern, prune the rightmost dimensions that are not +/// transposed. +static void pruneNonTransposedDims(ArrayRef transpose, + SmallVectorImpl &result) { + size_t numTransposedDims = transpose.size(); + for (size_t transpDim : llvm::reverse(transpose)) { + if (transpDim != numTransposedDims - 1) + break; + numTransposedDims--; + } + + result.append(transpose.begin(), transpose.begin() + numTransposedDims); +} + +namespace { +/// Progressive lowering of TransposeOp. +/// One: +/// %x = vector.transpose %y, [1, 0] +/// is replaced by: +/// %z = arith.constant dense<0.000000e+00> +/// %0 = vector.extract %y[0, 0] +/// %1 = vector.insert %0, %z [0, 0] +/// .. +/// %x = vector.insert .., .. [.., ..] +class TransposeOpLowering : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions) {} + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value input = op.getVector(); + VectorType inputType = op.getSourceVectorType(); + VectorType resType = op.getResultVectorType(); + + // Set up convenience transposition table. + SmallVector transp; + for (auto attr : op.getTransp()) + transp.push_back(attr.cast().getInt()); + + if (vectorTransformOptions.vectorTransposeLowering == + vector::VectorTransposeLowering::Shuffle && + resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) + return rewriter.notifyMatchFailure( + op, "Options specifies lowering to shuffle"); + + // Handle a true 2-D matrix transpose differently when requested. + if (vectorTransformOptions.vectorTransposeLowering == + vector::VectorTransposeLowering::Flat && + resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { + Type flattenedType = + VectorType::get(resType.getNumElements(), resType.getElementType()); + auto matrix = + rewriter.create(loc, flattenedType, input); + auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); + auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); + Value trans = rewriter.create( + loc, flattenedType, matrix, rows, columns); + rewriter.replaceOpWithNewOp(op, resType, trans); + return success(); + } + + // Generate unrolled extract/insert ops. We do not unroll the rightmost + // (i.e., highest-order) dimensions that are not transposed and leave them + // in vector form to improve performance. Therefore, we prune those + // dimensions from the shape/transpose data structures used to generate the + // extract/insert ops. + SmallVector prunedTransp; + pruneNonTransposedDims(transp, prunedTransp); + size_t numPrunedDims = transp.size() - prunedTransp.size(); + auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); + auto prunedInStrides = computeStrides(prunedInShape); + + // Generates the extract/insert operations for every scalar/vector element + // of the leftmost transposed dimensions. We traverse every transpose + // element using a linearized index that we delinearize to generate the + // appropriate indices for the extract/insert operations. + Value result = rewriter.create( + loc, resType, rewriter.getZeroAttr(resType)); + int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); + + for (int64_t linearIdx = 0; linearIdx < numTransposedElements; + ++linearIdx) { + auto extractIdxs = delinearize(linearIdx, prunedInStrides); + SmallVector insertIdxs(extractIdxs); + applyPermutationToVector(insertIdxs, prunedTransp); + Value extractOp = + rewriter.create(loc, input, extractIdxs); + result = + rewriter.create(loc, extractOp, result, insertIdxs); + } + + rewriter.replaceOp(op, result); + return success(); + } + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; +}; + +/// Rewrite a 2-D vector.transpose as a sequence of: +/// vector.shape_cast 2D -> 1D +/// vector.shuffle +/// vector.shape_cast 1D -> 2D +class TransposeOp2DToShuffleLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + TransposeOp2DToShuffleLowering( + vector::VectorTransformsOptions vectorTransformOptions, + MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + vectorTransformOptions(vectorTransformOptions) {} + + LogicalResult matchAndRewrite(vector::TransposeOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + VectorType srcType = op.getSourceVectorType(); + if (srcType.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); + + SmallVector transp; + for (auto attr : op.getTransp()) + transp.push_back(attr.cast().getInt()); + if (transp[0] != 1 && transp[1] != 0) + return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); + + if (vectorTransformOptions.vectorTransposeLowering != + VectorTransposeLowering::Shuffle) + return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); + + int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); + Value casted = rewriter.create( + loc, VectorType::get({m * n}, srcType.getElementType()), + op.getVector()); + SmallVector mask; + mask.reserve(m * n); + for (int64_t j = 0; j < n; ++j) + for (int64_t i = 0; i < m; ++i) + mask.push_back(i * n + j); + + Value shuffled = + rewriter.create(loc, casted, casted, mask); + rewriter.replaceOpWithNewOp( + op, op.getResultVectorType(), shuffled); + + return success(); + } + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformOptions; +}; +} // namespace + +void mlir::vector::populateVectorTransposeLoweringPatterns( + RewritePatternSet &patterns, VectorTransformsOptions options, + PatternBenefit benefit) { + patterns.add( + options, patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/BuiltinOps.h" diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -11,8 +11,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -92,11 +92,11 @@ } /// Split a vector.transfer operation into an in-bounds (i.e., no out-of-bounds -/// masking) fastpath and a slowpath. +/// masking) fast path and a slow path. /// If `ifOp` is not null and the result is `success, the `ifOp` points to the /// newly created conditional upon function return. -/// To accomodate for the fact that the original vector.transfer indexing may be -/// arbitrary and the slow path indexes @[0...0] in the temporary buffer, the +/// To accommodate for the fact that the original vector.transfer indexing may +/// be arbitrary and the slow path indexes @[0...0] in the temporary buffer, the /// scf.if op returns a view and values of type index. /// At this time, only vector.transfer_read case is implemented. /// @@ -107,11 +107,11 @@ /// is transformed into: /// ``` /// %1:3 = scf.if (%inBounds) { -/// // fastpath, direct cast +/// // fast path, direct cast /// memref.cast %A: memref to compatibleMemRefType /// scf.yield %view : compatibleMemRefType, index, index /// } else { -/// // slowpath, not in-bounds vector.transfer or linalg.copy. +/// // slow path, not in-bounds vector.transfer or linalg.copy. /// memref.cast %alloc: memref to compatibleMemRefType /// scf.yield %4 : compatibleMemRefType, index, index // } @@ -172,12 +172,10 @@ for (int64_t idx = 0, e = aT.getRank(); idx < e; ++idx) { resShape[idx] = (aShape[idx] == bShape[idx]) ? aShape[idx] : ShapedType::kDynamic; - resStrides[idx] = (aStrides[idx] == bStrides[idx]) - ? aStrides[idx] - : ShapedType::kDynamic; + resStrides[idx] = + (aStrides[idx] == bStrides[idx]) ? aStrides[idx] : ShapedType::kDynamic; } - resOffset = - (aOffset == bOffset) ? aOffset : ShapedType::kDynamic; + resOffset = (aOffset == bOffset) ? aOffset : ShapedType::kDynamic; return MemRefType::get( resShape, aT.getElementType(), StridedLayoutAttr::get(aT.getContext(), resOffset, resStrides)); @@ -634,7 +632,34 @@ return success(); } -LogicalResult mlir::vector::VectorTransferFullPartialRewriter::matchAndRewrite( +namespace { +/// Apply `splitFullAndPartialTransfer` selectively via a pattern. This pattern +/// may take an extra filter to perform selection at a finer granularity. +struct VectorTransferFullPartialRewriter : public RewritePattern { + using FilterConstraintType = + std::function; + + explicit VectorTransferFullPartialRewriter( + MLIRContext *context, + VectorTransformsOptions options = VectorTransformsOptions(), + FilterConstraintType filter = + [](VectorTransferOpInterface op) { return success(); }, + PatternBenefit benefit = 1) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), options(options), + filter(std::move(filter)) {} + + /// Performs the rewrite. + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + VectorTransformsOptions options; + FilterConstraintType filter; +}; + +} // namespace + +LogicalResult VectorTransferFullPartialRewriter::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { auto xferOp = dyn_cast(op); if (!xferOp || failed(splitFullAndPartialTransferPrecondition(xferOp)) || @@ -642,3 +667,9 @@ return failure(); return splitFullAndPartialTransfer(rewriter, xferOp, options); } + +void mlir::vector::populateVectorTransferFullPartialPatterns( + RewritePatternSet &patterns, const VectorTransformsOptions &options) { + patterns.add(patterns.getContext(), + options); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -51,102 +51,6 @@ using namespace mlir; using namespace mlir::vector; -// Helper to find an index in an affine map. -static std::optional getResultIndex(AffineMap map, int64_t index) { - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t idx = map.getDimPosition(i); - if (idx == index) - return i; - } - return std::nullopt; -} - -// Helper to construct iterator types with one index removed. -static SmallVector adjustIter(ArrayAttr iteratorTypes, - int64_t index) { - SmallVector results; - for (const auto &it : llvm::enumerate(iteratorTypes)) { - int64_t idx = it.index(); - if (idx == index) - continue; - results.push_back(it.value()); - } - return results; -} - -// Helper to construct an affine map with one index removed. -static AffineMap adjustMap(AffineMap map, int64_t index, - PatternRewriter &rewriter) { - auto *ctx = rewriter.getContext(); - SmallVector results; - for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { - int64_t idx = map.getDimPosition(i); - if (idx == index) - continue; - // Re-insert remaining indices, but renamed when occurring - // after the removed index. - auto targetExpr = getAffineDimExpr(idx < index ? idx : idx - 1, ctx); - results.push_back(targetExpr); - } - return AffineMap::get(map.getNumDims() - 1, 0, results, ctx); -} - -// Helper method to possibly drop a dimension in a load. -// TODO -static Value reshapeLoad(Location loc, Value val, VectorType type, - int64_t index, int64_t pos, - PatternRewriter &rewriter) { - if (index == -1) - return val; - Type lowType = VectorType::Builder(type).dropDim(0); - // At extraction dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, lowType, val, posAttr); - } - // Unroll leading dimensions. - VectorType vType = lowType.cast(); - Type resType = VectorType::Builder(type).dropDim(index); - auto resVectorType = resType.cast(); - Value result = rewriter.create( - loc, resVectorType, rewriter.getZeroAttr(resVectorType)); - for (int64_t d = 0, e = resVectorType.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, val, posAttr); - Value load = reshapeLoad(loc, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, resVectorType, load, result, - posAttr); - } - return result; -} - -// Helper method to possibly drop a dimension in a store. -// TODO -static Value reshapeStore(Location loc, Value val, Value result, - VectorType type, int64_t index, int64_t pos, - PatternRewriter &rewriter) { - // Unmodified? - if (index == -1) - return val; - // At insertion dimension? - if (index == 0) { - auto posAttr = rewriter.getI64ArrayAttr(pos); - return rewriter.create(loc, type, val, result, posAttr); - } - // Unroll leading dimensions. - Type lowType = VectorType::Builder(type).dropDim(0); - VectorType vType = lowType.cast(); - Type insType = VectorType::Builder(vType).dropDim(0); - for (int64_t d = 0, e = type.getDimSize(0); d < e; d++) { - auto posAttr = rewriter.getI64ArrayAttr(d); - Value ext = rewriter.create(loc, vType, result, posAttr); - Value ins = rewriter.create(loc, insType, val, posAttr); - Value sto = reshapeStore(loc, ins, ext, vType, index - 1, pos, rewriter); - result = rewriter.create(loc, type, sto, result, posAttr); - } - return result; -} - template static SmallVector extractVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>(llvm::map_range( @@ -154,61 +58,11 @@ [](IntegerAttr attr) { return static_cast(attr.getInt()); })); } -/// Helper to create arithmetic operation associated with a kind of contraction. -static std::optional -createContractArithOp(Location loc, Value x, Value y, Value acc, - vector::CombiningKind kind, PatternRewriter &rewriter, - bool isInt, Value mask = Value()) { - using vector::CombiningKind; - Value mul; - - if (isInt) { - if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) - // Only valid for floating point types. - return std::nullopt; - mul = rewriter.create(loc, x, y); - } else { - // Float case. - if (kind == CombiningKind::AND || kind == CombiningKind::MINUI || - kind == CombiningKind::MINSI || kind == CombiningKind::MAXUI || - kind == CombiningKind::MAXSI || kind == CombiningKind::OR || - kind == CombiningKind::XOR) - // Only valid for integer types. - return std::nullopt; - // Special case for fused multiply-add. - if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { - Value fma = rewriter.create(loc, x, y, acc); - if (mask) - // The fma op doesn't need explicit masking. However, fma ops used in - // reductions must preserve previous 'acc' values for masked-out lanes. - fma = selectPassthru(rewriter, mask, fma, acc); - return fma; - } - mul = rewriter.create(loc, x, y); - } - - if (!acc) - return std::optional(mul); - - return makeArithReduction(rewriter, loc, kind, mul, acc, mask); -} - -/// Return the positions of the reductions in the given map. -static SmallVector getReductionIndex(AffineMap map, - ArrayAttr iteratorTypes) { - SmallVector dimsIdx; - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - if (isReductionIterator(iteratorTypes[map.getDimPosition(i)])) - dimsIdx.push_back(i); - } - return dimsIdx; -} - -/// Look for a given dimension in an affine map and return its position. Return -/// std::nullopt if the dimension is not in the map results. -static std::optional getDimPosition(AffineMap map, unsigned dim) { - for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { - if (map.getDimPosition(i) == dim) +// Helper to find an index in an affine map. +static std::optional getResultIndex(AffineMap map, int64_t index) { + for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) { + int64_t idx = map.getDimPosition(i); + if (idx == index) return i; } return std::nullopt; @@ -264,735 +118,6 @@ } }; -/// Progressive lowering of BroadcastOp. -class BroadcastOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::BroadcastOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - VectorType dstType = op.getResultVectorType(); - VectorType srcType = op.getSourceType().dyn_cast(); - Type eltType = dstType.getElementType(); - - // Scalar to any vector can use splat. - if (!srcType) { - rewriter.replaceOpWithNewOp(op, dstType, op.getSource()); - return success(); - } - - // Determine rank of source and destination. - int64_t srcRank = srcType.getRank(); - int64_t dstRank = dstType.getRank(); - - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. - if (srcRank <= 1 && dstRank == 1) { - Value ext; - if (srcRank == 0) - ext = rewriter.create(loc, op.getSource()); - else - ext = rewriter.create(loc, op.getSource(), 0); - rewriter.replaceOpWithNewOp(op, dstType, ext); - return success(); - } - - // Duplicate this rank. - // For example: - // %x = broadcast %y : k-D to n-D, k < n - // becomes: - // %b = broadcast %y : k-D to (n-1)-D - // %x = [%b,%b,%b,%b] : n-D - // becomes: - // %b = [%y,%y] : (n-1)-D - // %x = [%b,%b,%b,%b] : n-D - if (srcRank < dstRank) { - // Duplication. - VectorType resType = - VectorType::get(dstType.getShape().drop_front(), eltType); - Value bcst = - rewriter.create(loc, resType, op.getSource()); - Value result = rewriter.create( - loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) - result = rewriter.create(loc, bcst, result, d); - rewriter.replaceOp(op, result); - return success(); - } - - // Find non-matching dimension, if any. - assert(srcRank == dstRank); - int64_t m = -1; - for (int64_t r = 0; r < dstRank; r++) - if (srcType.getDimSize(r) != dstType.getDimSize(r)) { - m = r; - break; - } - - // All trailing dimensions are the same. Simply pass through. - if (m == -1) { - rewriter.replaceOp(op, op.getSource()); - return success(); - } - - // Any non-matching dimension forces a stretch along this rank. - // For example: - // %x = broadcast %y : vector<4x1x2xf32> to vector<4x2x2xf32> - // becomes: - // %a = broadcast %y[0] : vector<1x2xf32> to vector<2x2xf32> - // %b = broadcast %y[1] : vector<1x2xf32> to vector<2x2xf32> - // %c = broadcast %y[2] : vector<1x2xf32> to vector<2x2xf32> - // %d = broadcast %y[3] : vector<1x2xf32> to vector<2x2xf32> - // %x = [%a,%b,%c,%d] - // becomes: - // %u = broadcast %y[0][0] : vector<2xf32> to vector <2x2xf32> - // %v = broadcast %y[1][0] : vector<2xf32> to vector <2x2xf32> - // %a = [%u, %v] - // .. - // %x = [%a,%b,%c,%d] - VectorType resType = - VectorType::get(dstType.getShape().drop_front(), eltType); - Value result = rewriter.create( - loc, dstType, rewriter.getZeroAttr(dstType)); - if (m == 0) { - // Stetch at start. - Value ext = rewriter.create(loc, op.getSource(), 0); - Value bcst = rewriter.create(loc, resType, ext); - for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) - result = rewriter.create(loc, bcst, result, d); - } else { - // Stetch not at start. - for (int64_t d = 0, dim = dstType.getDimSize(0); d < dim; ++d) { - Value ext = rewriter.create(loc, op.getSource(), d); - Value bcst = rewriter.create(loc, resType, ext); - result = rewriter.create(loc, bcst, result, d); - } - } - rewriter.replaceOp(op, result); - return success(); - } -}; - -/// Given a 'transpose' pattern, prune the rightmost dimensions that are not -/// transposed. -void pruneNonTransposedDims(ArrayRef transpose, - SmallVectorImpl &result) { - size_t numTransposedDims = transpose.size(); - for (size_t transpDim : llvm::reverse(transpose)) { - if (transpDim != numTransposedDims - 1) - break; - numTransposedDims--; - } - - result.append(transpose.begin(), transpose.begin() + numTransposedDims); -} - -/// Progressive lowering of TransposeOp. -/// One: -/// %x = vector.transpose %y, [1, 0] -/// is replaced by: -/// %z = arith.constant dense<0.000000e+00> -/// %0 = vector.extract %y[0, 0] -/// %1 = vector.insert %0, %z [0, 0] -/// .. -/// %x = vector.insert .., .. [.., ..] -class TransposeOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - TransposeOpLowering(vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions) {} - - LogicalResult matchAndRewrite(vector::TransposeOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - Value input = op.getVector(); - VectorType inputType = op.getSourceVectorType(); - VectorType resType = op.getResultVectorType(); - - // Set up convenience transposition table. - SmallVector transp; - for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); - - if (vectorTransformOptions.vectorTransposeLowering == - vector::VectorTransposeLowering::Shuffle && - resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) - return rewriter.notifyMatchFailure( - op, "Options specifies lowering to shuffle"); - - // Handle a true 2-D matrix transpose differently when requested. - if (vectorTransformOptions.vectorTransposeLowering == - vector::VectorTransposeLowering::Flat && - resType.getRank() == 2 && transp[0] == 1 && transp[1] == 0) { - Type flattenedType = - VectorType::get(resType.getNumElements(), resType.getElementType()); - auto matrix = - rewriter.create(loc, flattenedType, input); - auto rows = rewriter.getI32IntegerAttr(resType.getShape()[0]); - auto columns = rewriter.getI32IntegerAttr(resType.getShape()[1]); - Value trans = rewriter.create( - loc, flattenedType, matrix, rows, columns); - rewriter.replaceOpWithNewOp(op, resType, trans); - return success(); - } - - // Generate unrolled extract/insert ops. We do not unroll the rightmost - // (i.e., highest-order) dimensions that are not transposed and leave them - // in vector form to improve performance. Therefore, we prune those - // dimensions from the shape/transpose data structures used to generate the - // extract/insert ops. - SmallVector prunedTransp; - pruneNonTransposedDims(transp, prunedTransp); - size_t numPrunedDims = transp.size() - prunedTransp.size(); - auto prunedInShape = inputType.getShape().drop_back(numPrunedDims); - auto prunedInStrides = computeStrides(prunedInShape); - - // Generates the extract/insert operations for every scalar/vector element - // of the leftmost transposed dimensions. We traverse every transpose - // element using a linearized index that we delinearize to generate the - // appropriate indices for the extract/insert operations. - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); - int64_t numTransposedElements = ShapedType::getNumElements(prunedInShape); - - for (int64_t linearIdx = 0; linearIdx < numTransposedElements; - ++linearIdx) { - auto extractIdxs = delinearize(linearIdx, prunedInStrides); - SmallVector insertIdxs(extractIdxs); - applyPermutationToVector(insertIdxs, prunedTransp); - Value extractOp = - rewriter.create(loc, input, extractIdxs); - result = - rewriter.create(loc, extractOp, result, insertIdxs); - } - - rewriter.replaceOp(op, result); - return success(); - } - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; -}; - -/// Rewrite a 2-D vector.transpose as a sequence of: -/// vector.shape_cast 2D -> 1D -/// vector.shuffle -/// vector.shape_cast 1D -> 2D -class TransposeOp2DToShuffleLowering - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - TransposeOp2DToShuffleLowering( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions) {} - - LogicalResult matchAndRewrite(vector::TransposeOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - VectorType srcType = op.getSourceVectorType(); - if (srcType.getRank() != 2) - return rewriter.notifyMatchFailure(op, "Not a 2D transpose"); - - SmallVector transp; - for (auto attr : op.getTransp()) - transp.push_back(attr.cast().getInt()); - if (transp[0] != 1 && transp[1] != 0) - return rewriter.notifyMatchFailure(op, "Not a 2D transpose permutation"); - - if (vectorTransformOptions.vectorTransposeLowering != - VectorTransposeLowering::Shuffle) - return rewriter.notifyMatchFailure(op, "Options do not ask for Shuffle"); - - int64_t m = srcType.getShape().front(), n = srcType.getShape().back(); - Value casted = rewriter.create( - loc, VectorType::get({m * n}, srcType.getElementType()), - op.getVector()); - SmallVector mask; - mask.reserve(m * n); - for (int64_t j = 0; j < n; ++j) - for (int64_t i = 0; i < m; ++i) - mask.push_back(i * n + j); - - Value shuffled = - rewriter.create(loc, casted, casted, mask); - rewriter.replaceOpWithNewOp( - op, op.getResultVectorType(), shuffled); - - return success(); - } - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; -}; - -/// Progressive lowering of OuterProductOp. -/// One: -/// %x = vector.outerproduct %lhs, %rhs, %acc -/// is replaced by: -/// %z = zero-result -/// %0 = vector.extract %lhs[0] -/// %1 = vector.broadcast %0 -/// %2 = vector.extract %acc[0] -/// %3 = vector.fma %1, %rhs, %2 -/// %4 = vector.insert %3, %z[0] -/// .. -/// %x = vector.insert %.., %..[N-1] -/// -class OuterProductOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::OuterProductOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - - VectorType lhsType = op.getOperandVectorTypeLHS(); - VectorType rhsType = op.getOperandTypeRHS().dyn_cast(); - VectorType resType = op.getResultVectorType(); - Type eltType = resType.getElementType(); - bool isInt = eltType.isa(); - Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; - vector::CombiningKind kind = op.getKind(); - - // Vector mask setup. - OpBuilder::InsertionGuard guard(rewriter); - auto maskableOp = cast(op.getOperation()); - Operation *rootOp; - Value mask; - if (maskableOp.isMasked()) { - rewriter.setInsertionPoint(maskableOp.getMaskingOp()); - rootOp = maskableOp.getMaskingOp(); - mask = maskableOp.getMaskingOp().getMask(); - } else { - rootOp = op; - } - - if (!rhsType) { - // Special case: AXPY operation. - Value b = rewriter.create(loc, lhsType, op.getRhs()); - std::optional mult = createContractArithOp( - loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); - if (!mult.has_value()) - return failure(); - rewriter.replaceOp(rootOp, *mult); - return success(); - } - - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); - for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { - auto pos = rewriter.getI64ArrayAttr(d); - Value x = rewriter.create(loc, op.getLhs(), pos); - Value a = rewriter.create(loc, rhsType, x); - Value r = nullptr; - if (acc) - r = rewriter.create(loc, acc, pos); - Value extrMask; - if (mask) - extrMask = rewriter.create(loc, mask, pos); - - std::optional m = createContractArithOp( - loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); - if (!m.has_value()) - return failure(); - result = rewriter.create(loc, resType, *m, result, pos); - } - - rewriter.replaceOp(rootOp, result); - return success(); - } -}; - -/// Lower vector.contract with all size one reduction dimensions to -/// elementwise ops when possible. -struct ContractOpToElementwise - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - using FilterConstraintType = - std::function; - static LogicalResult defaultFilter(vector::ContractionOp op) { - return success(); - } - ContractOpToElementwise( - vector::VectorTransformsOptions vectorTransformOptions, - MLIRContext *context, PatternBenefit benefit = 1, - const FilterConstraintType &constraint = defaultFilter) - : OpRewritePattern(context, benefit), - vectorTransformOptions(vectorTransformOptions), filter(defaultFilter) {} - - LogicalResult matchAndRewrite(vector::ContractionOp contractOp, - PatternRewriter &rewriter) const override { - // TODO: Support vector.mask. - auto maskableOp = cast(contractOp.getOperation()); - if (maskableOp.isMasked()) - return failure(); - - // TODO: Remove native masks from contraction op? - if (!contractOp.getMasks().empty()) - return failure(); - - if (failed(filter(contractOp))) - return failure(); - - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::ParallelArith) - return failure(); - - ArrayRef lhsShape = contractOp.getLhsType().getShape(); - ArrayRef rhsShape = contractOp.getRhsType().getShape(); - AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; - AffineMap rhsMap = contractOp.getIndexingMapsArray()[1]; - SmallVector lhsReductionDims = - getReductionIndex(lhsMap, contractOp.getIteratorTypes()); - SmallVector rhsReductionDims = - getReductionIndex(rhsMap, contractOp.getIteratorTypes()); - // All the reduction dimensions must be a size 1. - for (int64_t dim : lhsReductionDims) { - if (lhsShape[dim] != 1) - return failure(); - } - for (int64_t dim : rhsReductionDims) { - if (rhsShape[dim] != 1) - return failure(); - } - AffineMap accMap = contractOp.getIndexingMapsArray()[2]; - unsigned numParallelDims = accMap.getNumResults(); - unsigned numLhsDimToBroadcast = - numParallelDims - (lhsMap.getNumResults() - lhsReductionDims.size()); - unsigned numRhsDimToBroadcast = - numParallelDims - (rhsMap.getNumResults() - rhsReductionDims.size()); - SmallVector lhsDims; - SmallVector lhsTranspose; - SmallVector rhsDims; - SmallVector rhsTranspose; - for (int64_t dim : lhsReductionDims) - lhsTranspose.push_back(numLhsDimToBroadcast + dim); - for (int64_t dim : rhsReductionDims) - rhsTranspose.push_back(numRhsDimToBroadcast + dim); - // Loop through the parallel dimensions to calculate the dimensions to - // broadcast and to permute in order to extract only parallel dimensions. - for (unsigned i = 0; i < numParallelDims; i++) { - std::optional lhsDim = - getDimPosition(lhsMap, accMap.getDimPosition(i)); - if (lhsDim) { - lhsTranspose.push_back(numLhsDimToBroadcast + *lhsDim); - } else { - // If the parallel dimension doesn't exist we will have to broadcast it. - lhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); - lhsTranspose.push_back(lhsDims.size() - 1); - } - std::optional rhsDim = - getDimPosition(rhsMap, accMap.getDimPosition(i)); - if (rhsDim) { - rhsTranspose.push_back(numRhsDimToBroadcast + *rhsDim); - } else { - // If the parallel dimension doesn't exist we will have to broadcast it. - rhsDims.push_back( - contractOp.getResultType().cast().getDimSize(i)); - rhsTranspose.push_back(rhsDims.size() - 1); - } - } - Value newLhs = contractOp.getLhs(); - Value newRhs = contractOp.getRhs(); - Location loc = contractOp.getLoc(); - if (!lhsDims.empty()) { - lhsDims.append(lhsShape.begin(), lhsShape.end()); - auto expandedType = - VectorType::get(lhsDims, contractOp.getLhsType().getElementType()); - newLhs = rewriter.create(loc, expandedType, newLhs); - } - if (!rhsDims.empty()) { - rhsDims.append(rhsShape.begin(), rhsShape.end()); - auto expandedType = - VectorType::get(rhsDims, contractOp.getRhsType().getElementType()); - newRhs = rewriter.create(loc, expandedType, newRhs); - } - bool isInt = contractOp.getLhsType().getElementType().isIntOrIndex(); - newLhs = rewriter.create(loc, newLhs, lhsTranspose); - newRhs = rewriter.create(loc, newRhs, rhsTranspose); - SmallVector lhsOffsets(lhsReductionDims.size(), 0); - SmallVector rhsOffsets(rhsReductionDims.size(), 0); - newLhs = rewriter.create( - loc, newLhs, rewriter.getI64ArrayAttr(lhsOffsets)); - newRhs = rewriter.create( - loc, newRhs, rewriter.getI64ArrayAttr(rhsOffsets)); - std::optional result = - createContractArithOp(loc, newLhs, newRhs, contractOp.getAcc(), - contractOp.getKind(), rewriter, isInt); - rewriter.replaceOp(contractOp, {*result}); - return success(); - } - -private: - /// Options to control the vector patterns. - vector::VectorTransformsOptions vectorTransformOptions; - FilterConstraintType filter; -}; - -/// Progressive lowering of ConstantMaskOp. -/// One: -/// %x = vector.constant_mask [a,b] -/// is replaced by: -/// %z = zero-result -/// %l = vector.constant_mask [b] -/// %4 = vector.insert %l, %z[0] -/// .. -/// %x = vector.insert %l, %..[a-1] -/// until a one-dimensional vector is reached. All these operations -/// will be folded at LLVM IR level. -class ConstantMaskOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ConstantMaskOp op, - PatternRewriter &rewriter) const override { - auto loc = op.getLoc(); - auto dstType = op.getType(); - auto eltType = dstType.getElementType(); - auto dimSizes = op.getMaskDimSizes(); - int64_t rank = dstType.getRank(); - - if (rank == 0) { - assert(dimSizes.size() == 1 && - "Expected exactly one dim size for a 0-D vector"); - bool value = dimSizes[0].cast().getInt() == 1; - rewriter.replaceOpWithNewOp( - op, dstType, - DenseIntElementsAttr::get( - VectorType::get(ArrayRef{}, rewriter.getI1Type()), - ArrayRef{value})); - return success(); - } - - // Scalable constant masks can only be lowered for the "none set" case. - if (dstType.cast().isScalable()) { - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(dstType, false)); - return success(); - } - - int64_t trueDim = std::min(dstType.getDimSize(0), - dimSizes[0].cast().getInt()); - - if (rank == 1) { - // Express constant 1-D case in explicit vector form: - // [T,..,T,F,..,F]. - SmallVector values(dstType.getDimSize(0)); - for (int64_t d = 0; d < trueDim; d++) - values[d] = true; - rewriter.replaceOpWithNewOp( - op, dstType, rewriter.getBoolVectorAttr(values)); - return success(); - } - - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); - SmallVector newDimSizes; - for (int64_t r = 1; r < rank; r++) - newDimSizes.push_back(dimSizes[r].cast().getInt()); - Value trueVal = rewriter.create( - loc, lowType, rewriter.getI64ArrayAttr(newDimSizes)); - Value result = rewriter.create( - loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0; d < trueDim; d++) { - auto pos = rewriter.getI64ArrayAttr(d); - result = - rewriter.create(loc, dstType, trueVal, result, pos); - } - rewriter.replaceOp(op, result); - return success(); - } -}; - -/// Progressive lowering of CreateMaskOp. -/// One: -/// %x = vector.create_mask %a, ... : vector -/// is replaced by: -/// %l = vector.create_mask ... : vector<...> ; one lower rank -/// %0 = arith.cmpi "slt", %ci, %a | -/// %1 = select %0, %l, %zeroes | -/// %r = vector.insert %1, %pr [i] | d-times -/// %x = .... -/// until a one-dimensional vector is reached. -class CreateMaskOpLowering : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::CreateMaskOp op, - PatternRewriter &rewriter) const override { - auto dstType = op.getResult().getType().cast(); - int64_t rank = dstType.getRank(); - if (rank <= 1) - return rewriter.notifyMatchFailure( - op, "0-D and 1-D vectors are handled separately"); - - auto loc = op.getLoc(); - auto eltType = dstType.getElementType(); - int64_t dim = dstType.getDimSize(0); - Value idx = op.getOperand(0); - - VectorType lowType = - VectorType::get(dstType.getShape().drop_front(), eltType); - Value trueVal = rewriter.create( - loc, lowType, op.getOperands().drop_front()); - Value falseVal = rewriter.create( - loc, lowType, rewriter.getZeroAttr(lowType)); - Value result = rewriter.create( - loc, dstType, rewriter.getZeroAttr(dstType)); - for (int64_t d = 0; d < dim; d++) { - Value bnd = - rewriter.create(loc, rewriter.getIndexAttr(d)); - Value val = rewriter.create(loc, arith::CmpIPredicate::slt, - bnd, idx); - Value sel = rewriter.create(loc, val, trueVal, falseVal); - auto pos = rewriter.getI64ArrayAttr(d); - result = - rewriter.create(loc, dstType, sel, result, pos); - } - rewriter.replaceOp(op, result); - return success(); - } -}; - -/// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D -/// vectors progressively on the way to target llvm.matrix intrinsics. -/// This iterates over the most major dimension of the 2-D vector and performs -/// rewrites into: -/// vector.extract from 2-D + vector.insert_strided_slice offset into 1-D -class ShapeCastOp2DDownCastRewritePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); - if (sourceVectorType.getRank() != 2 || resultVectorType.getRank() != 1) - return failure(); - - auto loc = op.getLoc(); - Value desc = rewriter.create( - loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); - unsigned mostMinorVectorSize = sourceVectorType.getShape()[1]; - for (int64_t i = 0, e = sourceVectorType.getShape().front(); i != e; ++i) { - Value vec = rewriter.create(loc, op.getSource(), i); - desc = rewriter.create( - loc, vec, desc, - /*offsets=*/i * mostMinorVectorSize, /*strides=*/1); - } - rewriter.replaceOp(op, desc); - return success(); - } -}; - -/// ShapeOp 1D -> 2D upcast serves the purpose of unflattening 2-D from 1-D -/// vectors progressively. -/// This iterates over the most major dimension of the 2-D vector and performs -/// rewrites into: -/// vector.extract_strided_slice from 1-D + vector.insert into 2-D -/// Note that 1-D extract_strided_slice are lowered to efficient vector.shuffle. -class ShapeCastOp2DUpCastRewritePattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); - if (sourceVectorType.getRank() != 1 || resultVectorType.getRank() != 2) - return failure(); - - auto loc = op.getLoc(); - Value desc = rewriter.create( - loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); - unsigned mostMinorVectorSize = resultVectorType.getShape()[1]; - for (int64_t i = 0, e = resultVectorType.getShape().front(); i != e; ++i) { - Value vec = rewriter.create( - loc, op.getSource(), /*offsets=*/i * mostMinorVectorSize, - /*sizes=*/mostMinorVectorSize, - /*strides=*/1); - desc = rewriter.create(loc, vec, desc, i); - } - rewriter.replaceOp(op, desc); - return success(); - } -}; - -// We typically should not lower general shape cast operations into data -// movement instructions, since the assumption is that these casts are -// optimized away during progressive lowering. For completeness, however, -// we fall back to a reference implementation that moves all elements -// into the right place if we get here. -class ShapeCastOpRewritePattern : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ShapeCastOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - auto sourceVectorType = op.getSourceVectorType(); - auto resultVectorType = op.getResultVectorType(); - - // Special case 2D/1D lowerings with better implementations. - // TODO: make is ND/1D to allow generic ND->1D->MD. - int64_t srcRank = sourceVectorType.getRank(); - int64_t resRank = resultVectorType.getRank(); - if ((srcRank == 2 && resRank == 1) || (srcRank == 1 && resRank == 2)) - return failure(); - - // Generic ShapeCast lowering path goes all the way down to unrolled scalar - // extract/insert chains. - // TODO: consider evolving the semantics to only allow 1D source or dest and - // drop this potentially very expensive lowering. - // Compute number of elements involved in the reshape. - int64_t numElts = 1; - for (int64_t r = 0; r < srcRank; r++) - numElts *= sourceVectorType.getDimSize(r); - // Replace with data movement operations: - // x[0,0,0] = y[0,0] - // x[0,0,1] = y[0,1] - // x[0,1,0] = y[0,2] - // etc., incrementing the two index vectors "row-major" - // within the source and result shape. - SmallVector srcIdx(srcRank); - SmallVector resIdx(resRank); - Value result = rewriter.create( - loc, resultVectorType, rewriter.getZeroAttr(resultVectorType)); - for (int64_t i = 0; i < numElts; i++) { - if (i != 0) { - incIdx(srcIdx, sourceVectorType, srcRank - 1); - incIdx(resIdx, resultVectorType, resRank - 1); - } - Value e = rewriter.create(loc, op.getSource(), srcIdx); - result = rewriter.create(loc, e, result, resIdx); - } - rewriter.replaceOp(op, result); - return success(); - } - -private: - static void incIdx(SmallVector &idx, VectorType tp, int64_t r) { - assert(0 <= r && r < tp.getRank()); - if (++idx[r] == tp.getDimSize(r)) { - idx[r] = 0; - incIdx(idx, tp, r - 1); - } - } -}; - /// Convert MulIOp/MulFOp + MultiDimReductionOp into ContractionOp. /// Ex: /// ``` @@ -1425,967 +550,6 @@ } }; -} // namespace - -/// Creates an AddIOp if `isInt` is true otherwise create an arith::AddFOp using -/// operands `x` and `y`. -static Value createAdd(Location loc, Value x, Value y, bool isInt, - PatternRewriter &rewriter) { - if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); -} - -/// Creates a MulIOp if `isInt` is true otherwise create an MulFOp using -/// operands `x and `y`. -static Value createMul(Location loc, Value x, Value y, bool isInt, - PatternRewriter &rewriter) { - if (isInt) - return rewriter.create(loc, x, y); - return rewriter.create(loc, x, y); -} - -namespace mlir { - -/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to: -/// ``` -/// %mta = maybe_transpose -/// %mtb = maybe_transpose -/// %flattened_a = vector.shape_cast %mta -/// %flattened_b = vector.shape_cast %mtb -/// %flattened_d = vector.matmul %flattened_a, %flattened_b -/// %mtd = vector.shape_cast %flattened_d -/// %d = maybe_untranspose %mtd -/// %e = add %c, %d -/// ``` -/// `vector.matmul` later lowers to `llvm.matrix.multiply`. -// -/// This only kicks in when VectorTransformsOptions is set to `Matmul`. -/// vector.transpose operations are inserted if the vector.contract op is not a -/// row-major matrix multiply. -LogicalResult -ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rew) const { - // TODO: Support vector.mask. - auto maskableOp = cast(op.getOperation()); - if (maskableOp.isMasked()) - return failure(); - - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::Matmul) - return failure(); - if (failed(filter(op))) - return failure(); - - auto iteratorTypes = op.getIteratorTypes().getValue(); - if (!isParallelIterator(iteratorTypes[0]) || - !isParallelIterator(iteratorTypes[1]) || - !isReductionIterator(iteratorTypes[2])) - return failure(); - - Type elementType = op.getLhsType().getElementType(); - if (!elementType.isIntOrFloat()) - return failure(); - - Type dstElementType = op.getType(); - if (auto vecType = dstElementType.dyn_cast()) - dstElementType = vecType.getElementType(); - if (elementType != dstElementType) - return failure(); - - // Perform lhs + rhs transpositions to conform to matmul row-major semantics. - // Bail out if the contraction cannot be put in this form. - MLIRContext *ctx = op.getContext(); - Location loc = op.getLoc(); - AffineExpr m, n, k; - bindDims(rew.getContext(), m, n, k); - // LHS must be A(m, k) or A(k, m). - Value lhs = op.getLhs(); - auto lhsMap = op.getIndexingMapsArray()[0]; - if (lhsMap == AffineMap::get(3, 0, {k, m}, ctx)) - lhs = rew.create(loc, lhs, ArrayRef{1, 0}); - else if (lhsMap != AffineMap::get(3, 0, {m, k}, ctx)) - return failure(); - - // RHS must be B(k, n) or B(n, k). - Value rhs = op.getRhs(); - auto rhsMap = op.getIndexingMapsArray()[1]; - if (rhsMap == AffineMap::get(3, 0, {n, k}, ctx)) - rhs = rew.create(loc, rhs, ArrayRef{1, 0}); - else if (rhsMap != AffineMap::get(3, 0, {k, n}, ctx)) - return failure(); - - // At this point lhs and rhs are in row-major. - VectorType lhsType = lhs.getType().cast(); - VectorType rhsType = rhs.getType().cast(); - int64_t lhsRows = lhsType.getDimSize(0); - int64_t lhsColumns = lhsType.getDimSize(1); - int64_t rhsColumns = rhsType.getDimSize(1); - - Type flattenedLHSType = - VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - lhs = rew.create(loc, flattenedLHSType, lhs); - - Type flattenedRHSType = - VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - rhs = rew.create(loc, flattenedRHSType, rhs); - - Value mul = rew.create(loc, lhs, rhs, lhsRows, lhsColumns, - rhsColumns); - mul = rew.create( - loc, - VectorType::get({lhsRows, rhsColumns}, - getElementTypeOrSelf(op.getAcc().getType())), - mul); - - // ACC must be C(m, n) or C(n, m). - auto accMap = op.getIndexingMapsArray()[2]; - if (accMap == AffineMap::get(3, 0, {n, m}, ctx)) - mul = rew.create(loc, mul, ArrayRef{1, 0}); - else if (accMap != AffineMap::get(3, 0, {m, n}, ctx)) - llvm_unreachable("invalid contraction semantics"); - - Value res = - elementType.isa() - ? static_cast(rew.create(loc, op.getAcc(), mul)) - : static_cast( - rew.create(loc, op.getAcc(), mul)); - - rew.replaceOp(op, res); - return success(); -} - -namespace { - -/// Generate a vector implementation for matmat, matvec and tmatvec. -/// This unrolls outer-products along the reduction dimension. -struct UnrolledOuterProductGenerator - : public StructuredGenerator { - UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) - : StructuredGenerator(b, op), - kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), - res(op.getAcc()), lhsType(op.getLhsType()) { - auto maskableOp = cast(op.getOperation()); - if (maskableOp.isMasked()) - mask = maskableOp.getMaskingOp().getMask(); - } - - Value t(Value v, ArrayRef perm = {1, 0}) { - if (!v) - return v; - return rewriter.create(loc, v, perm); - } - - Value promote(Value v, Type dstElementType) { - Type elementType = v.getType(); - auto vecType = elementType.dyn_cast(); - if (vecType) - elementType = vecType.getElementType(); - if (elementType == dstElementType) - return v; - Type promotedType = dstElementType; - if (vecType) - promotedType = VectorType::get(vecType.getShape(), promotedType); - if (dstElementType.isa()) - return rewriter.create(loc, promotedType, v); - return rewriter.create(loc, promotedType, v); - } - - FailureOr outerProd(Value lhs, Value rhs, Value res, int reductionSize, - std::optional maybeMask = std::nullopt) { - assert(reductionSize > 0); - // Incremental support for masking. - if (mask && !maybeMask.has_value()) - return failure(); - - Type resElementType = res.getType().cast().getElementType(); - for (int64_t k = 0; k < reductionSize; ++k) { - Value extractA = rewriter.create(loc, lhs, k); - Value extractB = rewriter.create(loc, rhs, k); - extractA = promote(extractA, resElementType); - extractB = promote(extractB, resElementType); - Value extractMask; - if (maybeMask.has_value() && maybeMask.value()) - extractMask = - rewriter.create(loc, maybeMask.value(), k); - - Operation *outerProdOp = rewriter.create( - loc, res.getType(), extractA, extractB, res, kind); - res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); - } - return res; - } - - /// Two outer parallel, one inner reduction (matmat flavor). - FailureOr matmat() { - if (!iters({Par(), Par(), Red()})) - return failure(); - // Set up the parallel/reduction structure in the right form. - AffineExpr m, n, k; - bindDims(rewriter.getContext(), m, n, k); - // Classical row-major matmul: Just permute the lhs. - if (layout({{m, k}, {k, n}, {m, n}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), - t(mask, {2, 0, 1})); - // TODO: may be better to fail and use some vector -> scalar reduction. - if (layout({{m, k}, {n, k}, {m, n}})) { - Value tlhs = t(lhs); - return outerProd(tlhs, t(rhs), res, lhsType.getDimSize(1)); - } - // No need to permute anything. - if (layout({{k, m}, {k, n}, {m, n}})) - return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); - // Just permute the rhs. - if (layout({{k, m}, {n, k}, {m, n}})) - return outerProd(lhs, t(rhs), res, lhsType.getDimSize(0)); - // Transposed output: swap RHS and LHS. - // Classical row-major matmul: permute the lhs. - if (layout({{m, k}, {k, n}, {n, m}})) - return outerProd(rhs, t(lhs), res, lhsType.getDimSize(1)); - // TODO: may be better to fail and use some vector -> scalar reduction. - if (layout({{m, k}, {n, k}, {n, m}})) { - Value trhs = t(rhs); - return outerProd(trhs, t(lhs), res, lhsType.getDimSize(1)); - } - if (layout({{k, m}, {k, n}, {n, m}})) - return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); - if (layout({{k, m}, {n, k}, {n, m}})) - return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); - return failure(); - } - - /// One outer parallel, one inner reduction (matvec flavor) - FailureOr matvec() { - if (!iters({Par(), Red()})) - return failure(); - AffineExpr m, k; - bindDims(rewriter.getContext(), m, k); - - // Case mat-vec: transpose. - if (layout({{m, k}, {k}, {m}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask)); - // Case mat-trans-vec: ready to go. - if (layout({{k, m}, {k}, {m}})) - return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); - // Case vec-mat: swap and transpose. - if (layout({{k}, {m, k}, {m}})) - return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); - // Case vec-mat-trans: swap and ready to go. - if (layout({{k}, {k, m}, {m}})) - return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); - return failure(); - } - - // - // One outer reduction, one inner parallel (tmatvec flavor) - // - FailureOr tmatvec() { - if (!iters({Red(), Par()})) - return failure(); - AffineExpr k, m; - bindDims(rewriter.getContext(), k, m); - - // Case mat-vec: transpose. - if (layout({{m, k}, {k}, {m}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); - // Case mat-trans-vec: ready to go. - if (layout({{k, m}, {k}, {m}})) - return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); - // Case vec-mat: swap and transpose. - if (layout({{k}, {m, k}, {m}})) - return outerProd(t(rhs), lhs, res, lhsType.getDimSize(0)); - // Case vec-mat-trans: swap and ready to go. - if (layout({{k}, {k, m}, {m}})) - return outerProd(rhs, lhs, res, lhsType.getDimSize(0)); - return failure(); - } - -private: - vector::CombiningKind kind; - Value lhs, rhs, res, mask; - VectorType lhsType; -}; -} // namespace - -/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to a reduction_size-unrolled sequence: -/// ``` -/// %at = vector.transpose %a, [1, 0] -/// %bRow0 = vector.extract %b[0] -/// %atRow0 = vector.extract %at[0] -/// %c0 = vector.outerproduct %atRow0, %bRow0, %c -/// ... -/// %bRowK = vector.extract %b[K] -/// %atRowK = vector.extract %at[K] -/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 -/// ``` -/// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct but -/// otherwise supports any layout permutation of the matrix-multiply. -LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( - vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); - - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::OuterProduct) - return failure(); - - if (failed(filter(op))) - return failure(); - - // Vector mask setup. - OpBuilder::InsertionGuard guard(rewriter); - auto maskableOp = cast(op.getOperation()); - Operation *rootOp; - if (maskableOp.isMasked()) { - rewriter.setInsertionPoint(maskableOp.getMaskingOp()); - rootOp = maskableOp.getMaskingOp(); - } else { - rootOp = op; - } - - UnrolledOuterProductGenerator e(rewriter, op); - FailureOr matmatRes = e.matmat(); - if (succeeded(matmatRes)) { - rewriter.replaceOp(rootOp, *matmatRes); - return success(); - } - FailureOr matvecRes = e.matvec(); - if (succeeded(matvecRes)) { - rewriter.replaceOp(rootOp, *matvecRes); - return success(); - } - FailureOr tmatvecRes = e.tmatvec(); - if (succeeded(tmatvecRes)) { - rewriter.replaceOp(rootOp, *tmatvecRes); - return success(); - } - - return failure(); -} - -LogicalResult -ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const { - // TODO: Support vector.mask. - auto maskableOp = cast(op.getOperation()); - if (maskableOp.isMasked()) - return failure(); - - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); - - if (failed(filter(op))) - return failure(); - - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::Dot) - return failure(); - - auto iteratorTypes = op.getIteratorTypes().getValue(); - static constexpr std::array perm = {1, 0}; - Location loc = op.getLoc(); - Value lhs = op.getLhs(), rhs = op.getRhs(); - - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - AffineExpr m, n, k; - bindDims(rewriter.getContext(), m, n, k); - SmallVector maps = op.getIndexingMapsArray(); - // - // In the following we wish to make the reduction dimension innermost so we - // can load vectors and just fmul + reduce into a scalar. - // - if (isParallelIterator(iteratorTypes[0]) && - isParallelIterator(iteratorTypes[1]) && - isReductionIterator(iteratorTypes[2])) { - // - // Two outer parallel, one inner reduction (matmat flavor). - // - if (maps == infer({{m, k}, {k, n}, {m, n}})) { - rhs = rewriter.create(loc, rhs, perm); - } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { - // No need to permute anything. - } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); - rhs = rewriter.create(loc, rhs, perm); - } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { - // This is the classical row-major matmul. Just permute the lhs. - Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); - rhs = tmp; - } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { - std::swap(lhs, rhs); - } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { - Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); - rhs = rewriter.create(loc, tmp, perm); - } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { - Value tmp = rhs; - rhs = rewriter.create(loc, lhs, perm); - lhs = tmp; - } else { - return failure(); - } - } else if (isParallelIterator(iteratorTypes[0]) && - isReductionIterator(iteratorTypes[1])) { - // - // One outer parallel, one inner reduction (matvec flavor) - // - if (maps == infer({{m, n}, {n}, {m}})) { - // No need to permute anything. - } else if (maps == infer({{n, m}, {n}, {m}})) { - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{n}, {m, n}, {m}})) { - std::swap(lhs, rhs); - } else if (maps == infer({{n}, {n, m}, {m}})) { - std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); - } else { - return failure(); - } - } else { - return failure(); - } - - VectorType dstType = op.getResultType().cast(); - assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && - "Expected dst type of rank 1 or 2"); - - unsigned rank = dstType.getRank(); - unsigned dstRows = dstType.getShape()[0]; - unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; - - // ExtractOp does not allow dynamic indexing, we must unroll explicitly. - Value res = rewriter.create(loc, dstType, - rewriter.getZeroAttr(dstType)); - bool isInt = dstType.getElementType().isa(); - for (unsigned r = 0; r < dstRows; ++r) { - Value a = rewriter.create(op.getLoc(), lhs, r); - for (unsigned c = 0; c < dstColumns; ++c) { - Value b = rank == 1 - ? rhs - : rewriter.create(op.getLoc(), rhs, c); - Value m = createMul(op.getLoc(), a, b, isInt, rewriter); - Value reduced = rewriter.create( - op.getLoc(), vector::CombiningKind::ADD, m); - - SmallVector pos = rank == 1 ? SmallVector{r} - : SmallVector{r, c}; - res = rewriter.create(op.getLoc(), reduced, res, pos); - } - } - if (auto acc = op.getAcc()) - res = createAdd(op.getLoc(), res, acc, isInt, rewriter); - rewriter.replaceOp(op, res); - return success(); -} - -/// Progressive lowering of ContractionOp. -/// One: -/// %x = vector.contract with at least one free/batch dimension -/// is replaced by: -/// %a = vector.contract with one less free/batch dimension -/// %b = vector.contract with one less free/batch dimension -/// .. -/// %x = combine %a %b .. -/// until a pure contraction is reached (no free/batch dimensions), -/// which is replaced by a dot-product. -/// -/// This only kicks in when either VectorTransformsOptions is set -/// to DOT or when other contraction patterns fail. -// -// TODO: break down into transpose/reshape/cast ops -// when they become available to avoid code dup -// TODO: investigate lowering order impact on performance -LogicalResult -ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const { - // TODO: Remove native masks from contraction op? - if (!op.getMasks().empty()) - return failure(); - - if (failed(filter(op))) - return failure(); - - // TODO: support mixed mode contract lowering. - if (op.getLhsType().getElementType() != - getElementTypeOrSelf(op.getAccType()) || - op.getRhsType().getElementType() != getElementTypeOrSelf(op.getAccType())) - return failure(); - - // TODO: the code below assumes the default contraction, make sure it supports - // other kinds before enabling this lowering. - if (op.getKind() != vector::CombiningKind::ADD) { - return rewriter.notifyMatchFailure( - op, "contractions other than 'add' not supported"); - } - - // TODO: implement benefits, cost models. - MLIRContext *ctx = op.getContext(); - ContractionOpToMatmulOpLowering pat1(vectorTransformOptions, ctx); - if (succeeded(pat1.matchAndRewrite(op, rewriter))) - return success(); - ContractionOpToOuterProductOpLowering pat2(vectorTransformOptions, ctx); - if (succeeded(pat2.matchAndRewrite(op, rewriter))) - return success(); - ContractionOpToDotLowering pat3(vectorTransformOptions, ctx); - if (succeeded(pat3.matchAndRewrite(op, rewriter))) - return success(); - ContractOpToElementwise pat4(vectorTransformOptions, ctx); - if (succeeded(pat4.matchAndRewrite(op, rewriter))) - return success(); - - // Vector mask setup. - OpBuilder::InsertionGuard guard(rewriter); - Operation *rootOp = op; - Value mask; - if (op.isMasked()) { - rewriter.setInsertionPoint(op.getMaskingOp()); - rootOp = op.getMaskingOp(); - mask = op.getMaskingOp().getMask(); - } - - // Find first batch dimension in LHS/RHS, and lower when found. - std::vector> batchDimMap = op.getBatchDimMap(); - if (!batchDimMap.empty()) { - int64_t lhsIndex = batchDimMap[0].first; - int64_t rhsIndex = batchDimMap[0].second; - auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask); - if (failed(newOp)) - return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); - } - - // Collect contracting dimensions. - std::vector> contractingDimMap = - op.getContractingDimMap(); - DenseSet lhsContractingDimSet; - DenseSet rhsContractingDimSet; - for (auto &dimPair : contractingDimMap) { - lhsContractingDimSet.insert(dimPair.first); - rhsContractingDimSet.insert(dimPair.second); - } - - // Find first free dimension in LHS, and lower when found. - VectorType lhsType = op.getLhsType(); - for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) { - if (lhsContractingDimSet.count(lhsIndex) == 0) { - auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask); - if (failed(newOp)) - return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); - } - } - - // Find first free dimension in RHS, and lower when found. - VectorType rhsType = op.getRhsType(); - for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) { - if (rhsContractingDimSet.count(rhsIndex) == 0) { - auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask); - if (failed(newOp)) - return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); - } - } - - // Lower the first remaining reduction dimension. - if (!contractingDimMap.empty()) { - auto newOp = lowerReduction(rewriter, op, mask); - if (failed(newOp)) - return failure(); - rewriter.replaceOp(rootOp, *newOp); - return success(); - } - - return failure(); -} - -// Lower one parallel dimension. -// Incidentally also tolerates unit-size (hence trivial) reduction dimensions. -// TODO: consider reusing existing contract unrolling -FailureOr ContractionOpLowering::lowerParallel(PatternRewriter &rewriter, - vector::ContractionOp op, - int64_t lhsIndex, - int64_t rhsIndex, - Value mask) const { - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); - VectorType resType = op.getResultType().cast(); - // Find the iterator type index and result index. - SmallVector iMap = op.getIndexingMapsArray(); - int64_t iterIndex = -1; - int64_t dimSize = -1; - if (lhsIndex >= 0) { - iterIndex = iMap[0].getDimPosition(lhsIndex); - if (rhsIndex >= 0 && iterIndex != iMap[1].getDimPosition(rhsIndex)) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "expected lhsIndex=" << lhsIndex << " and rhsIndex=" << rhsIndex - << " to map to the same dimension"; - }); - dimSize = lhsType.getDimSize(lhsIndex); - } else if (rhsIndex >= 0) { - iterIndex = iMap[1].getDimPosition(rhsIndex); - dimSize = rhsType.getDimSize(rhsIndex); - } - if (iterIndex < 0) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "expected either lhsIndex=" << lhsIndex - << " or rhsIndex=" << rhsIndex << " to be nonnegative"; - }); - // value_or(-1) means that we tolerate a dimension not appearing - // in the result map. That can't happen for actual parallel iterators, but - // the caller ContractionOpLowering::matchAndRewrite is currently calling - // lowerParallel also for the case of unit-size reduction dims appearing only - // on one of LHS or RHS, not both. At the moment, such cases are created by - // CastAwayContractionLeadingOneDim, so we need to either support that or - // modify that pattern. - int64_t resIndex = getResultIndex(iMap[2], iterIndex).value_or(-1); - if (resIndex == -1 && dimSize != 1) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "expected the dimension for iterIndex=" << iterIndex - << " to either appear in the result map, or to be a unit dimension"; - }); - - // Construct new iterator types and affine map array attribute. - std::array lowIndexingMaps = { - adjustMap(iMap[0], iterIndex, rewriter), - adjustMap(iMap[1], iterIndex, rewriter), - adjustMap(iMap[2], iterIndex, rewriter)}; - auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); - auto lowIter = - rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); - // Unroll into a series of lower dimensional vector.contract ops. - Location loc = op.getLoc(); - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); - - for (int64_t d = 0; d < dimSize; ++d) { - auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); - auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); - auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter); - - Value lowMask; - if (mask) - lowMask = reshapeLoad(loc, mask, cast(mask.getType()), - iterIndex, d, rewriter); - - Operation *lowContract = rewriter.create( - loc, lhs, rhs, acc, lowAffine, lowIter); - lowContract = maskOperation(rewriter, lowContract, lowMask); - result = reshapeStore(loc, lowContract->getResult(0), result, resType, - resIndex, d, rewriter); - } - return result; -} - -// Lower one reduction dimension. -FailureOr ContractionOpLowering::lowerReduction( - PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const { - auto loc = op.getLoc(); - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); - Type resType = op.getResultType(); - if (resType.isa()) - return rewriter.notifyMatchFailure(op, - "did not expect a VectorType result"); - bool isInt = resType.isa(); - // Use iterator index 0. - int64_t iterIndex = 0; - SmallVector iMap = op.getIndexingMapsArray(); - std::optional lookupLhs = getResultIndex(iMap[0], iterIndex); - std::optional lookupRhs = getResultIndex(iMap[1], iterIndex); - if (!lookupLhs.has_value()) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "expected iterIndex=" << iterIndex << "to map to a LHS dimension"; - }); - if (!lookupRhs.has_value()) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "expected iterIndex=" << iterIndex << "to map to a RHS dimension"; - }); - int64_t lhsIndex = *lookupLhs; - int64_t rhsIndex = *lookupRhs; - int64_t dimSize = lhsType.getDimSize(lhsIndex); - if (dimSize != rhsType.getDimSize(rhsIndex)) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "expect LHS dimension " << lhsIndex - << " to have the same size as RHS dimension " << rhsIndex; - }); - // Base case. - if (lhsType.getRank() == 1) { - if (rhsType.getRank() != 1) - return rewriter.notifyMatchFailure( - op, "When LHS has rank 1, expected also RHS to have rank 1"); - Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter); - auto kind = vector::CombiningKind::ADD; - - Value acc = op.getAcc(); - Operation *reductionOp = - acc ? rewriter.create(loc, kind, m, acc) - : rewriter.create(loc, kind, m); - return maskOperation(rewriter, reductionOp, mask)->getResult(0); - } - // Construct new iterator types and affine map array attribute. - std::array lowIndexingMaps = { - adjustMap(iMap[0], iterIndex, rewriter), - adjustMap(iMap[1], iterIndex, rewriter), - adjustMap(iMap[2], iterIndex, rewriter)}; - auto lowAffine = rewriter.getAffineMapArrayAttr(lowIndexingMaps); - auto lowIter = - rewriter.getArrayAttr(adjustIter(op.getIteratorTypes(), iterIndex)); - // Unroll into a series of lower dimensional vector.contract ops. - // By feeding the initial accumulator into the first contraction, - // and the result of each contraction into the next, eventually - // the sum of all reductions is computed. - Value result = op.getAcc(); - for (int64_t d = 0; d < dimSize; ++d) { - auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter); - auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter); - Value newMask; - if (mask) - newMask = reshapeLoad(loc, mask, cast(mask.getType()), - iterIndex, d, rewriter); - - Operation *newContract = rewriter.create( - loc, lhs, rhs, result, lowAffine, lowIter); - result = maskOperation(rewriter, newContract, newMask)->getResult(0); - } - return result; -} - -} // namespace mlir - -/// Progressive lowering of transfer_read. This pattern supports lowering of -/// `vector.transfer_read` to a combination of `vector.load` and -/// `vector.broadcast` if all of the following hold: -/// - Stride of most minor memref dimension must be 1. -/// - Out-of-bounds masking is not required. -/// - If the memref's element type is a vector type then it coincides with the -/// result type. -/// - The permutation map doesn't perform permutation (broadcasting is allowed). -struct TransferReadToVectorLoadLowering - : public OpRewritePattern { - TransferReadToVectorLoadLowering(MLIRContext *context, - std::optional maxRank, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - maxTransferRank(maxRank) {} - - LogicalResult matchAndRewrite(vector::TransferReadOp read, - PatternRewriter &rewriter) const override { - if (maxTransferRank && read.getVectorType().getRank() > *maxTransferRank) - return failure(); - - SmallVector broadcastedDims; - // Permutations are handled by VectorToSCF or - // populateVectorTransferPermutationMapLoweringPatterns. - // We let the 0-d corner case pass-through as it is supported. - if (!read.getPermutationMap().isMinorIdentityWithBroadcasting( - &broadcastedDims)) - return failure(); - - auto memRefType = read.getShapedType().dyn_cast(); - if (!memRefType) - return failure(); - - // Non-unit strides are handled by VectorToSCF. - if (!vector::isLastMemrefDimUnitStride(memRefType)) - return failure(); - - // If there is broadcasting involved then we first load the unbroadcasted - // vector, and then broadcast it with `vector.broadcast`. - ArrayRef vectorShape = read.getVectorType().getShape(); - SmallVector unbroadcastedVectorShape(vectorShape.begin(), - vectorShape.end()); - for (unsigned i : broadcastedDims) - unbroadcastedVectorShape[i] = 1; - VectorType unbroadcastedVectorType = VectorType::get( - unbroadcastedVectorShape, read.getVectorType().getElementType()); - - // `vector.load` supports vector types as memref's elements only when the - // resulting vector type is the same as the element type. - auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != unbroadcastedVectorType) - return failure(); - - // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && - memrefElTy != read.getVectorType().getElementType()) - return failure(); - - // Out-of-bounds dims are handled by MaterializeTransferMask. - if (read.hasOutOfBoundsDim()) - return failure(); - - // Create vector load op. - Operation *loadOp; - if (read.getMask()) { - Value fill = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.getPadding()); - loadOp = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.getSource(), - read.getIndices(), read.getMask(), fill); - } else { - loadOp = rewriter.create( - read.getLoc(), unbroadcastedVectorType, read.getSource(), - read.getIndices()); - } - - // Insert a broadcasting op if required. - if (!broadcastedDims.empty()) { - rewriter.replaceOpWithNewOp( - read, read.getVectorType(), loadOp->getResult(0)); - } else { - rewriter.replaceOp(read, loadOp->getResult(0)); - } - - return success(); - } - - std::optional maxTransferRank; -}; - -/// Replace a 0-d vector.load with a memref.load + vector.broadcast. -// TODO: we shouldn't cross the vector/scalar domains just for this -// but atm we lack the infra to avoid it. Possible solutions include: -// - go directly to LLVM + bitcast -// - introduce a bitcast op and likely a new pointer dialect -// - let memref.load/store additionally support the 0-d vector case -// There are still deeper data layout issues lingering even in this -// trivial case (for architectures for which this matters). -struct VectorLoadToMemrefLoadLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::LoadOp loadOp, - PatternRewriter &rewriter) const override { - auto vecType = loadOp.getVectorType(); - if (vecType.getNumElements() != 1) - return failure(); - auto memrefLoad = rewriter.create( - loadOp.getLoc(), loadOp.getBase(), loadOp.getIndices()); - rewriter.replaceOpWithNewOp(loadOp, vecType, - memrefLoad); - return success(); - } -}; - -/// Replace a 0-d vector.store with a vector.extractelement + memref.store. -struct VectorStoreToMemrefStoreLowering - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::StoreOp storeOp, - PatternRewriter &rewriter) const override { - auto vecType = storeOp.getVectorType(); - if (vecType.getNumElements() != 1) - return failure(); - Value extracted; - if (vecType.getRank() == 0) { - // TODO: Unifiy once ExtractOp supports 0-d vectors. - extracted = rewriter.create( - storeOp.getLoc(), storeOp.getValueToStore()); - } else { - SmallVector indices(vecType.getRank(), 0); - extracted = rewriter.create( - storeOp.getLoc(), storeOp.getValueToStore(), indices); - } - - rewriter.replaceOpWithNewOp( - storeOp, extracted, storeOp.getBase(), storeOp.getIndices()); - return success(); - } -}; - -/// Progressive lowering of transfer_write. This pattern supports lowering of -/// `vector.transfer_write` to `vector.store` if all of the following hold: -/// - Stride of most minor memref dimension must be 1. -/// - Out-of-bounds masking is not required. -/// - If the memref's element type is a vector type then it coincides with the -/// type of the written value. -/// - The permutation map is the minor identity map (neither permutation nor -/// broadcasting is allowed). -struct TransferWriteToVectorStoreLowering - : public OpRewritePattern { - TransferWriteToVectorStoreLowering(MLIRContext *context, - std::optional maxRank, - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - maxTransferRank(maxRank) {} - - LogicalResult matchAndRewrite(vector::TransferWriteOp write, - PatternRewriter &rewriter) const override { - if (maxTransferRank && write.getVectorType().getRank() > *maxTransferRank) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "rank exceeds maxTransferRank: " << write; - }); - - // Permutations are handled by VectorToSCF or - // populateVectorTransferPermutationMapLoweringPatterns. - if ( // pass-through for the 0-d corner case. - !write.getPermutationMap().isMinorIdentity()) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "permutation map is not minor identity: " << write; - }); - - auto memRefType = write.getShapedType().dyn_cast(); - if (!memRefType) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "not a memref type: " << write; - }); - - // Non-unit strides are handled by VectorToSCF. - if (!vector::isLastMemrefDimUnitStride(memRefType)) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "most minor stride is not 1: " << write; - }); - - // `vector.store` supports vector types as memref's elements only when the - // type of the vector value being written is the same as the element type. - auto memrefElTy = memRefType.getElementType(); - if (memrefElTy.isa() && memrefElTy != write.getVectorType()) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "elemental type mismatch: " << write; - }); - - // Otherwise, element types of the memref and the vector must match. - if (!memrefElTy.isa() && - memrefElTy != write.getVectorType().getElementType()) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "elemental type mismatch: " << write; - }); - - // Out-of-bounds dims are handled by MaterializeTransferMask. - if (write.hasOutOfBoundsDim()) - return rewriter.notifyMatchFailure(write.getLoc(), [=](Diagnostic &diag) { - diag << "out of bounds dim: " << write; - }); - if (write.getMask()) { - rewriter.replaceOpWithNewOp( - write, write.getSource(), write.getIndices(), write.getMask(), - write.getVector()); - } else { - rewriter.replaceOpWithNewOp( - write, write.getVector(), write.getSource(), write.getIndices()); - } - return success(); - } - - std::optional maxTransferRank; -}; - // Returns the values in `arrayAttr` as an integer vector. static SmallVector getIntValueVector(ArrayAttr arrayAttr) { return llvm::to_vector<4>( @@ -2863,202 +1027,6 @@ } }; -namespace { - -/// This function checks to see if the vector combining kind -/// is consistent with the integer or float element type. -static bool isValidKind(bool isInt, vector::CombiningKind kind) { - using vector::CombiningKind; - enum class KindType { FLOAT, INT, INVALID }; - KindType type{KindType::INVALID}; - switch (kind) { - case CombiningKind::MINF: - case CombiningKind::MAXF: - type = KindType::FLOAT; - break; - case CombiningKind::MINUI: - case CombiningKind::MINSI: - case CombiningKind::MAXUI: - case CombiningKind::MAXSI: - case CombiningKind::AND: - case CombiningKind::OR: - case CombiningKind::XOR: - type = KindType::INT; - break; - case CombiningKind::ADD: - case CombiningKind::MUL: - type = isInt ? KindType::INT : KindType::FLOAT; - break; - } - bool isValidIntKind = (type == KindType::INT) && isInt; - bool isValidFloatKind = (type == KindType::FLOAT) && (!isInt); - return (isValidIntKind || isValidFloatKind); -} - -/// This function constructs the appropriate integer or float -/// operation given the vector combining kind and operands. The -/// supported int operations are : add, mul, min (signed/unsigned), -/// max(signed/unsigned), and, or, xor. The supported float -/// operations are : add, mul, min and max. -static Value genOperator(Location loc, Value x, Value y, - vector::CombiningKind kind, - PatternRewriter &rewriter) { - using vector::CombiningKind; - - auto elType = x.getType().cast().getElementType(); - bool isInt = elType.isIntOrIndex(); - - Value combinedResult{nullptr}; - switch (kind) { - case CombiningKind::ADD: - if (isInt) - combinedResult = rewriter.create(loc, x, y); - else - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::MUL: - if (isInt) - combinedResult = rewriter.create(loc, x, y); - else - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::MINUI: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::MINSI: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::MAXUI: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::MAXSI: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::AND: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::OR: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::XOR: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::MINF: - combinedResult = rewriter.create(loc, x, y); - break; - case CombiningKind::MAXF: - combinedResult = rewriter.create(loc, x, y); - break; - } - return combinedResult; -} - -/// Convert vector.scan op into arith ops and -/// vector.insert_strided_slice/extract_strided_slice -/// -/// Ex: -/// ``` -/// %0:2 = vector.scan , %arg0, %arg1 {inclusive = true, reduction_dim = -/// 1} : -/// (vector<2x3xi32>, vector<2xi32>) to (vector<2x3xi32>, vector<2xi32>) -/// ``` -/// Gets converted to: -/// ``` -/// %cst = arith.constant dense<0> : vector<2x3xi32> -/// %0 = vector.extract_strided_slice %arg0 {offsets = [0, 0], sizes = [2, 1], -/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %1 = -/// vector.insert_strided_slice %0, %cst {offsets = [0, 0], strides = [1, 1]} -/// : vector<2x1xi32> into vector<2x3xi32> %2 = vector.extract_strided_slice -/// %arg0 {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : -/// vector<2x3xi32> to vector<2x1xi32> %3 = arith.muli %0, %2 : -/// vector<2x1xi32> %4 = vector.insert_strided_slice %3, %1 {offsets = [0, 1], -/// strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %5 = -/// vector.extract_strided_slice %arg0 {offsets = [0, 2], sizes = [2, 1], -/// strides = [1, 1]} : vector<2x3xi32> to vector<2x1xi32> %6 = arith.muli %3, -/// %5 : vector<2x1xi32> %7 = vector.insert_strided_slice %6, %4 {offsets = -/// [0, 2], strides = [1, 1]} : vector<2x1xi32> into vector<2x3xi32> %8 = -/// vector.shape_cast %6 : vector<2x1xi32> to vector<2xi32> return %7, %8 : -/// vector<2x3xi32>, vector<2xi32> -/// ``` -struct ScanToArithOps : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::ScanOp scanOp, - PatternRewriter &rewriter) const override { - auto loc = scanOp.getLoc(); - VectorType destType = scanOp.getDestType(); - ArrayRef destShape = destType.getShape(); - auto elType = destType.getElementType(); - bool isInt = elType.isIntOrIndex(); - if (!isValidKind(isInt, scanOp.getKind())) - return failure(); - - VectorType resType = VectorType::get(destShape, elType); - Value result = rewriter.create( - loc, resType, rewriter.getZeroAttr(resType)); - int64_t reductionDim = scanOp.getReductionDim(); - bool inclusive = scanOp.getInclusive(); - int64_t destRank = destType.getRank(); - VectorType initialValueType = scanOp.getInitialValueType(); - int64_t initialValueRank = initialValueType.getRank(); - - SmallVector reductionShape(destShape.begin(), destShape.end()); - reductionShape[reductionDim] = 1; - VectorType reductionType = VectorType::get(reductionShape, elType); - SmallVector offsets(destRank, 0); - SmallVector strides(destRank, 1); - SmallVector sizes(destShape.begin(), destShape.end()); - sizes[reductionDim] = 1; - ArrayAttr scanSizes = rewriter.getI64ArrayAttr(sizes); - ArrayAttr scanStrides = rewriter.getI64ArrayAttr(strides); - - Value lastOutput, lastInput; - for (int i = 0; i < destShape[reductionDim]; i++) { - offsets[reductionDim] = i; - ArrayAttr scanOffsets = rewriter.getI64ArrayAttr(offsets); - Value input = rewriter.create( - loc, reductionType, scanOp.getSource(), scanOffsets, scanSizes, - scanStrides); - Value output; - if (i == 0) { - if (inclusive) { - output = input; - } else { - if (initialValueRank == 0) { - // ShapeCastOp cannot handle 0-D vectors - output = rewriter.create( - loc, input.getType(), scanOp.getInitialValue()); - } else { - output = rewriter.create( - loc, input.getType(), scanOp.getInitialValue()); - } - } - } else { - Value y = inclusive ? input : lastInput; - output = genOperator(loc, lastOutput, y, scanOp.getKind(), rewriter); - assert(output != nullptr); - } - result = rewriter.create( - loc, output, result, offsets, strides); - lastOutput = output; - lastInput = input; - } - - Value reduction; - if (initialValueRank == 0) { - Value v = rewriter.create(loc, lastOutput, 0); - reduction = - rewriter.create(loc, initialValueType, v); - } else { - reduction = rewriter.create(loc, initialValueType, - lastOutput); - } - - rewriter.replaceOp(scanOp, {result, reduction}); - return success(); - } -}; - /// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul /// semantics to a contraction suitable for MMT (matrix matrix multiplication /// with the RHS transposed) lowering. @@ -3157,132 +1125,6 @@ FilterConstraintType filter; }; -/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the -/// outermost dimension. For example: -/// ``` -/// %g = vector.gather %base[%c0][%v], %mask, %pass_thru : -/// ... into vector<2x3xf32> -/// -/// ==> -/// -/// %0 = arith.constant dense<0.0> : vector<2x3xf32> -/// %g0 = vector.gather %base[%c0][%v0], %mask0, %pass_thru0 : ... -/// %1 = vector.insert %g0, %0 [0] : vector<3xf32> into vector<2x3xf32> -/// %g1 = vector.gather %base[%c0][%v1], %mask1, %pass_thru1 : ... -/// %g = vector.insert %g1, %1 [1] : vector<3xf32> into vector<2x3xf32> -/// ``` -/// -/// When applied exhaustively, this will produce a sequence of 1-d gather ops. -struct FlattenGather : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::GatherOp op, - PatternRewriter &rewriter) const override { - VectorType resultTy = op.getType(); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already flat"); - - Location loc = op.getLoc(); - Value indexVec = op.getIndexVec(); - Value maskVec = op.getMask(); - Value passThruVec = op.getPassThru(); - - Value result = rewriter.create( - loc, resultTy, rewriter.getZeroAttr(resultTy)); - - Type subTy = VectorType::get(resultTy.getShape().drop_front(), - resultTy.getElementType()); - - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - int64_t thisIdx[1] = {i}; - - Value indexSubVec = - rewriter.create(loc, indexVec, thisIdx); - Value maskSubVec = - rewriter.create(loc, maskVec, thisIdx); - Value passThruSubVec = - rewriter.create(loc, passThruVec, thisIdx); - Value subGather = rewriter.create( - loc, subTy, op.getBase(), op.getIndices(), indexSubVec, maskSubVec, - passThruSubVec); - result = - rewriter.create(loc, subGather, result, thisIdx); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - -/// Turns 1-d `vector.gather` into a scalarized sequence of `vector.loads` or -/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these -/// loads/extracts are made conditional using `scf.if` ops. -struct Gather1DToConditionalLoads : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(vector::GatherOp op, - PatternRewriter &rewriter) const override { - VectorType resultTy = op.getType(); - if (resultTy.getRank() != 1) - return rewriter.notifyMatchFailure(op, "unsupported rank"); - - Location loc = op.getLoc(); - Type elemTy = resultTy.getElementType(); - // Vector type with a single element. Used to generate `vector.loads`. - VectorType elemVecTy = VectorType::get({1}, elemTy); - - Value condMask = op.getMask(); - Value base = op.getBase(); - Value indexVec = rewriter.createOrFold( - loc, op.getIndexVectorType().clone(rewriter.getIndexType()), - op.getIndexVec()); - auto baseOffsets = llvm::to_vector(op.getIndices()); - Value lastBaseOffset = baseOffsets.back(); - - Value result = op.getPassThru(); - - // Emit a conditional access for each vector element. - for (int64_t i = 0, e = resultTy.getNumElements(); i < e; ++i) { - int64_t thisIdx[1] = {i}; - Value condition = - rewriter.create(loc, condMask, thisIdx); - Value index = rewriter.create(loc, indexVec, thisIdx); - baseOffsets.back() = - rewriter.createOrFold(loc, lastBaseOffset, index); - - auto loadBuilder = [&](OpBuilder &b, Location loc) { - Value extracted; - if (isa(base.getType())) { - // `vector.load` does not support scalar result; emit a vector load - // and extract the single result instead. - Value load = - b.create(loc, elemVecTy, base, baseOffsets); - int64_t zeroIdx[1] = {0}; - extracted = b.create(loc, load, zeroIdx); - } else { - extracted = b.create(loc, base, baseOffsets); - } - - Value newResult = - b.create(loc, extracted, result, thisIdx); - b.create(loc, newResult); - }; - auto passThruBuilder = [result](OpBuilder &b, Location loc) { - b.create(loc, result); - }; - - result = - rewriter - .create(loc, condition, /*thenBuilder=*/loadBuilder, - /*elseBuilder=*/passThruBuilder) - .getResult(0); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - } // namespace void mlir::vector::populateVectorMaskMaterializationPatterns( @@ -3307,33 +1149,6 @@ benefit); } -void mlir::vector::populateVectorBroadcastLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} - -void mlir::vector::populateVectorMaskOpLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), benefit); -} - -void mlir::vector::populateVectorShapeCastLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add( - patterns.getContext(), benefit); -} - -void mlir::vector::populateVectorContractLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions options, - PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); - patterns.add( - options, patterns.getContext(), benefit); -} - void mlir::vector::populateVectorContractCanonicalizeMatmulToMMT( RewritePatternSet &patterns, std::function constraint, @@ -3342,13 +1157,6 @@ std::move(constraint)); } -void mlir::vector::populateVectorTransposeLoweringPatterns( - RewritePatternSet &patterns, VectorTransformsOptions options, - PatternBenefit benefit) { - patterns.add( - options, patterns.getContext(), benefit); -} - void mlir::vector::populateVectorReductionToContractPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(patterns.getContext(), benefit); } -void mlir::vector::populateVectorTransferLoweringPatterns( - RewritePatternSet &patterns, std::optional maxTransferRank, - PatternBenefit benefit) { - patterns.add(patterns.getContext(), - maxTransferRank, benefit); - patterns - .add( - patterns.getContext(), benefit); -} - -void mlir::vector::populateVectorScanLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), benefit); -} - -void mlir::vector::populateVectorGatherLoweringPatterns( - RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add(patterns.getContext(), - benefit); -} - //===----------------------------------------------------------------------===// // TableGen'd enum attribute definitions //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL #dotp_accesses = [ @@ -1182,32 +1181,6 @@ return %0 : vector<3x2xf32> } -// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered -// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32> -// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] -func.func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>) --> vector<4x4xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> - return %0 : vector<4x4xf32> -} - -// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered -// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32> -// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] -func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>) --> vector<3x4xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> - return %0 : vector<3x4xf32> -} - // PARALLEL-LABEL: func @parrallel_contract_lowering // PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> // PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -22,6 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorDistribution.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" @@ -136,11 +137,6 @@ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), llvm::cl::init(false)}; - Option lowerToFilterOuterProduct{ - *this, "vector-filter-outerproduct", - llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " - "vectors of size 4."), - llvm::cl::init(false)}; Option lowerToParallelArith{ *this, "vector-parallel-arith", llvm::cl::desc("Lower vector.contract to elementwise vector ops."), @@ -153,24 +149,9 @@ if (lowerToOuterProduct) { VectorContractLowering lowering = VectorContractLowering::OuterProduct; VectorTransformsOptions options{lowering}; - patterns.add(options, - &getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } - - // Test on one pattern in isolation. - if (lowerToFilterOuterProduct) { - VectorContractLowering lowering = VectorContractLowering::OuterProduct; - VectorTransformsOptions options{lowering}; - patterns.add( - options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) { - // Only lowers vector.contract where the lhs as a type vector - // where M is not 4. - if (op.getRhsType().getShape()[0] == 4) - return failure(); - return success(); - }); + populateVectorContractLoweringPatterns( + patterns, options, /*benefit=*/1, + /*disableOuterProductlowering=*/true); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); return; } @@ -490,7 +471,7 @@ options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy); else options.setVectorTransferSplit(VectorTransferSplit::VectorTransfer); - patterns.add(ctx, options); + populateVectorTransferFullPartialPatterns(patterns, options); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } };