diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -32,7 +32,8 @@ /// Populates patterns for vectorization of all ConvN-D ops. void populateConvVectorizationPatterns( - MLIRContext *context, SmallVectorImpl &patterns); + MLIRContext *context, SmallVectorImpl &patterns, + ArrayRef tileSizes); /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` @@ -549,8 +550,8 @@ /// false of size 1. This ensures that the ConvOp can be lowered to vector /// contraction of dimensions marked in the *mask* as true. /// -/// A good example is ConvNHWCOp which is 2D Conv op with channels as the last -/// dimension. For this op we contract last 3 dimensions. +/// A good example for vectorization is ConvNHWCOp which is 2D Conv op +/// with channels as the last dimension. Let's vectorize last 3 dimensions. /// The initial op definition looks like this: /// ``` /// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : @@ -589,10 +590,6 @@ LogicalResult matchAndRewrite(ConvOp minOp, PatternRewriter &rewriter) const override; - - // TODO: Make these pass arguments. - static const int tileSize = 3; - static const int noTile = 1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=4" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,4" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,4" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,0,4,4" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,3,2" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2,2" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,5,5,5" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir @@ -9,13 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,5,5,5" \ -// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ +// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -385,16 +385,19 @@ return failure(); SmallVector mapping; - // Fail to apply when the size of not vectorized dimension is not 1 or - // when the size of vectorized dimension is not dimSize. + SmallVector vectorDims; + // Fail to apply when the size of not vectorized dimension is not 1. for (unsigned i = 0; i < N; i++) { if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) return failure(); - if (mask[i] && (inShape[i] != tileSize || kShape[i] != tileSize)) + + if (mask[i] && inShape[i] != kShape[i]) return failure(); - if (mask[i]) + if (mask[i]) { mapping.push_back(getAffineDimExpr(i, context)); + vectorDims.push_back(inShape[i]); + } } Value input = op.getInput(0); @@ -407,8 +410,7 @@ auto map = AffineMap::get(rank, 0, mapping, context); SmallVector zeros(rank, std_constant_index(0)); - auto vecType = - VectorType::get(SmallVector(numDims, tileSize), elemType); + auto vecType = VectorType::get(vectorDims, elemType); auto inputVec = vector_transfer_read(vecType, input, zeros, map); auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); @@ -443,6 +445,9 @@ OwningRewritePatternList &vectorizationPatterns, ArrayRef tileSizes, MLIRContext *context) { + if (tileSizes.size() < N) + return; + constexpr static StringRef kTiledMarker = "TILED"; constexpr static StringRef kPromotedMarker = "PROMOTED"; tilingPatterns.insert>( @@ -457,49 +462,41 @@ SmallVector mask(N); int offset = tileSizes.size() - N; std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), - [](int64_t i) -> bool { return i != ConvOpConst::noTile; }); + [](int64_t i) -> bool { return i > 1; }); vectorizationPatterns.insert>(context, mask); } void mlir::linalg::populateConvVectorizationPatterns( - MLIRContext *context, SmallVectorImpl &patterns) { - const int64_t tileSize = ConvOpConst::tileSize; - const int64_t noTile = ConvOpConst::noTile; - auto makeTileSizes = [&](unsigned numNoTile, unsigned numTile) { - SmallVector result(numNoTile, noTile); - result.append(numTile, tileSize); - return result; - }; - + MLIRContext *context, SmallVectorImpl &patterns, + ArrayRef tileSizes) { OwningRewritePatternList tiling, promotion, vectorization; - populateVectorizationPatterns( - tiling, promotion, vectorization, - makeTileSizes(/*numNoTile=*/1, /*numTile*/ 1), context); + populateVectorizationPatterns(tiling, promotion, vectorization, + tileSizes, context); populateVectorizationPatterns(tiling, promotion, vectorization, - makeTileSizes(3, 2), context); + tileSizes, context); populateVectorizationPatterns(tiling, promotion, vectorization, - makeTileSizes(3, 2), context); + tileSizes, context); populateVectorizationPatterns(tiling, promotion, vectorization, - makeTileSizes(2, 2), context); + tileSizes, context); populateVectorizationPatterns(tiling, promotion, vectorization, - makeTileSizes(4, 3), context); + tileSizes, context); populateVectorizationPatterns(tiling, promotion, vectorization, - makeTileSizes(4, 3), context); + tileSizes, context); populateVectorizationPatterns(tiling, promotion, vectorization, - makeTileSizes(3, 3), context); + tileSizes, context); populateVectorizationPatterns( - tiling, promotion, vectorization, makeTileSizes(5, 4), context); + tiling, promotion, vectorization, tileSizes, context); populateVectorizationPatterns( - tiling, promotion, vectorization, makeTileSizes(5, 4), context); + tiling, promotion, vectorization, tileSizes, context); patterns.push_back(std::move(tiling)); patterns.push_back(std::move(promotion)); diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir --- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir +++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-conv-vectorization --cse | FileCheck %s +// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse | FileCheck %s // CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)> // CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -24,6 +24,13 @@ /// A pass converting MLIR Linalg ops into Vector ops. class TestConvVectorization : public PassWrapper> { +public: + TestConvVectorization() = default; + TestConvVectorization(const TestConvVectorization &) {} + explicit TestConvVectorization(ArrayRef tileSizesParam) { + tileSizes = tileSizesParam; + } + void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { @@ -33,6 +40,10 @@ registry.insert(); registry.insert(); } + + ListOption tileSizes{ + *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."), + llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; }; } // namespace @@ -47,7 +58,7 @@ target.addLegalOp(); SmallVector stage1Patterns; - linalg::populateConvVectorizationPatterns(context, stage1Patterns); + linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); OwningRewritePatternList stage2Patterns = linalg::getLinalgTilingCanonicalizationPatterns(context);