diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -19,11 +19,14 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Location.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" namespace mlir { +class PatternRewriter; + /// Tests whether the given maps describe a row major matmul. The test is /// permutation-invariant. Note that this only checks the affine maps from an /// operation, so does not perform any checks on the math being performed within @@ -132,6 +135,60 @@ llvm_unreachable("Unsupported IteratorType"); } +/// Helper StructuredGenerator class to manipulate and rewrite ops with +/// `StructuredOpInterface`. This is templated for now because VectorOps do not +/// yet implement the StructuredOpInterface itself. +template +class StructuredGenerator { +public: + using MapList = ArrayRef>; + + struct IteratorType { + IteratorType(StringRef strRef) : strRef(strRef) {} + bool isOfType(Attribute attr) const { + auto sAttr = attr.dyn_cast(); + return sAttr && sAttr.getValue() == strRef; + } + StringRef strRef; + }; + struct Par : public IteratorType { + Par() : IteratorType(getParallelIteratorTypeName()) {} + }; + struct Red : public IteratorType { + Red() : IteratorType(getReductionIteratorTypeName()) {} + }; + struct Win : public IteratorType { + Win() : IteratorType(getWindowIteratorTypeName()) {} + }; + + StructuredGenerator(PatternRewriter &rewriter, StructuredOpInterface op) + : rewriter(rewriter), ctx(op.getContext()), loc(op.getLoc()), + iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {} + + bool iters(ArrayRef its) { + if (its.size() != iterators.size()) + return false; + for (int i = 0, e = its.size(); i != e; ++i) { + if (!its[i].isOfType(iterators[i])) + return false; + } + return true; + } + + bool layout(MapList l) { + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + return maps == infer(l); + } + +protected: + PatternRewriter &rewriter; + MLIRContext *ctx; + Location loc; + ArrayAttr iterators; + SmallVector maps; + Operation *op; +}; + } // end namespace mlir #endif // MLIR_UTILS_STRUCTUREDOPSUTILS_H diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1252,35 +1252,22 @@ Red() : IteratorType(getReductionIteratorTypeName()) {} }; -// Unroll outer-products along reduction. -struct UnrolledOuterProductEmitter { - using MapList = ArrayRef>; +/// Generate a vector implementation for matmat, matvec and tmatvec. +/// This unrolls outer-products along the reduction dimension. +struct UnrolledOuterProductGenerator + : public StructuredGenerator { - UnrolledOuterProductEmitter(PatternRewriter &rewriter, - vector::ContractionOp op) - : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()), - iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {} + UnrolledOuterProductGenerator(PatternRewriter &rewriter, + vector::ContractionOp op) + : StructuredGenerator(rewriter, op), + kind(op.kind()), lhs(op.lhs()), rhs(op.rhs()), res(op.acc()), + lhsType(op.getLhsType()) {} Value t(Value v) { static constexpr std::array perm = {1, 0}; return rewriter.create(loc, v, perm); } - bool iters(ArrayRef its) { - if (its.size() != iterators.size()) - return false; - for (int i = 0, e = its.size(); i != e; ++i) { - if (!its[i].isOfType(iterators[i])) - return false; - } - return true; - } - - bool layout(MapList l) { - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; - return maps == infer(l); - } - LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) { assert(reductionSize > 0); for (int64_t k = 0; k < reductionSize; ++k) { @@ -1293,128 +1280,132 @@ return success(); } - PatternRewriter &rewriter; - Location loc; - vector::CombiningKind kind; - ArrayAttr iterators; - SmallVector maps; - Operation *op; -}; -} // namespace - -/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul -/// semantics to a reduction_size-unrolled sequence: -/// ``` -/// %at = vector.transpose %a, [1, 0] -/// %bRow0 = vector.extract %b[0] -/// %atRow0 = vector.extract %at[0] -/// %c0 = vector.outerproduct %atRow0, %bRow0, %c -/// ... -/// %bRowK = vector.extract %b[K] -/// %atRowK = vector.extract %at[K] -/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 -/// ``` -/// -/// This only kicks in when VectorTransformsOptions is set to OuterProduct but -/// otherwise supports any layout permutation of the matrix-multiply. -LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( - vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks - if (llvm::size(op.masks()) != 0) - return failure(); - - if (vectorTransformOptions.vectorContractLowering != - vector::VectorContractLowering::OuterProduct) - return failure(); - - if (failed(filter(op))) - return failure(); - - VectorType lhsType = op.getLhsType(); - Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); - - // - // Two outer parallel, one inner reduction (matmat flavor). - // - UnrolledOuterProductEmitter e(rewriter, op); - if (e.iters({Par(), Par(), Red()})) { - // Set up the parallel/reduction structure in right form. + /// Two outer parallel, one inner reduction (matmat flavor). + LogicalResult matmat() { + if (!iters({Par(), Par(), Red()})) + return failure(); + // Set up the parallel/reduction structure in the right form. AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. - if (e.layout({{m, k}, {k, n}, {m, n}})) - return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); + if (layout({{m, k}, {k, n}, {m, n}})) + return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); // TODO: may be better to fail and use some vector -> scalar reduction. - if (e.layout({{m, k}, {n, k}, {m, n}})) { - Value tlhs = e.t(lhs); - return e.outer_prod(tlhs, e.t(rhs), res, lhsType.getDimSize(1)); + if (layout({{m, k}, {n, k}, {m, n}})) { + Value tlhs = t(lhs); + return outer_prod(tlhs, t(rhs), res, lhsType.getDimSize(1)); } // No need to permute anything. - if (e.layout({{k, m}, {k, n}, {m, n}})) - return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + if (layout({{k, m}, {k, n}, {m, n}})) + return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); // Just permute the rhs. - if (e.layout({{k, m}, {n, k}, {m, n}})) - return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0)); + if (layout({{k, m}, {n, k}, {m, n}})) + return outer_prod(lhs, t(rhs), res, lhsType.getDimSize(0)); // Transposed output: swap RHS and LHS. // Classical row-major matmul: permute the lhs. - if (e.layout({{m, k}, {k, n}, {n, m}})) - return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1)); + if (layout({{m, k}, {k, n}, {n, m}})) + return outer_prod(rhs, t(lhs), res, lhsType.getDimSize(1)); // TODO: may be better to fail and use some vector -> scalar reduction. - if (e.layout({{m, k}, {n, k}, {n, m}})) { - Value trhs = e.t(rhs); - return e.outer_prod(trhs, e.t(lhs), res, lhsType.getDimSize(1)); + if (layout({{m, k}, {n, k}, {n, m}})) { + Value trhs = t(rhs); + return outer_prod(trhs, t(lhs), res, lhsType.getDimSize(1)); } - if (e.layout({{k, m}, {k, n}, {n, m}})) - return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); - if (e.layout({{k, m}, {n, k}, {n, m}})) - return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); + if (layout({{k, m}, {k, n}, {n, m}})) + return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + if (layout({{k, m}, {n, k}, {n, m}})) + return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); return failure(); } - // - // One outer parallel, one inner reduction (matvec flavor) - // - if (e.iters({Par(), Red()})) { + /// One outer parallel, one inner reduction (matvec flavor) + LogicalResult matvec() { + if (!iters({Par(), Red()})) + return failure(); AffineExpr m, k; bindDims(rewriter.getContext(), m, k); // Case mat-vec: transpose. - if (e.layout({{m, k}, {k}, {m}})) - return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); + if (layout({{m, k}, {k}, {m}})) + return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); // Case mat-trans-vec: ready to go. - if (e.layout({{k, m}, {k}, {m}})) - return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + if (layout({{k, m}, {k}, {m}})) + return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); // Case vec-mat: swap and transpose. - if (e.layout({{k}, {m, k}, {m}})) - return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); + if (layout({{k}, {m, k}, {m}})) + return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); // Case vec-mat-trans: swap and ready to go. - if (e.layout({{k}, {k, m}, {m}})) - return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + if (layout({{k}, {k, m}, {m}})) + return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); return failure(); } // // One outer reduction, one inner parallel (tmatvec flavor) // - if (e.iters({Red(), Par()})) { + LogicalResult tmatvec() { + if (!iters({Red(), Par()})) + return failure(); AffineExpr k, m; bindDims(rewriter.getContext(), k, m); // Case mat-vec: transpose. - if (e.layout({{m, k}, {k}, {m}})) - return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); + if (layout({{m, k}, {k}, {m}})) + return outer_prod(t(lhs), rhs, res, lhsType.getDimSize(1)); // Case mat-trans-vec: ready to go. - if (e.layout({{k, m}, {k}, {m}})) - return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + if (layout({{k, m}, {k}, {m}})) + return outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); // Case vec-mat: swap and transpose. - if (e.layout({{k}, {m, k}, {m}})) - return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); + if (layout({{k}, {m, k}, {m}})) + return outer_prod(t(rhs), lhs, res, lhsType.getDimSize(0)); // Case vec-mat-trans: swap and ready to go. - if (e.layout({{k}, {k, m}, {m}})) - return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + if (layout({{k}, {k, m}, {m}})) + return outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); return failure(); } +private: + vector::CombiningKind kind; + Value lhs, rhs, res; + VectorType lhsType; +}; +} // namespace + +/// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul +/// semantics to a reduction_size-unrolled sequence: +/// ``` +/// %at = vector.transpose %a, [1, 0] +/// %bRow0 = vector.extract %b[0] +/// %atRow0 = vector.extract %at[0] +/// %c0 = vector.outerproduct %atRow0, %bRow0, %c +/// ... +/// %bRowK = vector.extract %b[K] +/// %atRowK = vector.extract %at[K] +/// %cK = vector.outerproduct %atRowK, %bRowK, %cK-1 +/// ``` +/// +/// This only kicks in when VectorTransformsOptions is set to OuterProduct but +/// otherwise supports any layout permutation of the matrix-multiply. +LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( + vector::ContractionOp op, PatternRewriter &rewriter) const { + // TODO: implement masks + if (llvm::size(op.masks()) != 0) + return failure(); + + if (vectorTransformOptions.vectorContractLowering != + vector::VectorContractLowering::OuterProduct) + return failure(); + + if (failed(filter(op))) + return failure(); + + UnrolledOuterProductGenerator e(rewriter, op); + if (succeeded(e.matmat())) + return success(); + if (succeeded(e.matvec())) + return success(); + if (succeeded(e.tmatvec())) + return success(); + return failure(); }