diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -31,8 +31,8 @@ }; /// Populates patterns for vectorization of all ConvN-D ops. -void populateConvVectorizationPatterns(MLIRContext *context, - OwningRewritePatternList &patterns); +void populateConvVectorizationPatterns( + MLIRContext *context, SmallVectorImpl &patterns); /// Performs standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` @@ -589,6 +589,10 @@ LogicalResult matchAndRewrite(ConvOp minOp, PatternRewriter &rewriter) const override; + + // TODO: Make these pass arguments. + static const int tileSize = 3; + static const int noTile = 1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=4" -linalg-tile="linalg-tile-sizes=1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=4" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-ncw-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,4" -linalg-tile="linalg-tile-sizes=1,1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,4" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-1d-nwc-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,4" -linalg-tile="linalg-tile-sizes=1,1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,4" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2" -linalg-tile="linalg-tile-sizes=1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nchw-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1,1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,0,4,4" -linalg-tile="linalg-tile-sizes=1,1,1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,0,4,4" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-2d-nhwc-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1,1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,3,2" -linalg-tile="linalg-tile-sizes=1,1,1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,3,2" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2,2" -linalg-tile="linalg-tile-sizes=1,1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,2,2" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ncdhw-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1,1,1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,5,5,5" -linalg-tile="linalg-tile-sizes=1,1,1,1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,0,5,5,5" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir --- a/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir +++ b/mlir/integration_test/Dialect/Linalg/CPU/test-conv-3d-ndhwc-call.mlir @@ -9,17 +9,13 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=1,1,1,1,1" -test-conv-vectorization \ -// RUN: -convert-linalg-to-loops -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,5,5,5" -linalg-tile="linalg-tile-sizes=1,1,1,1,1" \ -// RUN: -test-conv-vectorization -convert-linalg-to-loops \ -// RUN: -test-vector-contraction-conversion=vector-outerproduct=0 \ -// RUN: -convert-vector-to-scf -convert-linalg-to-llvm | \ +// RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=0,5,5,5" \ +// RUN: -test-conv-vectorization -convert-linalg-to-llvm | \ // RUN: mlir-cpu-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -371,7 +371,6 @@ template LogicalResult ConvOpVectorization::matchAndRewrite( ConvOp op, PatternRewriter &rewriter) const { - unsigned dimSize = 3; Location loc = op.getLoc(); MLIRContext *context = op.getContext(); edsc::ScopedContext scope(rewriter, loc); @@ -391,7 +390,7 @@ for (unsigned i = 0; i < N; i++) { if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) return failure(); - if (mask[i] && (inShape[i] != dimSize || kShape[i] != dimSize)) + if (mask[i] && (inShape[i] != tileSize || kShape[i] != tileSize)) return failure(); if (mask[i]) @@ -409,7 +408,7 @@ auto map = AffineMap::get(rank, 0, mapping, context); SmallVector zeros(rank, std_constant_index(0)); auto vecType = - VectorType::get(SmallVector(numDims, dimSize), elemType); + VectorType::get(SmallVector(numDims, tileSize), elemType); auto inputVec = vector_transfer_read(vecType, input, zeros, map); auto kernelVec = vector_transfer_read(vecType, kernel, zeros, map); @@ -433,32 +432,76 @@ return success(); } +using ConvOpConst = ConvOpVectorization; + +/// Inserts tiling, promotion and vectorization pattern for ConvOp +/// conversion into corresponding pattern lists. +template +static void +populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns, + OwningRewritePatternList &promotionPatterns, + OwningRewritePatternList &vectorizationPatterns, + ArrayRef tileSizes, + MLIRContext *context) { + constexpr static StringRef kTiledMarker = "TILED"; + constexpr static StringRef kPromotedMarker = "PROMOTED"; + tilingPatterns.insert>( + context, LinalgTilingOptions().setTileSizes(tileSizes), + LinalgMarker({}, Identifier::get(kTiledMarker, context))); + + promotionPatterns.insert>( + context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), + LinalgMarker(Identifier::get(kTiledMarker, context), + Identifier::get(kPromotedMarker, context))); + + SmallVector mask(N); + int offset = tileSizes.size() - N; + std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), + [](int64_t i) -> bool { return i != ConvOpConst::noTile; }); + + vectorizationPatterns.insert>(context, mask); +} + void mlir::linalg::populateConvVectorizationPatterns( - MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert>( - context, SmallVector{true}); + MLIRContext *context, SmallVectorImpl &patterns) { + const int64_t tileSize = ConvOpConst::tileSize; + const int64_t noTile = ConvOpConst::noTile; + auto makeTileSizes = [&](unsigned numNoTile, unsigned numTile) { + SmallVector result(numNoTile, noTile); + result.append(numTile, tileSize); + return result; + }; + + OwningRewritePatternList tiling, promotion, vectorization; + populateVectorizationPatterns( + tiling, promotion, vectorization, + makeTileSizes(/*numNoTile=*/1, /*numTile*/ 1), context); + + populateVectorizationPatterns(tiling, promotion, vectorization, + makeTileSizes(3, 2), context); - patterns.insert>( - context, SmallVector{false, true, true}); + populateVectorizationPatterns(tiling, promotion, vectorization, + makeTileSizes(3, 2), context); - patterns.insert>( - context, SmallVector{false, true, true}); + populateVectorizationPatterns(tiling, promotion, vectorization, + makeTileSizes(2, 2), context); - patterns.insert>( - context, SmallVector{true, true}); + populateVectorizationPatterns(tiling, promotion, vectorization, + makeTileSizes(4, 3), context); - patterns.insert>( - context, SmallVector{false, true, true, true}); + populateVectorizationPatterns(tiling, promotion, vectorization, + makeTileSizes(4, 3), context); - patterns.insert>( - context, SmallVector{false, true, true, true}); + populateVectorizationPatterns(tiling, promotion, vectorization, + makeTileSizes(3, 3), context); - patterns.insert>( - context, SmallVector{true, true, true}); + populateVectorizationPatterns( + tiling, promotion, vectorization, makeTileSizes(5, 4), context); - patterns.insert>( - context, SmallVector{false, true, true, true, true}); + populateVectorizationPatterns( + tiling, promotion, vectorization, makeTileSizes(5, 4), context); - patterns.insert>( - context, SmallVector{false, true, true, true, true}); + patterns.push_back(std::move(tiling)); + patterns.push_back(std::move(promotion)); + patterns.push_back(std::move(vectorization)); } diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir --- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir +++ b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir @@ -1,167 +1,52 @@ // RUN: mlir-opt %s -test-conv-vectorization --cse | FileCheck %s -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0) -> (d0)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0) -> ()> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> -// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0, d1) -> ()> -// CHECK-DAG: #[[$map5:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> -// CHECK-DAG: #[[$map6:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -// CHECK-DAG: #[[$map7:.*]] = affine_map<(d0, d1, d2) -> ()> -// CHECK-DAG: #[[$map8:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2, d3, d4)> -// CHECK-DAG: #[[$map9:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -// CHECK-DAG: #[[$map10:.*]] = affine_map<(d0, d1, d2, d3) -> ()> +// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)> +// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)> +// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> +// CHECK-DAG: #[[$map5:.*]] = affine_map<(d0) -> (d0)> -func @conv_1d(%arg0: memref<3xf32>, %arg1: memref<3xf32>, %arg2: memref) { - linalg.conv_1d %arg0, %arg1, %arg2 : (memref<3xf32>, memref<3xf32>, memref) +func @conv_1d(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv_1d %arg0, %arg1, %arg2 : (memref, memref, memref) return } // CHECK-LABEL: @conv_1d -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3xf32> +// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]]], %[[cst]] : memref<3xf32>, vector<3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map0]], #[[$map0]], #[[$map1]]], iterator_types = ["reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3xf32>, vector<3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]]] : memref -// CHECK: return - -func @conv_1d_ncw(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref) { - linalg.conv_1d_ncw %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_1d_ncw -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref -// CHECK: return - - -func @conv_1d_nwc(%arg0: memref<1x3x3xf32>, %arg1: memref<1x3x3xf32>, %arg2: memref) { - linalg.conv_1d_nwc %arg0, %arg1, %arg2 : (memref<1x3x3xf32>, memref<1x3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_1d_nwc -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3xf32>, vector<3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref -// CHECK: return - -func @conv_2d(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>, %arg2: memref) { - linalg.conv_2d %arg0, %arg1, %arg2 : (memref<3x3xf32>, memref<3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_2d -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]]], %[[cst]] : memref<3x3xf32>, vector<3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map3]], #[[$map3]], #[[$map4]]], iterator_types = ["reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3xf32>, vector<3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]]] : memref -// CHECK: return - -func @conv_2d_nchw(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref) { - linalg.conv_2d_nchw %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_2d_nchw -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref -// CHECK: return - -func @conv_2d_nhwc(%arg0: memref<1x3x3x3xf32>, %arg1: memref<1x3x3x3xf32>, %arg2: memref) { - linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_2d_nhwc -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3xf32>, vector<3x3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref -// CHECK: return - -func @conv_3d(%arg0: memref<3x3x3xf32>, %arg1: memref<3x3x3xf32>, %arg2: memref) { - linalg.conv_3d %arg0, %arg1, %arg2 : (memref<3x3x3xf32>, memref<3x3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_3d -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<3x3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<3x3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<3x3x3xf32>, vector<3x3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map6]], #[[$map6]], #[[$map7]]], iterator_types = ["reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]]] : memref -// CHECK: return - -func @conv_3d_ncdhw(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref) { - linalg.conv_3d_ncdhw %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_3d_ncdhw -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map9]], #[[$map9]], #[[$map10]]], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3x3xf32>, vector<3x3x3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref -// CHECK: return - -func @conv_3d_ndhwc(%arg0: memref<1x3x3x3x3xf32>, %arg1: memref<1x3x3x3x3xf32>, %arg2: memref) { - linalg.conv_3d_ndhwc %arg0, %arg1, %arg2 : (memref<1x3x3x3x3xf32>, memref<1x3x3x3x3xf32>, memref) - return -} - -// CHECK-LABEL: @conv_3d_ndhwc -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref<1x3x3x3x3xf32> -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, vector<3x3x3x3xf32> -// CHECK: %[[v1:.*]] = vector.transfer_read %[[arg1]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]], %[[cst]] : memref<1x3x3x3x3xf32>, vector<3x3x3x3xf32> -// CHECK: %[[v2:.*]] = vector.contract {indexing_maps = [#[[$map9]], #[[$map9]], #[[$map10]]], iterator_types = ["reduction", "reduction", "reduction", "reduction"]} %[[v0]], %[[v1]], %[[cst]] : vector<3x3x3x3xf32>, vector<3x3x3x3xf32> into f32 -// CHECK: store %[[v2]], %[[arg2]][%[[c0]], %[[c0]], %[[c0]], %[[c0]], %[[c0]]] : memref -// CHECK: return +// CHECK-DAG: %[[c12:.*]] = constant 12 : index +// CHECK-DAG: %[[c4:.*]] = constant 4 : index +// CHECK-DAG: %[[cst:.*]] = constant 0.000000e+00 : f32 +// CHECK-DAG: %[[c3:.*]] = constant 3 : index +// CHECK-DAG: %[[c0:.*]] = constant 0 : index +// CHECK-DAG: %[[c1:.*]] = constant 1 : index +// CHECK: %[[v0:.*]] = dim %[[arg1]], %[[c0]] : memref +// CHECK: %[[v1:.*]] = dim %[[arg2]], %[[c0]] : memref +// CHECK: %[[v2:.*]] = dim %[[arg0]], %[[c0]] : memref +// CHECK: %[[v3:.*]] = alloc(%[[c12]]) : memref +// CHECK: %[[v4:.*]] = alloc(%[[c12]]) : memref +// CHECK: %[[v5:.*]] = alloc(%[[c4]]) : memref +// CHECK: %[[v6:.*]] = std.view %[[v3]][%[[c0]]][] : memref to memref<3xf32> +// CHECK: %[[v7:.*]] = std.view %[[v4]][%[[c0]]][] : memref to memref<3xf32> +// CHECK: %[[v8:.*]] = std.view %[[v5]][%[[c0]]][] : memref to memref<1xf32> +// CHECK: scf.for %[[arg3:.*]] = %[[c0]] to %[[v1]] step %[[c1]] { +// CHECK: %[[v9:.*]] = affine.min #[[$map0]](%[[arg3]])[%[[v1]]] +// CHECK: %[[v10:.*]] = subview %[[arg2]][%[[arg3]]] [%[[v9]]] [1] : memref to memref +// CHECK: %[[v11:.*]] = subview %[[v8]][0] [%[[v9]]] [1] : memref<1xf32> to memref +// CHECK: scf.for %[[arg4:.*]] = %[[c0]] to %[[v0]] step %[[c3]] { +// CHECK: %[[v12:.*]] = affine.apply #[[$map2]](%[[arg3]], %[[arg4]]) +// CHECK: %[[v13:.*]] = affine.min #[[$map3]](%[[arg3]], %[[arg4]])[%[[v2]]] +// CHECK: %[[v14:.*]] = subview %arg0[%12] [%13] [1] : memref to memref +// CHECK: %[[v15:.*]] = affine.min #[[$map4]](%arg4)[%0] +// CHECK: %[[v16:.*]] = subview %[[arg1]][%[[arg4]]] [%[[v15]]] [1] : memref to memref +// CHECK: %[[v17:.*]] = subview %[[v6]][0] [%[[v13]]] [1] : memref<3xf32> to memref +// CHECK: %[[v19:.*]] = vector.transfer_read %[[v6]][%[[c0]]], %[[cst]] {masked = [false]} : memref<3xf32>, vector<3xf32> +// CHECK: %[[v20:.*]] = vector.transfer_read %[[v7]][%[[c0]]], %[[cst]] {masked = [false]} : memref<3xf32>, vector<3xf32> +// CHECK: %[[v21:.*]] = mulf %[[v19]], %[[v20]] : vector<3xf32> +// CHECK: %[[v22:.*]] = vector.reduction "add", %[[v21]], %[[cst]] : vector<3xf32> into f32 +// CHECK: store %[[v22]], %[[v8]][%[[c0]]] : memref<1xf32> +// CHECK: scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] { +// CHECK: %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref +// CHECK: store %[[v23]], %[[v10]][%[[arg5]]] : memref diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp --- a/mlir/test/lib/Transforms/TestConvVectorization.cpp +++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp @@ -1,4 +1,4 @@ -//===- TestConvVectorization.cpp - Linalg to Vector dialect conversion ----===// +//===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,19 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Vector/VectorTransforms.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" using namespace mlir; +using namespace vector; namespace { /// A pass converting MLIR Linalg ops into Vector ops. @@ -19,8 +27,10 @@ void runOnOperation() override; void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); registry.insert(); + registry.insert(); + registry.insert(); registry.insert(); } }; @@ -32,15 +42,70 @@ ConversionTarget target(*context); target.addLegalDialect(); + VectorDialect>(); target.addLegalOp(); target.addLegalOp(); - OwningRewritePatternList patterns; - linalg::populateConvVectorizationPatterns(context, patterns); + SmallVector stage1Patterns; + linalg::populateConvVectorizationPatterns(context, stage1Patterns); - if (failed(applyPartialConversion(module, target, patterns))) - return signalPassFailure(); + OwningRewritePatternList stage2Patterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + stage2Patterns.insert(context); + + auto stage3Transforms = [](Operation *op) { + PassManager pm(op->getContext()); + pm.addPass(createLoopInvariantCodeMotionPass()); + if (failed(pm.run(cast(op)))) + llvm_unreachable("Unexpected failure in cleanup pass pipeline."); + op->walk([](FuncOp func) { + promoteSingleIterationLoops(func); + linalg::hoistViewAllocOps(func); + linalg::hoistRedundantVectorTransfers(func); + }); + return success(); + }; + + linalg::applyStagedPatterns(module, stage1Patterns, stage2Patterns, + stage3Transforms); + + //===--------------------------------------------------------------------===// + // Post staged patterns transforms + //===--------------------------------------------------------------------===// + + VectorTransformsOptions vectorTransformsOptions{ + VectorContractLowering::Dot, VectorTransposeLowering::EltWise}; + + OwningRewritePatternList vectorTransferPatterns; + // Pattern is not applied because rank-reducing vector transfer is not yet + // supported as can be seen in splitFullAndPartialTransferPrecondition, + // VectorTransforms.cpp + vectorTransferPatterns.insert( + context, vectorTransformsOptions); + applyPatternsAndFoldGreedily(module, vectorTransferPatterns); + + // Programmatic controlled lowering of linalg.copy and linalg.fill. + PassManager pm(context); + pm.addPass(createConvertLinalgToLoopsPass()); + if (failed(pm.run(module))) + llvm_unreachable("Unexpected failure in linalg to loops pass."); + + // Programmatic controlled lowering of vector.contract only. + OwningRewritePatternList vectorContractLoweringPatterns; + populateVectorContractLoweringPatterns(vectorContractLoweringPatterns, + context, vectorTransformsOptions); + applyPatternsAndFoldGreedily(module, vectorContractLoweringPatterns); + + // Programmatic controlled lowering of vector.transfer only. + OwningRewritePatternList vectorToLoopsPatterns; + populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, + VectorTransferToSCFOptions()); + applyPatternsAndFoldGreedily(module, vectorToLoopsPatterns); + + // Ensure we drop the marker in the end. + module.walk([](linalg::LinalgOp op) { + op.removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker); + }); } namespace mlir {