diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -208,9 +208,8 @@ : OpRewritePattern(context), vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} - LogicalResult match(vector::ContractionOp op) const override; - void rewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; private: /// Options to control the vector patterns. @@ -250,9 +249,8 @@ : OpRewritePattern(context), vectorTransformsOptions(vectorTransformsOptions), filter(constraint) {} - LogicalResult match(vector::ContractionOp op) const override; - void rewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; private: /// Options to control the vector patterns. diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1576,16 +1576,14 @@ // /// This only kicks in when VectorTransformsOptions is set to OuterProduct and /// the vector.contract op is a row-major matrix multiply. -LogicalResult -ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const { +LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { // TODO: implement masks if (llvm::size(op.masks()) != 0) return failure(); - if (vectorTransformsOptions.vectorContractLowering != vector::VectorContractLowering::Matmul) return failure(); - if (failed(filter(op))) return failure(); @@ -1598,11 +1596,10 @@ if (!isRowMajorMatmul(op.indexing_maps())) return failure(); - return success(); -} + Type elementType = op.getLhsType().getElementType(); + if (!elementType.isIntOrFloat()) + return failure(); -void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const { VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); int64_t lhsRows = lhsType.getDimSize(0); @@ -1622,12 +1619,12 @@ lhsColumns, rhsColumns); mul = rewriter.create(op.getLoc(), op.acc().getType(), mul); - Type elementType = op.getLhsType().getElementType(); - assert(elementType.isIntOrFloat()); if (elementType.isa()) rewriter.replaceOpWithNewOp(op, op.acc(), mul); else rewriter.replaceOpWithNewOp(op, op.acc(), mul); + + return success(); } /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul @@ -1645,8 +1642,8 @@ /// /// This only kicks in when VectorTransformsOptions is set to OuterProduct but /// otherwise supports any layout permutation of the matrix-multiply. -LogicalResult -ContractionOpToOuterProductOpLowering::match(vector::ContractionOp op) const { +LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { // TODO: implement masks if (llvm::size(op.masks()) != 0) return failure(); @@ -1658,50 +1655,6 @@ if (failed(filter(op))) return failure(); - // Determine if the parallel/reduction structure matches something - // that can be expressed a reduction_size unrolled sequence. - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - AffineExpr m, n, k; - bindDims(op.getContext(), m, n, k); - auto iteratorTypes = op.iterator_types().getValue(); - SmallVector maps = op.getIndexingMaps(); - if (isParallelIterator(iteratorTypes[0]) && - isParallelIterator(iteratorTypes[1]) && - isReductionIterator(iteratorTypes[2])) { - // - // Two outer parallel, one inner reduction (matmat flavor). - // When lowering to outerproduct we can support all permutations. - // - if (maps != infer({{m, k}, {k, n}, {m, n}}) && - maps != infer({{m, k}, {n, k}, {m, n}}) && - maps != infer({{k, m}, {k, n}, {m, n}}) && - maps != infer({{k, m}, {n, k}, {m, n}}) && - maps != infer({{m, k}, {k, n}, {n, m}}) && - maps != infer({{m, k}, {n, k}, {n, m}}) && - maps != infer({{k, m}, {k, n}, {n, m}}) && - maps != infer({{k, m}, {n, k}, {n, m}})) - return failure(); - return success(); - } else if (isParallelIterator(iteratorTypes[0]) && - isReductionIterator(iteratorTypes[1])) { - // - // One outer parallel, one inner reduction (matvec flavor) - // See if a series of AXPY operations chained through FMA operations - // could replace the default DOT implementation. - // - if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec - maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec - maps != infer({{n}, {m, n}, {m}}) && // vec-mat - maps != infer({{n}, {n, m}, {m}})) // vec-mat-trans - return failure(); - return success(); - } - return failure(); -} - -void ContractionOpToOuterProductOpLowering::rewrite( - vector::ContractionOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); int64_t reductionSize = 0; VectorType lhsType = op.getLhsType(); @@ -1759,13 +1712,14 @@ Value tmp = lhs; lhs = rewriter.create(loc, rhs, perm); rhs = tmp; + } else { + return failure(); } - } else { + } else if (isParallelIterator(iteratorTypes[0]) && + isReductionIterator(iteratorTypes[1])) { // // One outer parallel, one inner reduction (matvec flavor) // - assert(isParallelIterator(iteratorTypes[0]) && - isReductionIterator(iteratorTypes[1])); if (maps == infer({{m, n}, {n}, {m}})) { // Case mat-vec: transpose. reductionSize = lhsType.getDimSize(1); @@ -1782,7 +1736,11 @@ // Case vec-mat-trans: swap and ready to go. reductionSize = lhsType.getDimSize(0); std::swap(lhs, rhs); + } else { + return failure(); } + } else { + return failure(); } assert(reductionSize > 0); @@ -1793,6 +1751,7 @@ res = rewriter.create(op.getLoc(), a, b, res); } rewriter.replaceOp(op, res); + return success(); } /// Progressive lowering of ContractionOp. @@ -1815,7 +1774,6 @@ LogicalResult ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks. if (llvm::size(op.masks()) != 0) return failure(); @@ -1832,11 +1790,11 @@ // TODO: implement benefits, cost models. MLIRContext *ctx = op.getContext(); ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx); - if (succeeded(pat1.match(op))) - return failure(); + if (succeeded(pat1.matchAndRewrite(op, rewriter))) + return success(); ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx); - if (succeeded(pat2.match(op))) - return failure(); + if (succeeded(pat2.matchAndRewrite(op, rewriter))) + return success(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap();