diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1816,6 +1816,72 @@ return success(); } +namespace { +struct IteratorType { + IteratorType(StringRef strRef) : strRef(strRef) {} + bool isOfType(Attribute attr) const { + auto sAttr = attr.dyn_cast(); + return sAttr && sAttr.getValue() == strRef; + } + StringRef strRef; +}; +struct Par : public IteratorType { + Par() : IteratorType(getParallelIteratorTypeName()) {} +}; +struct Red : public IteratorType { + Red() : IteratorType(getReductionIteratorTypeName()) {} +}; + +// Unroll outer-products along reduction. +struct UnrolledOuterProductEmitter { + using MapList = ArrayRef>; + + UnrolledOuterProductEmitter(PatternRewriter &rewriter, + vector::ContractionOp op) + : rewriter(rewriter), loc(op.getLoc()), kind(op.kind()), + iterators(op.iterator_types()), maps(op.getIndexingMaps()), op(op) {} + + Value t(Value v) { + static constexpr std::array perm = {1, 0}; + return rewriter.create(loc, v, perm); + } + + bool iters(ArrayRef its) { + if (its.size() != iterators.size()) + return false; + for (int i = 0, e = its.size(); i != e; ++i) { + if (!its[i].isOfType(iterators[i])) + return false; + } + return true; + } + + bool layout(MapList l) { + auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; + return maps == infer(l); + } + + LogicalResult outer_prod(Value lhs, Value rhs, Value res, int reductionSize) { + assert(reductionSize > 0); + for (int64_t k = 0; k < reductionSize; ++k) { + Value a = rewriter.create(loc, lhs, k); + Value b = rewriter.create(loc, rhs, k); + res = rewriter.create(loc, res.getType(), a, b, + res, kind); + } + rewriter.replaceOp(op, res); + return success(); + } + + PatternRewriter &rewriter; + Location loc; + vector::CombiningKind kind; + ArrayAttr iterators; + SmallVector maps; + Operation *op; +}; +} // namespace + /// Progressively lower a `vector.contract %a, %b, %c` with row-major matmul /// semantics to a reduction_size-unrolled sequence: /// ``` @@ -1844,104 +1910,64 @@ if (failed(filter(op))) return failure(); - Location loc = op.getLoc(); - int64_t reductionSize = 0; VectorType lhsType = op.getLhsType(); Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); // Set up the parallel/reduction structure in right form. - using MapList = ArrayRef>; - auto infer = [](MapList m) { return AffineMap::inferFromExprList(m); }; AffineExpr m, n, k; bindDims(rewriter.getContext(), m, n, k); - static constexpr std::array perm = {1, 0}; - auto iteratorTypes = op.iterator_types().getValue(); - SmallVector maps = op.getIndexingMaps(); - if (isParallelIterator(iteratorTypes[0]) && - isParallelIterator(iteratorTypes[1]) && - isReductionIterator(iteratorTypes[2])) { - // - // Two outer parallel, one inner reduction (matmat flavor). - // - if (maps == infer({{m, k}, {k, n}, {m, n}})) { - // This is the classical row-major matmul. Just permute the lhs. - reductionSize = lhsType.getDimSize(1); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{m, k}, {n, k}, {m, n}})) { - // TODO: may be better to fail and use some vector -> scalar reduction. - reductionSize = lhsType.getDimSize(1); - lhs = rewriter.create(loc, lhs, perm); - rhs = rewriter.create(loc, rhs, perm); - } else if (maps == infer({{k, m}, {k, n}, {m, n}})) { - // No need to permute anything. - reductionSize = lhsType.getDimSize(0); - } else if (maps == infer({{k, m}, {n, k}, {m, n}})) { - // Just permute the rhs. - reductionSize = lhsType.getDimSize(0); - rhs = rewriter.create(loc, rhs, perm); - } else if (maps == infer({{m, k}, {k, n}, {n, m}})) { - // This is the classical row-major matmul. Just permute the lhs. - reductionSize = lhsType.getDimSize(1); - Value tmp = rhs; - rhs = rewriter.create(loc, lhs, perm); - lhs = tmp; - } else if (maps == infer({{m, k}, {n, k}, {n, m}})) { - // TODO: may be better to fail and use some vector -> scalar reduction. - reductionSize = lhsType.getDimSize(1); - Value tmp = rhs; - rhs = rewriter.create(loc, lhs, perm); - lhs = rewriter.create(loc, tmp, perm); - } else if (maps == infer({{k, m}, {k, n}, {n, m}})) { - // No need to permute anything, but still swap lhs and rhs. - reductionSize = lhsType.getDimSize(0); - std::swap(lhs, rhs); - } else if (maps == infer({{k, m}, {n, k}, {n, m}})) { - // Just permute the rhs. - reductionSize = lhsType.getDimSize(0); - Value tmp = lhs; - lhs = rewriter.create(loc, rhs, perm); - rhs = tmp; - } else { - return failure(); - } - } else if (isParallelIterator(iteratorTypes[0]) && - isReductionIterator(iteratorTypes[1])) { - // - // One outer parallel, one inner reduction (matvec flavor) - // - if (maps == infer({{m, n}, {n}, {m}})) { - // Case mat-vec: transpose. - reductionSize = lhsType.getDimSize(1); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{n, m}, {n}, {m}})) { - // Case mat-trans-vec: ready to go. - reductionSize = lhsType.getDimSize(0); - } else if (maps == infer({{n}, {m, n}, {m}})) { - // Case vec-mat: swap and transpose. - reductionSize = lhsType.getDimSize(0); - std::swap(lhs, rhs); - lhs = rewriter.create(loc, lhs, perm); - } else if (maps == infer({{n}, {n, m}, {m}})) { - // Case vec-mat-trans: swap and ready to go. - reductionSize = lhsType.getDimSize(0); - std::swap(lhs, rhs); - } else { - return failure(); - } - } else { + + // + // Two outer parallel, one inner reduction (matmat flavor). + // + UnrolledOuterProductEmitter e(rewriter, op); + if (e.iters({Par(), Par(), Red()})) { + // Classical row-major matmul: Just permute the lhs. + if (e.layout({{m, k}, {k, n}, {m, n}})) + return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); + // TODO: may be better to fail and use some vector -> scalar reduction. + if (e.layout({{m, k}, {n, k}, {m, n}})) + return e.outer_prod(e.t(lhs), e.t(rhs), res, lhsType.getDimSize(1)); + // No need to permute anything. + if (e.layout({{k, m}, {k, n}, {m, n}})) + return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + // Just permute the rhs. + if (e.layout({{k, m}, {n, k}, {m, n}})) + return e.outer_prod(lhs, e.t(rhs), res, lhsType.getDimSize(0)); + // Transposed output: swap RHS and LHS. + // Classical row-major matmul: permute the lhs. + if (e.layout({{m, k}, {k, n}, {n, m}})) + return e.outer_prod(rhs, e.t(lhs), res, lhsType.getDimSize(1)); + // TODO: may be better to fail and use some vector -> scalar reduction. + if (e.layout({{m, k}, {n, k}, {n, m}})) + return e.outer_prod(e.t(rhs), e.t(lhs), res, lhsType.getDimSize(1)); + if (e.layout({{k, m}, {k, n}, {n, m}})) + return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + if (e.layout({{k, m}, {n, k}, {n, m}})) + return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); return failure(); } - assert(reductionSize > 0); - - // Unroll outer-products along reduction. - for (int64_t k = 0; k < reductionSize; ++k) { - Value a = rewriter.create(op.getLoc(), lhs, k); - Value b = rewriter.create(op.getLoc(), rhs, k); - res = rewriter.create(op.getLoc(), res.getType(), a, - b, res, op.kind()); + + // + // One outer parallel, one inner reduction (matvec flavor) + // + if (e.iters({Par(), Red()})) { + // Case mat-vec: transpose. + if (e.layout({{m, n}, {n}, {m}})) + return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); + // Case mat-trans-vec: ready to go. + if (e.layout({{n, m}, {n}, {m}})) + return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + // Case vec-mat: swap and transpose. + if (e.layout({{n}, {m, n}, {m}})) + return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); + // Case vec-mat-trans: swap and ready to go. + if (e.layout({{n}, {n, m}, {m}})) + return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + return failure(); } - rewriter.replaceOp(op, res); - return success(); + + return failure(); } LogicalResult