diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -25,11 +25,18 @@ class OwningRewritePatternList; namespace vector { +/// Enum to control the lowering of `vector.contract` operations. +enum class VectorContractLowering { + /// Progressively lower to finer grained `vector.contract` and `vector.fma`. + FMA = 0, + /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. + Matmul = 1, + /// Lower to `vector.outerproduct`. + OuterProduct = 2, +}; /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { - /// Let vector.contract lower to vector.matrix_multiply and LLVM matrix - /// intrinsics. - bool lowerToLLVMMatrixIntrinsics = false; + VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; }; /// Collect a set of vector-to-vector canonicalization patterns. diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.td b/mlir/include/mlir/Dialect/Vector/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.td @@ -685,6 +685,11 @@ return %3: vector<4x8xf32> ``` }]; + let builders = [ + // Build an op without mask, use the type of `acc` as the return type. + OpBuilder< + "OpBuilder &builder, OperationState &result, Value lhs, Value rhs, " + "Value acc">]; let extraClassDeclaration = [{ VectorType getOperandVectorTypeLHS() { return lhs().getType().cast(); diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -9,6 +9,7 @@ #ifndef DIALECT_VECTOR_VECTORTRANSFORMS_H_ #define DIALECT_VECTOR_VECTORTRANSFORMS_H_ +#include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -22,13 +23,6 @@ ArrayRef coarseVectorShape = {}, ArrayRef fineVectorShape = {}); -//////////////////////////////////////////////////////////////////////////////// -// The following Declarative Rewrite Rule (DRR) helpers are used in rewrite -// patterns. As such, they must not call into `rewriter.erase/replace` APIs and -// it is the responsibility of the enclosing PatternRewriter to erase on -// success. -//////////////////////////////////////////////////////////////////////////////// - namespace vector { // Entry point for unrolling declarative pattern rewrites. @@ -69,6 +63,70 @@ ArrayRef targetShape); } // namespace vector + +//===----------------------------------------------------------------------===// +// Finer-grained patterns exposed for more control over individual lowerings. +//===----------------------------------------------------------------------===// + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +class ContractionOpToMatmulOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + ContractionOpToMatmulOpLowering( + vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; + + LogicalResult match(vector::ContractionOp op) const override; + void rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; +}; + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +class ContractionOpToOuterProductOpLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ContractionOpToOuterProductOpLowering( + vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + + LogicalResult match(vector::ContractionOp op) const override; + void rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; +}; + } // namespace mlir #endif // DIALECT_VECTOR_VECTORTRANSFORMS_H_ diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -957,6 +957,13 @@ // OuterProductOp //===----------------------------------------------------------------------===// +/// Build an op without mask, use the type of `acc` as the return type. +void OuterProductOp::build(OpBuilder &builder, OperationState &result, + Value lhs, Value rhs, Value acc) { + result.addOperands({lhs, rhs, acc}); + result.addTypes(acc.getType()); +} + static void print(OpAsmPrinter &p, OuterProductOp op) { p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs(); if (!op.acc().empty()) diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -497,6 +497,113 @@ return true; } +namespace mlir { + +LogicalResult +ContractionOpToMatmulOpLowering::match(vector::ContractionOp op) const { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return failure(); + + if (vectorTransformsOptions.vectorContractLowering != + vector::VectorContractLowering::Matmul || + !isRowMajorMatmul(op.indexing_maps())) + return failure(); + return success(); +} + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to: +/// ``` +/// %flattened_a = vector.shape_cast %a +/// %flattened_b = vector.shape_cast %b +/// %flattened_d = vector.matmul %flattened_a, %flattened_b +/// %d = vector.shape_cast %%flattened_d +/// %e = add %c, %d +/// ``` +/// `vector.matmul` later lowers to `llvm.matrix.multiply`. +void ContractionOpToMatmulOpLowering::rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + unsigned lhsRows = op.getLhsType().getShape()[0]; + unsigned lhsColumns = op.getLhsType().getShape()[1]; + unsigned rhsColumns = op.getRhsType().getShape()[1]; + + Type flattenedLHSType = + VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); + Type flattenedRHSType = + VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); + auto lhs = rewriter.create(op.getLoc(), flattenedLHSType, + op.lhs()); + auto rhs = rewriter.create(op.getLoc(), flattenedRHSType, + op.rhs()); + + Value mul = rewriter.create(op.getLoc(), lhs, rhs, lhsRows, + lhsColumns, rhsColumns); + mul = rewriter.create(op.getLoc(), op.acc().getType(), + mul); + Type elementType = op.getLhsType().getElementType(); + assert(elementType.isIntOrFloat()); + if (elementType.isa()) + rewriter.replaceOpWithNewOp(op, op.acc(), mul); + else + rewriter.replaceOpWithNewOp(op, op.acc(), mul); +} + +LogicalResult +ContractionOpToOuterProductOpLowering ::match(vector::ContractionOp op) const { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return failure(); + + if (vectorTransformsOptions.vectorContractLowering != + vector::VectorContractLowering::OuterProduct || + !isRowMajorMatmul(op.indexing_maps())) + return failure(); + return success(); +} + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +void ContractionOpToOuterProductOpLowering::rewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { + VectorType lhsType = op.getLhsType(); + // TODO(ntv) other modes. + // We know we are in row-major. + bool transposeLhs = false; + unsigned reductionSize = + transposeLhs ? lhsType.getShape()[0] : lhsType.getShape()[1]; + + // If transposeLhs == false (i.e. lhs(m, reductionSize)), we need to + // transpose it to extract the proper vector. Otherwise, just take + // the lhs. + Value lhs = transposeLhs + ? op.lhs() + : rewriter.create( + op.getLoc(), op.lhs(), ArrayRef{1, 0}); + Value res = op.acc(); + // ExtractOp does not allow dynamic indexing, we must unroll explicitly. + for (unsigned k = 0; k < reductionSize; ++k) { + Value a = rewriter.create(op.getLoc(), lhs, k); + Value b = rewriter.create(op.getLoc(), op.rhs(), k); + res = rewriter.create(op.getLoc(), a, b, res); + } + rewriter.replaceOp(op, res); +} + +} // namespace mlir + namespace { // Splits vector TransferReadOp into smaller TransferReadOps based on slicing @@ -1275,43 +1382,24 @@ : OpRewritePattern(context), vectorTransformsOptions(vectorTransformsOptions) {} + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; + LogicalResult matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const override { - // TODO(ajcbik): implement masks + + // TODO(ajcbik): implement masks. if (llvm::size(op.masks()) != 0) return failure(); - // TODO(ntv, ajcbik): implement benefits, cost models, separate this out in - // a new pattern. - if (vectorTransformsOptions.lowerToLLVMMatrixIntrinsics && - isRowMajorMatmul(op.indexing_maps())) { - VectorType lhsType = op.getLhsType(); - VectorType rhsType = op.getRhsType(); - unsigned lhsRows = op.getLhsType().getShape()[0]; - unsigned lhsColumns = op.getLhsType().getShape()[1]; - unsigned rhsColumns = op.getRhsType().getShape()[1]; - - Type flattenedLHSType = - VectorType::get(lhsType.getNumElements(), lhsType.getElementType()); - Type flattenedRHSType = - VectorType::get(rhsType.getNumElements(), rhsType.getElementType()); - auto lhs = rewriter.create( - op.getLoc(), flattenedLHSType, op.lhs()); - auto rhs = rewriter.create( - op.getLoc(), flattenedRHSType, op.rhs()); - - Value mul = rewriter.create( - op.getLoc(), lhs, rhs, lhsRows, lhsColumns, rhsColumns); - mul = rewriter.create(op.getLoc(), - op.acc().getType(), mul); - Type elementType = op.getLhsType().getElementType(); - assert(elementType.isIntOrFloat()); - if (elementType.isa()) - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - else - rewriter.replaceOpWithNewOp(op, op.acc(), mul); - return success(); - } + // TODO(ntv, ajcbik): implement benefits, cost models. + MLIRContext *ctx = op.getContext(); + ContractionOpToMatmulOpLowering pat1(vectorTransformsOptions, ctx); + if (succeeded(pat1.match(op))) + return failure(); + ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx); + if (succeeded(pat2.match(op))) + return failure(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); @@ -1585,8 +1673,6 @@ } return result; } - - vector::VectorTransformsOptions vectorTransformsOptions; }; /// ShapeOp 2D -> 1D downcast serves the purpose of flattening 2-D to 1-D @@ -1685,6 +1771,8 @@ ShapeCastOp2DDownCastRewritePattern, ShapeCastOp2DUpCastRewritePattern, TransposeOpLowering>(context); + patterns.insert(parameters, context); // clang-format on - patterns.insert(parameters, context); } diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1,5 +1,6 @@ -// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s -// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX +// RUN: mlir-opt %s -test-vector-contraction-conversion | FileCheck %s --dump-input-on-failure +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-lower-matrix-intrinsics=1 | FileCheck %s --check-prefix=MATRIX --dump-input-on-failure +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-outerproduct=1 | FileCheck %s --check-prefix=OUTERPRODUCT --dump-input-on-failure #dotp_accesses = [ affine_map<(i) -> (i)>, @@ -382,6 +383,35 @@ // MATRIX: %[[mm4:.*]] = vector.extract_strided_slice %[[mm1]] {offsets = [3], sizes = [3], strides = [1]} : vector<6xf32> to vector<3xf32> // MATRIX: %[[mm5:.*]] = vector.insert %[[mm4]], %[[mm3]] [1] : vector<3xf32> into vector<2x3xf32> // MATRIX: %[[mm6:.*]] = addf %[[C]], %[[mm5]] : vector<2x3xf32> + +// OUTERPRODUCT-LABEL: func @matmul +// OUTERPRODUCT-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// OUTERPRODUCT-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// OUTERPRODUCT-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// OUTERPRODUCT: %[[At:.*]] = vector.transpose %[[A]], [1, 0] +// OUTERPRODUCT-SAME: : vector<2x4xf32> to vector<4x2xf32> +// +// OUTERPRODUCT: %[[a0:.*]] = vector.extract %[[At]][0] : vector<4x2xf32> +// OUTERPRODUCT: %[[b0:.*]] = vector.extract %[[B]][0] : vector<4x3xf32> +// OUTERPRODUCT: %[[c0:.*]] = vector.outerproduct %[[a0]], %[[b0]], %[[C]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: %[[a1:.*]] = vector.extract %[[At]][1] : vector<4x2xf32> +// OUTERPRODUCT: %[[b1:.*]] = vector.extract %[[B]][1] : vector<4x3xf32> +// OUTERPRODUCT: %[[c1:.*]] = vector.outerproduct %[[a1]], %[[b1]], %[[c0]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: %[[a2:.*]] = vector.extract %[[At]][2] : vector<4x2xf32> +// OUTERPRODUCT: %[[b2:.*]] = vector.extract %[[B]][2] : vector<4x3xf32> +// OUTERPRODUCT: %[[c2:.*]] = vector.outerproduct %[[a2]], %[[b2]], %[[c1]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: %[[a3:.*]] = vector.extract %[[At]][3] : vector<4x2xf32> +// OUTERPRODUCT: %[[b3:.*]] = vector.extract %[[B]][3] : vector<4x3xf32> +// OUTERPRODUCT: %[[c3:.*]] = vector.outerproduct %[[a3]], %[[b3]], %[[c2]] +// OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> +// +// OUTERPRODUCT: return %[[c3]] : vector<2x3xf32> func @matmul(%arg0: vector<2x4xf32>, %arg1: vector<4x3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -51,11 +51,26 @@ *this, "vector-lower-matrix-intrinsics", llvm::cl::desc("Lower vector.contract to llvm.intr.matrix.multiply"), llvm::cl::init(false)}; + Option lowerToOuterProduct{ + *this, "vector-outerproduct", + llvm::cl::desc("Lower vector.contract to vector.outerproduct"), + llvm::cl::init(false)}; void runOnFunction() override { OwningRewritePatternList patterns; - VectorTransformsOptions options{ - /*lowerToLLVMMatrixIntrinsics=*/lowerToLLVMMatrixIntrinsics}; + if (lowerToOuterProduct) { + VectorContractLowering lowering = VectorContractLowering::OuterProduct; + VectorTransformsOptions options{lowering}; + patterns.insert(options, + &getContext()); + applyPatternsAndFoldGreedily(getFunction(), patterns); + return; + } + + VectorContractLowering lowering = VectorContractLowering::FMA; + if (lowerToLLVMMatrixIntrinsics) + lowering = VectorContractLowering::Matmul; + VectorTransformsOptions options{lowering}; populateVectorContractLoweringPatterns(patterns, &getContext(), options); applyPatternsAndFoldGreedily(getFunction(), patterns); }