diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -49,11 +49,6 @@ void populatePadTensorTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options); -/// [DEPRECATED] Populate patterns for vectorization of all ConvN-D ops. -void populateConvVectorizationPatterns( - MLIRContext *context, SmallVectorImpl &patterns, - ArrayRef tileSizes); - /// Populate patterns for vectorizing low-D convolution ops. This is a step in /// progressive lowering for convolution ops, it assume high-D convolution ops /// were decomposed previously. @@ -1246,54 +1241,6 @@ PatternRewriter &rewriter) const override; }; -/// Converts Convolution op into vector contraction. -/// -/// Conversion expects ConvOp to have dimensions marked in the *mask* as -/// false of size 1. This ensures that the ConvOp can be lowered to vector -/// contraction of dimensions marked in the *mask* as true. -/// -/// A good example for vectorization is ConvNHWCOp which is 2D Conv op -/// with channels as the last dimension. Let's vectorize last 3 dimensions. -/// The initial op definition looks like this: -/// ``` -/// linalg.conv_2d_nhwc %arg0, %arg1, %arg2 : -/// (memref<1x3x3x3xf32>, memref<1x3x3x3xf32>, memref) -/// ``` -/// This op can be expressed as a dot product between %arg0 (input) and -/// %arg1 (kernel) which is written into first entry of %arg2 (output). This is -/// the ConvOp this pass expects and converts into: -/// ``` -/// #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> -/// #map1 = affine_map<(d0, d1, d2) -> ()> -/// ..... -/// %0 = vector.transfer_read %arg0[%c0, %c0, %c0, %c0], %c0_f32 -/// : memref<1x3x3x3xf32>, vector<3x3x3xf32> -/// %1 = vector.transfer_read %arg1[%c0, %c0, %c0, %c0], %c0_f32 -/// : memref<1x3x3x3xf32>, vector<3x3x3xf32> -/// %2 = vector.contract {indexing_maps = [#map0, #map0, #map1], -/// iterator_types = ["reduction", "reduction", "reduction"]} %0, %1, -/// %c0_f32 : vector<3x3x3xf32>, vector<3x3x3xf32> into f32 -/// store %2, %arg2[%c0, %c0, %c0, %c0] : memref -/// ``` -/// where first 2 operations read input and kernel memory buffers into vectors. -/// Subsequently, they are contracted together and the result is written to -/// the first entry of the output buffer. -template -class ConvOpVectorization : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - SmallVector mask; - -public: - ConvOpVectorization(MLIRContext *context, const SmallVector &msk) - : OpRewritePattern(context) { - assert(msk.size() == N && "Mask size does not match rank"); - this->mask = msk; - } - - LogicalResult matchAndRewrite(ConvOp minOp, - PatternRewriter &rewriter) const override; -}; - /// Rewrite a TiledLoopOp with bounds/step that potentially do not divide evenly /// into a TiledLoopOp where the step divides the iteration space evenly, /// followed by another TiledLoopOp for the last (partial) iteration (if any). diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1098,134 +1098,6 @@ patterns.getContext(), baseBenefit.getBenefit() + 1); } -// TODO: cleanup all the convolution vectorization patterns. -template -LogicalResult ConvOpVectorization::matchAndRewrite( - ConvOp op, PatternRewriter &rewriter) const { - Location loc = op.getLoc(); - MLIRContext *context = op.getContext(); - - OpOperand *input = op.getInputOperand(0); - OpOperand *kernel = op.getInputOperand(1); - OpOperand *output = op.getOutputOperand(0); - ArrayRef inShape = op.getShape(input); - ArrayRef kShape = op.getShape(kernel); - - if (llvm::any_of(inShape, ShapedType::isDynamic) || - llvm::any_of(kShape, ShapedType::isDynamic)) - return failure(); - - SmallVector mapping; - SmallVector vectorDims; - // Fail to apply when the size of not vectorized dimension is not 1. - for (unsigned i = 0; i < N; i++) { - if (!mask[i] && (inShape[i] != 1 || kShape[i] != 1)) - return failure(); - - if (mask[i] && inShape[i] != kShape[i]) - return failure(); - - if (mask[i]) { - mapping.push_back(getAffineDimExpr(i, context)); - vectorDims.push_back(inShape[i]); - } - } - - int64_t rank = op.getRank(input); - int64_t numDims = mapping.size(); - Type elemType = getElementTypeOrSelf(input->get()); - - auto map = AffineMap::get(rank, 0, mapping, context); - SmallVector zeros(rank, - rewriter.create(loc, 0)); - auto vecType = VectorType::get(vectorDims, elemType); - - auto inputVec = rewriter.create( - loc, vecType, input->get(), zeros, map); - auto kernelVec = rewriter.create( - loc, vecType, kernel->get(), zeros, map); - - auto acc = rewriter.create(loc, elemType, - rewriter.getZeroAttr(elemType)); - - std::array indexingMaps{ - AffineMap::getMultiDimIdentityMap(numDims, context), - AffineMap::getMultiDimIdentityMap(numDims, context), - AffineMap::get(numDims, 0, {}, context)}; - - std::vector iteratorTypes(numDims, "reduction"); - - auto result = rewriter.create( - loc, inputVec, kernelVec, acc, - rewriter.getAffineMapArrayAttr(indexingMaps), - rewriter.getStrArrayAttr(iteratorTypes)); - - rewriter.create(loc, result, output->get(), - ValueRange(zeros)); - rewriter.eraseOp(op); - return success(); -} - -/// Inserts tiling, promotion and vectorization pattern for ConvOp -/// conversion into corresponding pattern lists. -template -static void populateVectorizationPatterns( - RewritePatternSet &tilingPatterns, RewritePatternSet &promotionPatterns, - RewritePatternSet &vectorizationPatterns, ArrayRef tileSizes) { - auto *context = tilingPatterns.getContext(); - if (tileSizes.size() < N) - return; - - constexpr static StringRef kTiledMarker = "TILED"; - constexpr static StringRef kPromotedMarker = "PROMOTED"; - tilingPatterns.add( - ConvOp::getOperationName(), context, - LinalgTilingOptions().setTileSizes(tileSizes), - LinalgTransformationFilter(ArrayRef{}, - StringAttr::get(context, kTiledMarker))); - - promotionPatterns.add>( - context, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true), - LinalgTransformationFilter(StringAttr::get(context, kTiledMarker), - StringAttr::get(context, kPromotedMarker))); - - SmallVector mask(N); - int offset = tileSizes.size() - N; - std::transform(tileSizes.begin() + offset, tileSizes.end(), mask.begin(), - [](int64_t i) -> bool { return i > 1; }); - - vectorizationPatterns.add>(context, mask); -} - -void mlir::linalg::populateConvVectorizationPatterns( - MLIRContext *context, SmallVectorImpl &patterns, - ArrayRef tileSizes) { - RewritePatternSet tiling(context); - RewritePatternSet promotion(context); - RewritePatternSet vectorization(context); - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes); - - populateVectorizationPatterns(tiling, promotion, vectorization, - tileSizes); - - populateVectorizationPatterns(tiling, promotion, - vectorization, tileSizes); - - populateVectorizationPatterns(tiling, promotion, - vectorization, tileSizes); - - populateVectorizationPatterns( - tiling, promotion, vectorization, tileSizes); - - patterns.push_back(std::move(tiling)); - patterns.push_back(std::move(promotion)); - patterns.push_back(std::move(vectorization)); -} - //----------------------------------------------------------------------------// // Forwarding patterns //----------------------------------------------------------------------------// diff --git a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir b/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir deleted file mode 100644 --- a/mlir/test/Conversion/LinalgToVector/linalg-to-vector.mlir +++ /dev/null @@ -1,53 +0,0 @@ -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" --cse -split-input-file -// | FileCheck %s - -// CHECK-DAG: #[[$map0:.*]] = affine_map<(d0)[s0] -> (1, -d0 + s0)> -// CHECK-DAG: #[[$map1:.*]] = affine_map<(d0)[s0] -> (d0 + s0)> -// CHECK-DAG: #[[$map2:.*]] = affine_map<(d0, d1) -> (d0 + d1)> -// CHECK-DAG: #[[$map3:.*]] = affine_map<(d0, d1)[s0] -> (3, -d0 - d1 + s0)> -// CHECK-DAG: #[[$map4:.*]] = affine_map<(d0)[s0] -> (3, -d0 + s0)> - -// CHECK-LABEL: @conv_1d -// CHECK-SAME: %[[arg0:[a-zA-Z0-9]+]]: memref -// CHECK-SAME: %[[arg1:[a-zA-Z0-9]+]]: memref -// CHECK-SAME: %[[arg2:[a-zA-Z0-9]+]]: memref, %arg1: memref, %arg2: memref) { -// CHECK-DAG: %[[c12:.*]] = arith.constant 12 : index -// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index -// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK-DAG: %[[c3:.*]] = arith.constant 3 : index -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK: %[[v0:.*]] = memref.dim %[[arg1]], %[[c0]] : memref -// CHECK: %[[v1:.*]] = memref.dim %[[arg2]], %[[c0]] : memref -// CHECK: %[[v2:.*]] = memref.dim %[[arg0]], %[[c0]] : memref -// CHECK: %[[v3:.*]] = memref.alloc(%[[c12]]) : memref -// CHECK: %[[v4:.*]] = memref.alloc(%[[c12]]) : memref -// CHECK: %[[v5:.*]] = memref.alloc(%[[c4]]) : memref -// CHECK: %[[v6:.*]] = memref.view %[[v3]][%[[c0]]][] : memref to memref<3xf32> -// CHECK: %[[v7:.*]] = memref.view %[[v4]][%[[c0]]][] : memref to memref<3xf32> -// CHECK: %[[v8:.*]] = memref.view %[[v5]][%[[c0]]][] : memref to memref<1xf32> -// CHECK: scf.for %[[arg3:.*]] = %[[c0]] to %[[v1]] step %[[c1]] { -// CHECK: %[[v9:.*]] = affine.min #[[$map0]](%[[arg3]])[%[[v1]]] -// CHECK: %[[v10:.*]] = subview %[[arg2]][%[[arg3]]] [%[[v9]]] [1] : memref to memref -// CHECK: %[[v11:.*]] = subview %[[v8]][0] [%[[v9]]] [1] : memref<1xf32> to memref -// CHECK: scf.for %[[arg4:.*]] = %[[c0]] to %[[v0]] step %[[c3]] { -// CHECK: %[[v12:.*]] = affine.apply #[[$map2]](%[[arg3]], %[[arg4]]) -// CHECK: %[[v13:.*]] = affine.min #[[$map3]](%[[arg3]], %[[arg4]])[%[[v2]]] -// CHECK: %[[v14:.*]] = subview %arg0[%12] [%13] [1] : memref to memref -// CHECK: %[[v15:.*]] = affine.min #[[$map4]](%arg4)[%0] -// CHECK: %[[v16:.*]] = subview %[[arg1]][%[[arg4]]] [%[[v15]]] [1] : memref to memref -// CHECK: %[[v17:.*]] = subview %[[v6]][0] [%[[v13]]] [1] : memref<3xf32> to memref -// CHECK: %[[v19:.*]] = vector.transfer_read %[[v6]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32> -// CHECK: %[[v20:.*]] = vector.transfer_read %[[v7]][%[[c0]]], %[[cst]] {in_bounds = [true]} : memref<3xf32>, vector<3xf32> -// CHECK: %[[v21:.*]] = arith.mulf %[[v19]], %[[v20]] : vector<3xf32> -// CHECK: %[[v22:.*]] = vector.reduction "add", %[[v21]], %[[cst]] : vector<3xf32> into f32 -// CHECK: store %[[v22]], %[[v8]][%[[c0]]] : memref<1xf32> -// CHECK: scf.for %[[arg5:.*]] = %[[c0]] to %[[v9]] step %[[c1]] { -// CHECK: %[[v23:.*]] = load %[[v11]][%[[arg5]]] : memref -// CHECK: store %[[v23]], %[[v10]][%[[arg5]]] : memref - linalg.conv_1d ins(%arg0, %arg1 : memref, memref) - outs(%arg2 : memref) - return -} - diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-call.mlir @@ -9,17 +9,6 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="tile-sizes=4" \ -// RUN: -test-conv-vectorization="tile-sizes=1,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - func private @print_memref_f32(memref<*xf32>) // Creates and returns a 1-D buffer of size %s1 filled with the value %f diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-1d-nwc-wcf-call.mlir @@ -9,17 +9,6 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,4" \ -// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - func private @print_memref_f32(memref<*xf32>) // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-call.mlir @@ -9,17 +9,6 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,2" \ -// RUN: -test-conv-vectorization="tile-sizes=1,1,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - func private @print_memref_f32(memref<*xf32>) // Creates and returns a 2-D buffer of size (%s1, %s2) filled with the value %f diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-2d-nhwc-hwcf-call.mlir @@ -9,17 +9,6 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,3,3,2" \ -// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - func private @print_memref_f32(memref<*xf32>) // Creates and returns 4-D buffer of size (%s1, %s2, %s3, %s4) filled with the value %f diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-call.mlir @@ -9,17 +9,6 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="tile-sizes=2,2,2" \ -// RUN: -test-conv-vectorization="tile-sizes=1,1,1,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - func private @print_memref_f32(memref<*xf32>) // Creates and returns 3-D buffer of size (%s1, %s2, %s3) filled with the value %f diff --git a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir --- a/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir +++ b/mlir/test/Integration/Dialect/Linalg/CPU/test-conv-3d-ndhwc-dhwcf-call.mlir @@ -9,17 +9,6 @@ // RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ // RUN: | FileCheck %s -// RUN: mlir-opt %s -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - -// RUN: mlir-opt %s -linalg-tile="tile-sizes=0,5,5,5" \ -// RUN: -test-conv-vectorization="tile-sizes=1,1,1,1,1,3,3,3,3" -convert-linalg-to-llvm -lower-affine -convert-scf-to-std -convert-vector-to-llvm --convert-memref-to-llvm -convert-std-to-llvm -reconcile-unrealized-casts | \ -// RUN: mlir-cpu-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_integration_test_dir/libmlir_runner_utils%shlibext \ -// RUN: | FileCheck %s - func private @print_memref_f32(memref<*xf32>) // Creates and returns 5-D buffer of size (%s1, %s2, %s3, %s4, %s5) filled with the value %f diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,7 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses TestComprehensiveBufferize.cpp - TestConvVectorization.cpp TestLinalgCodegenStrategy.cpp TestLinalgDistribution.cpp TestLinalgElementwiseFusion.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp b/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp deleted file mode 100644 --- a/mlir/test/lib/Dialect/Linalg/TestConvVectorization.cpp +++ /dev/null @@ -1,143 +0,0 @@ -//===- TestConvVectorization.cpp - Vectorization of Conv ops --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" -#include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/SCF/Transforms.h" -#include "mlir/Dialect/Vector/VectorTransforms.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "mlir/Transforms/LoopUtils.h" -#include "mlir/Transforms/Passes.h" - -using namespace mlir; -using namespace vector; - -namespace { -/// A pass converting MLIR Linalg ops into Vector ops. -class TestConvVectorization - : public PassWrapper> { -public: - StringRef getArgument() const final { return "test-conv-vectorization"; } - StringRef getDescription() const final { - return "Test vectorization of convolutions"; - } - TestConvVectorization() = default; - TestConvVectorization(const TestConvVectorization &) {} - explicit TestConvVectorization(ArrayRef tileSizesParam) { - tileSizes = tileSizesParam; - } - - void runOnOperation() override; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - registry.insert(); - } - - ListOption tileSizes{ - *this, "tile-sizes", llvm::cl::desc("Vectorization sizes."), - llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated}; -}; -} // namespace - -void TestConvVectorization::runOnOperation() { - MLIRContext *context = &getContext(); - ModuleOp module = getOperation(); - - ConversionTarget target(*context); - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); - - SmallVector stage1Patterns; - linalg::populateConvVectorizationPatterns(context, stage1Patterns, tileSizes); - SmallVector frozenStage1Patterns; - llvm::move(stage1Patterns, std::back_inserter(frozenStage1Patterns)); - - RewritePatternSet stage2Patterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - scf::populateSCFForLoopCanonicalizationPatterns(stage2Patterns); - - auto stage3Transforms = [](Operation *op) { - PassManager pm(op->getContext()); - pm.addPass(createLoopInvariantCodeMotionPass()); - if (failed(pm.run(cast(op)))) - llvm_unreachable("Unexpected failure in cleanup pass pipeline."); - op->walk([](FuncOp func) { - promoteSingleIterationLoops(func); - linalg::hoistRedundantVectorTransfers(func); - }); - return success(); - }; - - (void)linalg::applyStagedPatterns(module, frozenStage1Patterns, - std::move(stage2Patterns), - stage3Transforms); - - //===--------------------------------------------------------------------===// - // Post staged patterns transforms - //===--------------------------------------------------------------------===// - - VectorTransformsOptions vectorTransformOptions{ - VectorContractLowering::Dot, VectorMultiReductionLowering::InnerParallel, - VectorTransposeLowering::EltWise}; - - RewritePatternSet vectorTransferPatterns(context); - // Pattern is not applied: rank-reducing vector transfer is not yet supported - // (see: splitFullAndPartialTransferPrecondition in VectorTransforms.cpp). - vectorTransferPatterns.add( - context, vectorTransformOptions); - (void)applyPatternsAndFoldGreedily(module, std::move(vectorTransferPatterns)); - - // Programmatic controlled lowering of linalg.copy and linalg.fill. - PassManager pm(context); - pm.addNestedPass(createConvertLinalgToLoopsPass()); - if (failed(pm.run(module))) - llvm_unreachable("Unexpected failure in linalg to loops pass."); - - // Programmatic controlled lowering of vector.contract only. - RewritePatternSet vectorContractLoweringPatterns(context); - populateVectorBroadcastLoweringPatterns(vectorContractLoweringPatterns); - populateVectorContractLoweringPatterns(vectorContractLoweringPatterns, - vectorTransformOptions); - populateVectorMaskOpLoweringPatterns(vectorContractLoweringPatterns); - populateVectorShapeCastLoweringPatterns(vectorContractLoweringPatterns); - populateVectorTransposeLoweringPatterns(vectorContractLoweringPatterns, - vectorTransformOptions); - (void)applyPatternsAndFoldGreedily(module, - std::move(vectorContractLoweringPatterns)); - - // Programmatic controlled lowering of vector.transfer only. - RewritePatternSet vectorToLoopsPatterns(context); - populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, - VectorTransferToSCFOptions()); - (void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns)); - - // Ensure we drop the marker in the end. - module.walk([](linalg::LinalgOp op) { - op->removeAttr(linalg::LinalgTransforms::kLinalgTransformMarker); - }); -} - -namespace mlir { -namespace test { -void registerTestConvVectorization() { - PassRegistration(); -} -} // namespace test -} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -66,7 +66,6 @@ void registerTestCallGraphPass(); void registerTestComprehensiveFunctionBufferize(); void registerTestConstantFold(); -void registerTestConvVectorization(); void registerTestGpuSerializeToCubinPass(); void registerTestGpuSerializeToHsacoPass(); void registerTestDataLayoutQuery(); @@ -162,7 +161,6 @@ mlir::test::registerTestGpuSerializeToHsacoPass(); #endif mlir::test::registerTestComprehensiveFunctionBufferize(); - mlir::test::registerTestConvVectorization(); mlir::test::registerTestDecomposeCallGraphTypes(); mlir::test::registerTestDataLayoutQuery(); mlir::test::registerTestDominancePass();