diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp @@ -1913,15 +1913,14 @@ VectorType lhsType = op.getLhsType(); Value lhs = op.lhs(), rhs = op.rhs(), res = op.acc(); - // Set up the parallel/reduction structure in right form. - AffineExpr m, n, k; - bindDims(rewriter.getContext(), m, n, k); - // // Two outer parallel, one inner reduction (matmat flavor). // UnrolledOuterProductEmitter e(rewriter, op); if (e.iters({Par(), Par(), Red()})) { + // Set up the parallel/reduction structure in right form. + AffineExpr m, n, k; + bindDims(rewriter.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (e.layout({{m, k}, {k, n}, {m, n}})) return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); @@ -1956,17 +1955,42 @@ // One outer parallel, one inner reduction (matvec flavor) // if (e.iters({Par(), Red()})) { + AffineExpr m, k; + bindDims(rewriter.getContext(), m, k); + + // Case mat-vec: transpose. + if (e.layout({{m, k}, {k}, {m}})) + return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); + // Case mat-trans-vec: ready to go. + if (e.layout({{k, m}, {k}, {m}})) + return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); + // Case vec-mat: swap and transpose. + if (e.layout({{k}, {m, k}, {m}})) + return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); + // Case vec-mat-trans: swap and ready to go. + if (e.layout({{k}, {k, m}, {m}})) + return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); + return failure(); + } + + // + // One outer reduction, one inner parallel (tmatvec flavor) + // + if (e.iters({Red(), Par()})) { + AffineExpr k, m; + bindDims(rewriter.getContext(), k, m); + // Case mat-vec: transpose. - if (e.layout({{m, n}, {n}, {m}})) + if (e.layout({{m, k}, {k}, {m}})) return e.outer_prod(e.t(lhs), rhs, res, lhsType.getDimSize(1)); // Case mat-trans-vec: ready to go. - if (e.layout({{n, m}, {n}, {m}})) + if (e.layout({{k, m}, {k}, {m}})) return e.outer_prod(lhs, rhs, res, lhsType.getDimSize(0)); // Case vec-mat: swap and transpose. - if (e.layout({{n}, {m, n}, {m}})) + if (e.layout({{k}, {m, k}, {m}})) return e.outer_prod(e.t(rhs), lhs, res, lhsType.getDimSize(0)); // Case vec-mat-trans: swap and ready to go. - if (e.layout({{n}, {n, m}, {m}})) + if (e.layout({{k}, {k, m}, {m}})) return e.outer_prod(rhs, lhs, res, lhsType.getDimSize(0)); return failure(); }