Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline
Changeset View
Changeset View
Standalone View
Standalone View
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Show All 26 Lines | |||||
struct TiledLinalgOp { | struct TiledLinalgOp { | ||||
LinalgOp op; | LinalgOp op; | ||||
SmallVector<Operation *, 8> loops; | SmallVector<Operation *, 8> loops; | ||||
}; | }; | ||||
/// Populates patterns for vectorization of all ConvN-D ops. | /// Populates patterns for vectorization of all ConvN-D ops. | ||||
void populateConvVectorizationPatterns( | void populateConvVectorizationPatterns( | ||||
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns); | MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, | ||||
ArrayRef<int64_t> tileSizes); | |||||
/// Performs standalone tiling of a single LinalgOp by `tileSizes`. | /// Performs standalone tiling of a single LinalgOp by `tileSizes`. | ||||
/// and permute the loop nest according to `interchangeVector` | /// and permute the loop nest according to `interchangeVector` | ||||
/// The permutation is expressed as a list of integers that specify | /// The permutation is expressed as a list of integers that specify | ||||
/// the new ordering of the loop nest. The length of `interchangeVector` | /// the new ordering of the loop nest. The length of `interchangeVector` | ||||
/// must be equal to the length of `tileSizes`. | /// must be equal to the length of `tileSizes`. | ||||
/// An empty vector is interpreted as the identity permutation and the | /// An empty vector is interpreted as the identity permutation and the | ||||
/// transformation returns early. | /// transformation returns early. | ||||
▲ Show 20 Lines • Show All 500 Lines • ▼ Show 20 Lines | |||||
}; | }; | ||||
/// Converts Convolution op into vector contraction. | /// Converts Convolution op into vector contraction. | ||||
/// | /// | ||||
/// Conversion expects ConvOp to have dimensions marked in the *mask* as | /// Conversion expects ConvOp to have dimensions marked in the *mask* as | ||||
/// false of size 1. This ensures that the ConvOp can be lowered to vector | /// false of size 1. This ensures that the ConvOp can be lowered to vector | ||||
/// contraction of dimensions marked in the *mask* as true. | /// contraction of dimensions marked in the *mask* as true. | ||||
/// | /// | ||||
/// A good example is ConvNHWCOp which is 2D Conv op with channels as the last | /// A good example for vectorization is ConvNHWCOp which is 2D Conv op | ||||
/// dimension. For this op we contract last 3 dimensions. | /// with channels as the last dimension. Let's vectorize last 3 dimensions. | ||||
/// The initial op definition looks like this: | /// The initial op definition looks like this: | ||||
/// ``` | /// ``` | ||||
/// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : | /// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : | ||||
/// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>) | /// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>) | ||||
/// ``` | /// ``` | ||||
/// This op can be expressed as a dot product between %arg0 (input) and | /// This op can be expressed as a dot product between %arg0 (input) and | ||||
/// %arg1 (kernel) which is written into first entry of %arg2 (output). This is | /// %arg1 (kernel) which is written into first entry of %arg2 (output). This is | ||||
/// the ConvOp this pass expects and converts into: | /// the ConvOp this pass expects and converts into: | ||||
Show All 22 Lines | public: | ||||
ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk) | ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk) | ||||
: OpRewritePattern<ConvOp>(context) { | : OpRewritePattern<ConvOp>(context) { | ||||
assert(msk.size() == N && "Mask size does not match rank"); | assert(msk.size() == N && "Mask size does not match rank"); | ||||
this->mask = msk; | this->mask = msk; | ||||
} | } | ||||
LogicalResult matchAndRewrite(ConvOp minOp, | LogicalResult matchAndRewrite(ConvOp minOp, | ||||
PatternRewriter &rewriter) const override; | PatternRewriter &rewriter) const override; | ||||
// TODO: Make these pass arguments. | |||||
static const int tileSize = 3; | |||||
static const int noTile = 1; | |||||
}; | }; | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
// Support for staged pattern application. | // Support for staged pattern application. | ||||
//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||
/// Helper function to allow applying rewrite patterns, interleaved with more | /// Helper function to allow applying rewrite patterns, interleaved with more | ||||
/// global transformations, in a staged fashion: | /// global transformations, in a staged fashion: | ||||
/// 1. the first stage consists of a list of OwningRewritePatternList. Each | /// 1. the first stage consists of a list of OwningRewritePatternList. Each | ||||
Show All 16 Lines |