diff --git a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h --- a/mlir/include/mlir/Dialect/Vector/VectorTransforms.h +++ b/mlir/include/mlir/Dialect/Vector/VectorTransforms.h @@ -258,6 +258,50 @@ FilterConstraintType filter; }; +/// Progressive lowering of a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a output-size-unrolled sequence: +/// ``` +/// %out = constant ... : vector +/// %bt = vector.transpose %b, [1, 0] +/// %aRow0 = vector.extract %a[0] +/// %btRow0 = vector.extract %bt[0] +/// %c00 = vector.reduce %atRow0, %bRow0 +/// %out00 = vector.insert %c00, %out[0, 0] +/// ... +/// %aRowLast = vector.extract %at[M-1] +/// %btRowLast = vector.extract %b[N-1] +/// %cLastLast = vector.reduce %atRowLast, %bRowLast +/// %outcLastLast = vector.insert %cLastLast, %out[M-1, N-1] +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to Dot and +/// the vector.contract op is a row-major matmul or matvec. +class ContractionOpToDotLowering + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + using FilterConstraintType = + std::function; + + static LogicalResult defaultFilter(vector::ContractionOp op) { + return success(); + } + + ContractionOpToDotLowering( + vector::VectorTransformsOptions vectorTransformsOptions, + MLIRContext *context) + : OpRewritePattern(context), + vectorTransformsOptions(vectorTransformsOptions) {} + + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override; + +private: + /// Options to control the vector patterns. + vector::VectorTransformsOptions vectorTransformsOptions; + FilterConstraintType filter; +}; + /// Progressive lowering of ContractionOp. /// /// One: diff --git a/mlir/include/mlir/Interfaces/VectorInterfaces.td b/mlir/include/mlir/Interfaces/VectorInterfaces.td --- a/mlir/include/mlir/Interfaces/VectorInterfaces.td +++ b/mlir/include/mlir/Interfaces/VectorInterfaces.td @@ -137,8 +137,9 @@ /*methodName=*/"getVectorType", /*args=*/(ins), /*methodBody=*/"", - /*defaultImplementation=*/ - "return $_op.vector().getType().template cast();" + /*defaultImplementation=*/[{ + return $_op.vector().getType().template dyn_cast(); + }] >, InterfaceMethod< /*desc=*/[{ Return the number of dimensions that participate in the diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1754,6 +1754,121 @@ return success(); } +LogicalResult +ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const { + // TODO(ajcbik): implement masks + if (llvm::size(op.masks()) != 0) + return failure(); + + if (failed(filter(op))) + return failure(); + + if (vectorTransformsOptions.vectorContractLowering != + vector::VectorContractLowering::Dot) + return failure(); + + auto iteratorTypes = op.iterator_types().getValue(); + static constexpr std::array perm = {1, 0}; + Location loc = op.getLoc(); + Value lhs = op.lhs(), rhs = op.rhs(); + + using MapList = ArrayRef>; + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); + SmallVector maps = op.getIndexingMaps(); + // + // In the following we wish to make the reduction dimension innermost so we + // can load vectors and just fmul + reduce into a scalar. + // + if (isParallelIterator(iteratorTypes[0]) && + isParallelIterator(iteratorTypes[1]) && + isReductionIterator(iteratorTypes[2])) { + // + // Two outer parallel, one inner reduction (matmat flavor). + // + if (maps == infer({{m, k}, {k, n}, {m, n}})) { + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { + // No need to permute anything. + } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + rhs = rewriter.create(loc, rhs, perm); + } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { + // This is the classical row-major matmul. Just permute the lhs. + Value tmp = lhs; + lhs = rewriter.create(loc, rhs, perm); + rhs = tmp; + } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { + std::swap(lhs, rhs); + } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { + Value tmp = lhs; + lhs = rewriter.create(loc, rhs, perm); + rhs = rewriter.create(loc, tmp, perm); + } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { + Value tmp = rhs; + rhs = rewriter.create(loc, lhs, perm); + lhs = tmp; + } else { + return failure(); + } + } else if (isParallelIterator(iteratorTypes[0]) && + isReductionIterator(iteratorTypes[1])) { + // + // One outer parallel, one inner reduction (matvec flavor) + // + if (maps == infer({{m, n}, {n}, {m}})) { + // No need to permute anything. + } else if (maps == infer({{n, m}, {n}, {m}})) { + lhs = rewriter.create(loc, lhs, perm); + } else if (maps == infer({{n}, {m, n}, {m}})) { + std::swap(lhs, rhs); + } else if (maps == infer({{n}, {n, m}, {m}})) { + std::swap(lhs, rhs); + lhs = rewriter.create(loc, lhs, perm); + } else { + return failure(); + } + } else { + return failure(); + } + + VectorType dstType = op.getResultType().cast(); + assert(dstType.getRank() >= 1 && dstType.getRank() <= 2 && + "Expected dst type of rank 1 or 2"); + + unsigned rank = dstType.getRank(); + unsigned dstRows = dstType.getShape()[0]; + unsigned dstColumns = rank == 1 ? 1 : dstType.getShape()[1]; + + // ExtractOp does not allow dynamic indexing, we must unroll explicitly. + Value res = + rewriter.create(loc, dstType, rewriter.getZeroAttr(dstType)); + for (unsigned r = 0; r < dstRows; ++r) { + Value a = rewriter.create(op.getLoc(), lhs, r); + for (unsigned c = 0; c < dstColumns; ++c) { + Value b = rank == 1 + ? rhs + : rewriter.create(op.getLoc(), rhs, c); + Value m = rewriter.create(op.getLoc(), a, b); + Value reduced = rewriter.create( + op.getLoc(), dstType.getElementType(), rewriter.getStringAttr("add"), + m, ValueRange{}); + + SmallVector pos = + rank == 1 ? ArrayRef{r} : ArrayRef{r, c}; + res = rewriter.create(op.getLoc(), reduced, res, pos); + } + } + if (auto acc = op.acc()) + res = rewriter.create(op.getLoc(), res, op.acc()); + rewriter.replaceOp(op, res); + return success(); +} + /// Progressive lowering of ContractionOp. /// One: /// %x = vector.contract with at least one free/batch dimension @@ -1795,6 +1910,9 @@ ContractionOpToOuterProductOpLowering pat2(vectorTransformsOptions, ctx); if (succeeded(pat2.matchAndRewrite(op, rewriter))) return success(); + ContractionOpToDotLowering pat3(vectorTransformsOptions, ctx); + if (succeeded(pat3.matchAndRewrite(op, rewriter))) + return success(); // Find first batch dimension in LHS/RHS, and lower when found. std::vector> batchDimMap = op.getBatchDimMap(); diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -43,16 +43,15 @@ // CHECK-SAME: %[[C:.*2]]: vector<2xf32> // CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> // CHECK: %[[T2:.*]] = mulf %[[T0]], %[[B]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32 +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[A]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> // CHECK: %[[T7:.*]] = mulf %[[T5]], %[[B]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32 +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: return %[[T9]] : vector<2xf32> +// CHECK: %[[T10:.*]] = addf %[[T9]], %[[C]] : vector<2xf32> +// CHECK: return %[[T10]] : vector<2xf32> func @extract_contract2(%arg0: vector<2x3xf32>, %arg1: vector<3xf32>, @@ -78,16 +77,15 @@ // CHECK-SAME: %[[C:.*2]]: vector<2xf32> // CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2xf32> // CHECK: %[[T0:.*]] = vector.extract %[[B]][0] : vector<2x3xf32> -// CHECK: %[[T1:.*]] = vector.extract %[[C]][0] : vector<2xf32> -// CHECK: %[[T2:.*]] = mulf %[[A]], %[[T0]] : vector<3xf32> -// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]], %[[T1]] : vector<3xf32> into f32 +// CHECK: %[[T2:.*]] = mulf %[[T0]], %[[A]] : vector<3xf32> +// CHECK: %[[T3:.*]] = vector.reduction "add", %[[T2]] : vector<3xf32> into f32 // CHECK: %[[T4:.*]] = vector.insert %[[T3]], %[[R]] [0] : f32 into vector<2xf32> // CHECK: %[[T5:.*]] = vector.extract %[[B]][1] : vector<2x3xf32> -// CHECK: %[[T6:.*]] = vector.extract %[[C]][1] : vector<2xf32> -// CHECK: %[[T7:.*]] = mulf %[[A]], %[[T5]] : vector<3xf32> -// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]], %[[T6]] : vector<3xf32> into f32 +// CHECK: %[[T7:.*]] = mulf %[[T5]], %[[A]] : vector<3xf32> +// CHECK: %[[T8:.*]] = vector.reduction "add", %[[T7]] : vector<3xf32> into f32 // CHECK: %[[T9:.*]] = vector.insert %[[T8]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: return %[[T9]] : vector<2xf32> +// CHECK: %[[T10:.*]] = addf %[[T9]], %[[C]] : vector<2xf32> +// CHECK: return %[[T10]] : vector<2xf32> func @extract_contract3(%arg0: vector<3xf32>, %arg1: vector<2x3xf32>, @@ -112,47 +110,31 @@ // CHECK-SAME: %[[B:.*1]]: vector<2x2xf32>, // CHECK-SAME: %[[C:.*2]]: vector<2x2xf32> // CHECK: %[[R:.*]] = constant dense<0.000000e+00> : vector<2x2xf32> -// CHECK: %[[Z:.*]] = constant dense<0.000000e+00> : vector<2xf32> +// ... bunch of extract insert to transpose B into Bt +// CHECK: %[[Bt:.*]] = vector.insert %{{.*}}, %{{.*}} [1, 1] : f32 into vector<2x2xf32> // CHECK: %[[T0:.*]] = vector.extract %[[A]][0] : vector<2x2xf32> -// CHECK: %[[T2:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32> -// CHECK: %[[T4:.*]] = vector.insert %[[T2]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T5:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32> -// CHECK: %[[T7:.*]] = vector.insert %[[T5]], %[[T4]] [1] : f32 into vector<2xf32> -// CHECK: %[[T8:.*]] = vector.extract %[[C]][0, 0] : vector<2x2xf32> -// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T7]] : vector<2xf32> -// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]], %[[T8]] : vector<2xf32> into f32 -// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T2:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> +// CHECK: %[[T9:.*]] = mulf %[[T0]], %[[T2]] : vector<2xf32> +// CHECK: %[[T10:.*]] = vector.reduction "add", %[[T9]] : vector<2xf32> into f32 +// CHECK: %[[T11:.*]] = vector.insert %[[T10]], %[[R]] [0, 0] : f32 into vector<2x2xf32> // -// CHECK: %[[T12:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32> -// CHECK: %[[T14:.*]] = vector.insert %[[T12]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T15:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32> -// CHECK: %[[T17:.*]] = vector.insert %[[T15]], %[[T14]] [1] : f32 into vector<2xf32> -// CHECK: %[[T18:.*]] = vector.extract %[[C]][0, 1] : vector<2x2xf32> -// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T17]] : vector<2xf32> -// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]], %[[T18]] : vector<2xf32> into f32 -// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [1] : f32 into vector<2xf32> -// CHECK: %[[T22:.*]] = vector.insert %[[T21]], %[[R]] [0] : vector<2xf32> into vector<2x2xf32> +// CHECK: %[[T12:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> +// CHECK: %[[T19:.*]] = mulf %[[T0]], %[[T12]] : vector<2xf32> +// CHECK: %[[T20:.*]] = vector.reduction "add", %[[T19]] : vector<2xf32> into f32 +// CHECK: %[[T21:.*]] = vector.insert %[[T20]], %[[T11]] [0, 1] : f32 into vector<2x2xf32> // // CHECK: %[[T23:.*]] = vector.extract %[[A]][1] : vector<2x2xf32> -// CHECK: %[[T22b:.*]] = vector.extract %[[B]][0, 0] : vector<2x2xf32> -// CHECK: %[[T24:.*]] = vector.insert %[[T22b]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T25:.*]] = vector.extract %[[B]][1, 0] : vector<2x2xf32> -// CHECK: %[[T27:.*]] = vector.insert %[[T25]], %[[T24]] [1] : f32 into vector<2xf32> -// CHECK: %[[T28:.*]] = vector.extract %[[C]][1, 0] : vector<2x2xf32> -// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T27]] : vector<2xf32> -// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]], %[[T28]] : vector<2xf32> into f32 -// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[Z]] [0] : f32 into vector<2xf32> +// CHECK: %[[T24:.*]] = vector.extract %[[Bt]][0] : vector<2x2xf32> +// CHECK: %[[T32:.*]] = mulf %[[T23]], %[[T24]] : vector<2xf32> +// CHECK: %[[T33:.*]] = vector.reduction "add", %[[T32]] : vector<2xf32> into f32 +// CHECK: %[[T34:.*]] = vector.insert %[[T33]], %[[T21]] [1, 0] : f32 into vector<2x2xf32> // -// CHECK: %[[T42:.*]] = vector.extract %[[B]][0, 1] : vector<2x2xf32> -// CHECK: %[[T44:.*]] = vector.insert %[[T42]], %[[Z]] [0] : f32 into vector<2xf32> -// CHECK: %[[T45:.*]] = vector.extract %[[B]][1, 1] : vector<2x2xf32> -// CHECK: %[[T47:.*]] = vector.insert %[[T45]], %[[T44]] [1] : f32 into vector<2xf32> -// CHECK: %[[T48:.*]] = vector.extract %[[C]][1, 1] : vector<2x2xf32> -// CHECK: %[[T49:.*]] = mulf %[[T23]], %[[T47]] : vector<2xf32> -// CHECK: %[[T50:.*]] = vector.reduction "add", %[[T49]], %[[T48]] : vector<2xf32> into f32 +// CHECK: %[[T40:.*]] = vector.extract %[[Bt]][1] : vector<2x2xf32> +// CHECK: %[[T41:.*]] = mulf %[[T23]], %[[T40]] : vector<2xf32> +// CHECK: %[[T42:.*]] = vector.reduction "add", %[[T41]] : vector<2xf32> into f32 +// CHECK: %[[T43:.*]] = vector.insert %[[T42]], %[[T34]] [1, 1] : f32 into vector<2x2xf32> // -// CHECK: %[[T51:.*]] = vector.insert %[[T50]], %[[T34]] [1] : f32 into vector<2xf32> -// CHECK: %[[T52:.*]] = vector.insert %[[T51]], %[[T22]] [1] : vector<2xf32> into vector<2x2xf32> +// CHECK: %[[T52:.*]] = addf %[[T43]], %[[C]] : vector<2x2xf32> // CHECK: return %[[T52]] : vector<2x2xf32> func @extract_contract4(%arg0: vector<2x2xf32>, @@ -574,6 +556,31 @@ // OUTERPRODUCT-SAME: : vector<2xf32>, vector<3xf32> // // OUTERPRODUCT: return %[[c3]] : vector<2x3xf32> + +// REDUCE-LABEL: func @matmul +// REDUCE-SAME: %[[A:[a-zA-Z0-9]*]]: vector<2x4xf32>, +// REDUCE-SAME: %[[B:[a-zA-Z0-9]*]]: vector<4x3xf32>, +// REDUCE-SAME: %[[C:[a-zA-Z0-9]*]]: vector<2x3xf32> +// +// REDUCE: %[[RES:.*]] = constant dense<0.000000e+00> : vector<2x3xf32> +// REDUCE: %[[Bt:.*]] = vector.transpose %[[B]], [1, 0] +// REDUCE-SAME: : vector<4x3f32> to vector<3x4xf32> +// +// REDUCE: %[[a0:.*]] = vector.extract %[[A]][0] : vector<2x4xf32> +// REDUCE-NEXT: %[[b0:.*]] = vector.extract %[[Bt]][0] : vector<3x4xf32> +// REDUCE-NEXT: %[[ab00:.*]] = mul %[[a0]], %[[b0]] : vector<4xf32> +// REDUCE-NEXT: %[[s00:.*]] = vector.reduction "add", %[[ab00]] : vector<4xf32> into f32 +// REDUCE-NEXT: %[[r00:.*]] = vector.insert %[[s00]], %[[RES]] [0, 0] : f32 into vector<2x3xf32> +// +// ... +// +// REDUCE: %[[a1:.*]] = vector.extract %[[A]][1] : vector<2x4xf32> +// REDUCE-NEXT: %[[b2:.*]] = vector.extract %[[Bt]][2] : vector<3x4xf32> +// REDUCE-NEXT: %[[ab12:.*]] = mul %[[a1]], %[[b02]] : vector<4xf32> +// REDUCE-NEXT: %[[s12:.*]] = vector.reduction "add", %[[ab12]] : vector<4xf32> into f32 +// REDUCE-NEXT: %[[r12:.*]] = vector.insert %[[s12]], %{{.*}} [1, 2] : f32 into vector<2x3xf32> +// +// REDUCE: return %[[c3]] : vector<2x3xf32> func @matmul(%arg0: vector<2x4xf32>, %arg1: vector<4x3xf32>, %arg2: vector<2x3xf32>) -> vector<2x3xf32> { @@ -1056,7 +1063,3 @@ : vector<3x4xf32>, vector<4x4xf32> into vector<3x4xf32> return %0 : vector<3x4xf32> } - - - -