diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h @@ -46,12 +46,14 @@ /// Enum to control the lowering of `vector.contract` operations. enum class VectorContractLowering { - /// Progressively lower to finer grained `vector.contract` and `vector.fma`. - FMA = 0, + /// Progressively lower to finer grained `vector.contract` and dot-products. + Dot = 0, /// Lower to `vector.matrix_multiply`, maps 1-1 to LLVM matrix intrinsics. Matmul = 1, /// Lower to `vector.outerproduct`. OuterProduct = 2, + /// Lower to series of AXPY chained through FMA. + AXPY = 3, }; /// Enum to control the lowering of `vector.transpose` operations. enum class VectorTransposeLowering { @@ -63,7 +65,7 @@ }; /// Structure to control the behavior of vector transform patterns. struct VectorTransformsOptions { - VectorContractLowering vectorContractLowering = VectorContractLowering::FMA; + VectorContractLowering vectorContractLowering = VectorContractLowering::Dot; VectorTransposeLowering vectorTransposeLowering = VectorTransposeLowering::EltWise; VectorTransformsOptions & diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -135,6 +135,34 @@ vector::VectorTransformsOptions vectorTransformsOptions; }; +/// Progressive lowering of a `vector.contract %a, %b, %c` with +/// matvec semantics to series of AXPY operations that are chained +/// through FMA operations. +// +/// This only kicks in when either VectorTransformsOptions is set +/// to AXPY or when other contraction patterns fail. +// +/// TODO (ajcbik): this is very similar, but not quite the same as +/// the outerproduct lowering above; merge the two? +class ContractionOpToAXPYLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + ContractionOpToAXPYLowering( + vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + + LogicalResult match(vector::ContractionOp op) const override; + void rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; +}; + /// Progressive lowering of ContractionOp. /// /// One: @@ -145,10 +173,10 @@ /// .. /// %x = combine %a %b .. /// until a pure contraction is reached (no free/batch dimensions), -/// which is replaced by a fma/reduction op. +/// which is replaced by a dot-product. /// -/// This only kicks in when either VectorTransformsOptions is set to FMA or when -/// other contraction patterns fail. +/// This only kicks in when either VectorTransformsOptions is set +/// to Dot or when other contraction patterns fail. class ContractionOpLowering : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1560,15 +1560,17 @@ if (llvm::size(op.masks()) != 0) return failure(); + if (vectorTransformsOptions.vectorContractLowering != + vector::VectorContractLowering::Matmul) + return failure(); + auto iteratorTypes = op.iterator_types().getValue(); if (!isParallelIterator(iteratorTypes[0]) || !isParallelIterator(iteratorTypes[1]) || !isReductionIterator(iteratorTypes[2])) return failure(); - if (vectorTransformsOptions.vectorContractLowering != - vector::VectorContractLowering::Matmul || - !isRowMajorMatmul(op.indexing_maps())) + if (!isRowMajorMatmul(op.indexing_maps())) return failure(); return success(); @@ -1723,6 +1725,89 @@ rewriter.replaceOp(op, res); } +/// Progressive lowering of a `vector.contract %a, %b, %c` with +/// matvec semantics to series of AXPY operations that are chained +/// through FMA operations. +// +/// This only kicks in when either VectorTransformsOptions is set +/// to AXPY or when other contraction patterns fail. +// +/// TODO (ajcbik): this is very similar, but not quite the same as +/// the outerproduct lowering above; merge the two? +LogicalResult +ContractionOpToAXPYLowering::match(vector::ContractionOp op) const { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return failure(); + + if (vectorTransformsOptions.vectorContractLowering != + vector::VectorContractLowering::AXPY) + return failure(); + + auto iteratorTypes = op.iterator_types().getValue(); + if (!isParallelIterator(iteratorTypes[0]) || + !isReductionIterator(iteratorTypes[1])) + return failure(); + + // See if a series of AXPY operations chained through FMA operations + // could replace the default DOT implementation. + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n; + bindDims(op.getContext(), m, n); + SmallVector maps = op.getIndexingMaps(); + if (maps != infer({{m, n}, {n}, {m}}) && // mat-vec + maps != infer({{n, m}, {n}, {m}}) && // mat-trans-vec + maps != infer({{n}, {m, n}, {m}}) && // vec-mat + maps != infer({{n}, {n, m}, {m}})) // vec-mat-trans + return failure(); + return success(); +} + +void ContractionOpToAXPYLowering::rewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + Location loc = op.getLoc(); + int64_t reductionSize = 0; + VectorType lhsType = op.getLhsType(); + Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); + + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n; + bindDims(op.getContext(), m, n); + SmallVector perm{1, 0}; + SmallVector maps = op.getIndexingMaps(); + if (maps == infer({{m, n}, {n}, {m}})) { + // Case mat-vec: transpose. + reductionSize = lhsType.getDimSize(1); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{n, m}, {n}, {m}})) { + // Case mat-trans-vec: ready to go. + reductionSize = lhsType.getShape()[0]; + } else if (maps == infer({{n}, {m, n}, {m}})) { + // Case vec-mat: swap and transpose. + reductionSize = lhsType.getShape()[0]; + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{n}, {n, m}, {m}})) { + // Case vec-mat-trans: swap and ready to go. + reductionSize = lhsType.getShape()[0]; + std::swap(lhs, rhs); + } + assert(reductionSize > 0); + + // A direct series of AXPY operations, chained through FMA. + // VectorType resType = op.getResultType().cast(); + Type resType = op.getResultType(); + for (int64_t k = 0; k < reductionSize; ++k) { + Value a = rewriter.create(loc, lhs, k); + Value s = rewriter.create(loc, rhs, k); + Value b = rewriter.create(loc, resType, s); + res = rewriter.create(loc, a, b, res); + } + rewriter.replaceOp(op, res); +} + /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension @@ -1732,7 +1817,10 @@ /// .. /// %x = combine %a %b .. /// until a pure contraction is reached (no free/batch dimensions), -/// which is replaced by a dot-product/reduction pair. +/// which is replaced by a dot-product. +/// +/// This only kicks in when either VectorTransformsOptions is set +/// to DOT or when other contraction patterns fail. /// /// TODO(ajcbik): break down into transpose/reshape/cast ops /// when they become available to avoid code dup @@ -1758,6 +1846,9 @@ ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx); if (succeeded(pat2.match(op))) return failure(); + ContractionOpToAXPYLowering pat3(vectorTransformsOptions, ctx); + if (succeeded(pat3.match(op))) + return failure(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); @@ -1943,6 +2034,7 @@ patterns.insert(parameters, context); + ContractionOpToOuterProductOpLowering, + ContractionOpToAXPYLowering>(parameters, context); // clang-format on } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -0,0 +1,163 @@ +// RUN: mlir-opt %s -test-vector-contraction-conversion=vector-axpy=1 | FileCheck %s + +#matvec_accesses = [ + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#matvec_trait = { + indexing_maps = #matvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +#mattransvec_accesses = [ + affine_map<(i, j) -> (j, i)>, + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i)> +] +#mattransvec_trait = { + indexing_maps = #mattransvec_accesses, + iterator_types = ["parallel", "reduction"] +} + +#vecmat_accesses = [ + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (i, j)>, + affine_map<(i, j) -> (i)> +] +#vecmat_trait = { + indexing_maps = #vecmat_accesses, + iterator_types = ["parallel", "reduction"] +} + +#vecmattrans_accesses = [ + affine_map<(i, j) -> (j)>, + affine_map<(i, j) -> (j, i)>, + affine_map<(i, j) -> (i)> +] +#vecmattrans_trait = { + indexing_maps = #vecmattrans_accesses, + iterator_types = ["parallel", "reduction"] +} + +// CHECK-LABEL: func @matvec2x2 +// CHECK-SAME: %[[A:.*0]]: memref> +// CHECK-SAME: %[[B:.*1]]: memref> +// CHECK-SAME: %[[C:.*2]]: memref> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[T0:.*]] = load %[[A]][] : memref> +// CHECK: %[[T1:.*]] = load %[[B]][] : memref> +// CHECK: %[[T2:.*]] = load %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32> +// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32> +// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32> +// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32> +// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32> +// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32> +// CHECK: store %[[T18]], %[[C]][] : memref> +func @matvec2x2(%arg0: memref>, %arg1: memref>, + %arg2: memref>) { + %A = load %arg0[] : memref> + %x = load %arg1[] : memref> + %b = load %arg2[] : memref> + %0 = vector.contract #matvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> + store %0, %arg2[] : memref> + return +} + +// CHECK-LABEL: func @mattransvec2x2 +// CHECK-SAME: %[[A:.*0]]: memref> +// CHECK-SAME: %[[B:.*1]]: memref> +// CHECK-SAME: %[[C:.*2]]: memref> +// CHECK: %[[T0:.*]] = load %[[A]][] : memref> +// CHECK: %[[T1:.*]] = load %[[B]][] : memref> +// CHECK: %[[T2:.*]] = load %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32> +// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32> +// CHECK: store %[[T10]], %[[C]][] : memref> +func @mattransvec2x2(%arg0: memref>, %arg1: memref>, + %arg2: memref>) { + %A = load %arg0[] : memref> + %x = load %arg1[] : memref> + %b = load %arg2[] : memref> + %0 = vector.contract #mattransvec_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32> + store %0, %arg2[] : memref> + return +} + +// CHECK-LABEL: func @vecmat2x2 +// CHECK-SAME: %[[A:.*0]]: memref> +// CHECK-SAME: %[[B:.*1]]: memref> +// CHECK-SAME: %[[C:.*2]]: memref> +// CHECK: %[[C0:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[T0:.*]] = load %[[A]][] : memref> +// CHECK: %[[T1:.*]] = load %[[B]][] : memref> +// CHECK: %[[T2:.*]] = load %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0, 0] : vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[C0]] [0, 0] : f32 into vector<2x2xf32> +// CHECK: %[[T5:.*]] = vector.extract %[[T0]][1, 0] : vector<2x2xf32> +// CHECK: %[[T6:.*]] = vector.insert %[[T5]], %[[T4]] [0, 1] : f32 into vector<2x2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T0]][0, 1] : vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.insert %[[T7]], %[[T6]] [1, 0] : f32 into vector<2x2xf32> +// CHECK: %[[T9:.*]] = vector.extract %[[T0]][1, 1] : vector<2x2xf32> +// CHECK: %[[T10:.*]] = vector.insert %[[T9]], %[[T8]] [1, 1] : f32 into vector<2x2xf32> +// CHECK: %[[T11:.*]] = vector.extract %[[T10]][0] : vector<2x2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T13:.*]] = splat %[[T12]] : vector<2xf32> +// CHECK: %[[T14:.*]] = vector.fma %[[T11]], %[[T13]], %[[T2]] : vector<2xf32> +// CHECK: %[[T15:.*]] = vector.extract %[[T10]][1] : vector<2x2xf32> +// CHECK: %[[T16:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T17:.*]] = splat %[[T16]] : vector<2xf32> +// CHECK: %[[T18:.*]] = vector.fma %[[T15]], %[[T17]], %[[T14]] : vector<2xf32> +// CHECK: store %[[T18]], %[[C]][] : memref> +func @vecmat2x2(%arg0: memref>, %arg1: memref>, + %arg2: memref>) { + %A = load %arg0[] : memref> + %x = load %arg1[] : memref> + %b = load %arg2[] : memref> + %0 = vector.contract #vecmat_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> + store %0, %arg2[] : memref> + return +} + +// CHECK-LABEL: func @vecmattrans2x2 +// CHECK-SAME: %[[A:.*0]]: memref> +// CHECK-SAME: %[[B:.*1]]: memref> +// CHECK-SAME: %[[C:.*2]]: memref> +// CHECK: %[[T0:.*]] = load %[[A]][] : memref> +// CHECK: %[[T1:.*]] = load %[[B]][] : memref> +// CHECK: %[[T2:.*]] = load %[[C]][] : memref> +// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32> +// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32> +// CHECK: %[[T5:.*]] = splat %[[T4]] : vector<2xf32> +// CHECK: %[[T6:.*]] = vector.fma %[[T3]], %[[T5]], %[[T2]] : vector<2xf32> +// CHECK: %[[T7:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32> +// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32> +// CHECK: %[[T9:.*]] = splat %[[T8]] : vector<2xf32> +// CHECK: %[[T10:.*]] = vector.fma %[[T7]], %[[T9]], %[[T6]] : vector<2xf32> +// CHECK: store %[[T10]], %[[C]][] : memref> +func @vecmattrans2x2(%arg0: memref>, %arg1: memref>, + %arg2: memref>) { + %A = load %arg0[] : memref> + %x = load %arg1[] : memref> + %b = load %arg2[] : memref> + %0 = vector.contract #vecmattrans_trait %x, %A, %b : vector<2xf32>, vector<2x2xf32> into vector<2xf32> + store %0, %arg2[] : memref> + return +} diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp --- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp +++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp @@ -59,9 +59,14 @@ *this, "vector-outerproduct", llvm::cl::desc("Lower vector.contract to vector.outerproduct"), llvm::cl::init(false)}; + Option lowerToAXPY{*this, "vector-axpy", + llvm::cl::desc("Lower vector.contract to AXPY"), + llvm::cl::init(false)}; void runOnFunction() override { OwningRewritePatternList patterns; + + // Test on individual lowering patterns. if (lowerToOuterProduct) { VectorContractLowering lowering = VectorContractLowering::OuterProduct; VectorTransformsOptions options{lowering}; @@ -71,9 +76,11 @@ return; } - VectorContractLowering contractLowering = VectorContractLowering::FMA; + VectorContractLowering contractLowering = VectorContractLowering::Dot; if (lowerToFlatMatrix) contractLowering = VectorContractLowering::Matmul; + else if (lowerToAXPY) + contractLowering = VectorContractLowering::AXPY; VectorTransposeLowering transposeLowering = VectorTransposeLowering::EltWise; if (lowerToFlatTranspose)