# Please use GitHub pull requests for new patches. Avoid migrating existing patches. Phabricator shutdown timeline

# Changeset View

Changeset View

# Standalone View

Standalone View

# mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Show All 26 Lines | |||||

struct TiledLinalgOp { | struct TiledLinalgOp { | ||||

LinalgOp op; | LinalgOp op; | ||||

SmallVector<Operation *, 8> loops; | SmallVector<Operation *, 8> loops; | ||||

}; | }; | ||||

/// Populates patterns for vectorization of all ConvN-D ops. | /// Populates patterns for vectorization of all ConvN-D ops. | ||||

void populateConvVectorizationPatterns( | void populateConvVectorizationPatterns( | ||||

MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns); | MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns, | ||||

ArrayRef<int64_t> tileSizes); | |||||

/// Performs standalone tiling of a single LinalgOp by `tileSizes`. | /// Performs standalone tiling of a single LinalgOp by `tileSizes`. | ||||

/// and permute the loop nest according to `interchangeVector` | /// and permute the loop nest according to `interchangeVector` | ||||

/// The permutation is expressed as a list of integers that specify | /// The permutation is expressed as a list of integers that specify | ||||

/// the new ordering of the loop nest. The length of `interchangeVector` | /// the new ordering of the loop nest. The length of `interchangeVector` | ||||

/// must be equal to the length of `tileSizes`. | /// must be equal to the length of `tileSizes`. | ||||

/// An empty vector is interpreted as the identity permutation and the | /// An empty vector is interpreted as the identity permutation and the | ||||

/// transformation returns early. | /// transformation returns early. | ||||

▲ Show 20 Lines • Show All 500 Lines • ▼ Show 20 Lines | |||||

}; | }; | ||||

/// Converts Convolution op into vector contraction. | /// Converts Convolution op into vector contraction. | ||||

/// | /// | ||||

/// Conversion expects ConvOp to have dimensions marked in the *mask* as | /// Conversion expects ConvOp to have dimensions marked in the *mask* as | ||||

/// false of size 1. This ensures that the ConvOp can be lowered to vector | /// false of size 1. This ensures that the ConvOp can be lowered to vector | ||||

/// contraction of dimensions marked in the *mask* as true. | /// contraction of dimensions marked in the *mask* as true. | ||||

/// | /// | ||||

/// A good example is ConvNHWCOp which is 2D Conv op with channels as the last | /// A good example for vectorization is ConvNHWCOp which is 2D Conv op | ||||

/// dimension. For this op we contract last 3 dimensions. | /// with channels as the last dimension. Let's vectorize last 3 dimensions. | ||||

/// The initial op definition looks like this: | /// The initial op definition looks like this: | ||||

/// ``` | /// ``` | ||||

/// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : | /// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : | ||||

/// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>) | /// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref<?x?x?x?xf32>) | ||||

/// ``` | /// ``` | ||||

/// This op can be expressed as a dot product between %arg0 (input) and | /// This op can be expressed as a dot product between %arg0 (input) and | ||||

/// %arg1 (kernel) which is written into first entry of %arg2 (output). This is | /// %arg1 (kernel) which is written into first entry of %arg2 (output). This is | ||||

/// the ConvOp this pass expects and converts into: | /// the ConvOp this pass expects and converts into: | ||||

Show All 22 Lines | public: | ||||

ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk) | ConvOpVectorization(MLIRContext *context, SmallVector<bool, 4> msk) | ||||

: OpRewritePattern<ConvOp>(context) { | : OpRewritePattern<ConvOp>(context) { | ||||

assert(msk.size() == N && "Mask size does not match rank"); | assert(msk.size() == N && "Mask size does not match rank"); | ||||

this->mask = msk; | this->mask = msk; | ||||

} | } | ||||

LogicalResult matchAndRewrite(ConvOp minOp, | LogicalResult matchAndRewrite(ConvOp minOp, | ||||

PatternRewriter &rewriter) const override; | PatternRewriter &rewriter) const override; | ||||

// TODO: Make these pass arguments. | |||||

static const int tileSize = 3; | |||||

static const int noTile = 1; | |||||

}; | }; | ||||

//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||

// Support for staged pattern application. | // Support for staged pattern application. | ||||

//===----------------------------------------------------------------------===// | //===----------------------------------------------------------------------===// | ||||

/// Helper function to allow applying rewrite patterns, interleaved with more | /// Helper function to allow applying rewrite patterns, interleaved with more | ||||

/// global transformations, in a staged fashion: | /// global transformations, in a staged fashion: | ||||

/// 1. the first stage consists of a list of OwningRewritePatternList. Each | /// 1. the first stage consists of a list of OwningRewritePatternList. Each | ||||

Show All 16 Lines |