diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,7 +1,6 @@ // RUN: mlir-opt %s -test-vector-contraction-lowering | FileCheck %s // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT -// RUN: mlir-opt %s -test-vector-contraction-lowering=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT // RUN: mlir-opt %s -test-vector-contraction-lowering=vector-parallel-arith=1 | FileCheck %s --check-prefix=PARALLEL #dotp_accesses = [ @@ -1182,32 +1181,6 @@ return %0 : vector<3x2xf32> } -// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered -// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32> -// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] -func.func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>) --> vector<4x4xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> - return %0 : vector<4x4xf32> -} - -// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered -// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, -// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32> -// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] -func.func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>) --> vector<3x4xf32> -{ - %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 - : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> - return %0 : vector<3x4xf32> -} - // PARALLEL-LABEL: func @parrallel_contract_lowering // PARALLEL: %[[E0:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> // PARALLEL: %[[E1:.*]] = vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" @@ -136,11 +136,6 @@ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), llvm::cl::init(false)}; - Option lowerToFilterOuterProduct{ - *this, "vector-filter-outerproduct", - llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " - "vectors of size 4."), - llvm::cl::init(false)}; Option lowerToParallelArith{ *this, "vector-parallel-arith", llvm::cl::desc("Lower vector.contract to elementwise vector ops."), @@ -159,22 +154,6 @@ return; } - // Test on one pattern in isolation. - if (lowerToFilterOuterProduct) { - VectorContractLowering lowering = VectorContractLowering::OuterProduct; - VectorTransformsOptions options{lowering}; - patterns.add( - options, &getContext(), /*benefit=*/1, [](vector::ContractionOp op) { - // Only lowers vector.contract where the lhs as a type vector - // where M is not 4. - if (op.getRhsType().getShape()[0] == 4) - return failure(); - return success(); - }); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - return; - } - if (lowerToParallelArith) { vector::populateVectorContractLoweringPatterns( patterns,