diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -127,12 +127,18 @@ : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } ContractionOpToMatmulOpLowering( vector::VectorTransformsOptions vectorTransformsOptions, - MLIRContext *context) + MLIRContext *context, FilterConstraintType constraint = defaultFilter) : OpRewritePattern(context), - vectorTransformsOptions(vectorTransformsOptions) {} + vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} LogicalResult match(vector::ContractionOp op) const override; void rewrite(vector::ContractionOp op, @@ -141,6 +147,7 @@ private: /// Options to control the vector patterns. vector::VectorTransformsOptions vectorTransformsOptions; + FilterConstraintType filter; }; /// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul @@ -162,11 +169,18 @@ : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + ContractionOpToOuterProductOpLowering( vector::VectorTransformsOptions vectorTransformsOptions, - MLIRContext *context) + MLIRContext *context, FilterConstraintType constraint = defaultFilter) : OpRewritePattern(context), - vectorTransformsOptions(vectorTransformsOptions) {} + vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} LogicalResult match(vector::ContractionOp op) const override; void rewrite(vector::ContractionOp op, @@ -175,6 +189,7 @@ private: /// Options to control the vector patterns. vector::VectorTransformsOptions vectorTransformsOptions; + FilterConstraintType filter; }; /// Progressive lowering of ContractionOp. @@ -194,11 +209,18 @@ class ContractionOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } ContractionOpLowering(vector::VectorTransformsOptions vectorTransformsOptions, - MLIRContext *context) + MLIRContext *context, + FilterConstraintType constraint = defaultFilter) : OpRewritePattern(context), - vectorTransformsOptions(vectorTransformsOptions) {} + vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override; @@ -206,6 +228,7 @@ private: /// Options to control the vector patterns. vector::VectorTransformsOptions vectorTransformsOptions; + FilterConstraintType filter; // Lower one parallel dimension. Value lowerParallel(vector::ContractionOp op, int64_t lhsIndex, int64_t rhsIndex, PatternRewriter &rewriter) const; diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1581,6 +1581,9 @@ vector::VectorContractLowering::Matmul) return failure(); + if (failed(filter(op))) + return failure(); + auto iteratorTypes = op.iterator_types().getValue(); if (!isParallelIterator(iteratorTypes[0]) || !isParallelIterator(iteratorTypes[1]) || @@ -1647,6 +1650,9 @@ vector::VectorContractLowering::OuterProduct) return failure(); + if (failed(filter(op))) + return failure(); + // Determine if the parallel/reduction structure matches something // that can be expressed a reduction_size unrolled sequence. using MapList = ArrayRef>; @@ -1808,6 +1814,10 @@ // TODO: implement masks. if (llvm::size(op.masks()) != 0) return failure(); + + if (failed(filter(op))) + return failure(); + // TODO: support mixed mode contract lowering. if (op.getLhsType().getElementType() != getElementTypeOrSelf(op.getAccType()) || diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,6 +1,7 @@ // RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s // RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX // RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-filter-outerproduct=1 | FileCheck %s --check-prefix=FILTEROUTERPRODUCT #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -1029,3 +1030,33 @@ : vector<2x1xf32>, vector<1x3xf32> into vector<3x2xf32> return %0 : vector<3x2xf32> } + +// FILTEROUTERPRODUCT-LABEL: func @matmul_4_filtered +// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<4x4xf32>, +// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, +// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<4x4xf32> +// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] +func @matmul_4_filtered(%arg0: vector<4x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<4x4xf32>) +-> vector<4x4xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<4x4xf32>, vector<4x4xf32> into vector<4x4xf32> + return %0 : vector<4x4xf32> +} + +// FILTEROUTERPRODUCT-LABEL: func @matmul_4_not_filtered +// FILTEROUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<3x4xf32>, +// FILTEROUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x4xf32>, +// FILTEROUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<3x4xf32> +// FILTEROUTERPRODUCT: %[[c0:.*]] = vector.contract {{{.*}}} %[[A]], %[[B]], %[[C]] +func @matmul_4_not_filtered(%arg0: vector<3x4xf32>, %arg1: vector<4x4xf32>, %arg2: vector<3x4xf32>) +-> vector<3x4xf32> +{ + %0 = vector.contract #matmat_trait_0 %arg0, %arg1, %arg2 + : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> + return %0 : vector<3x4xf32> +} + + + + diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -59,6 +59,11 @@ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), llvm::cl::init(false)}; + Option lowerToFilterOuterProduct{ + *this, "vector-filter-outerproduct", + llvm::cl::desc("Lower vector.contract to vector.outerproduct but not for " + "vectors of size 4."), + llvm::cl::init(false)}; void runOnFunction() override { OwningRewritePatternList patterns; @@ -73,6 +78,22 @@ return; } + // Test on one pattern in isolation. + if (lowerToFilterOuterProduct) { + VectorContractLowering lowering = VectorContractLowering::OuterProduct; + VectorTransformsOptions options{lowering}; + patterns.insert( + options, &getContext(), [](vector::ContractionOp op) { + // Only lowers vector.contract where the lhs as a type vector + // where M is not 4. + if (op.getRhsType().getShape()[0] == 4) + return failure(); + return success(); + }); + applyPatternsAndFoldGreedily(getFunction(), patterns); + return; + } + // Test on all contract lowering patterns. VectorContractLowering contractLowering = VectorContractLowering::Dot; if (lowerToFlatMatrix)