diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2657,10 +2657,18 @@ if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); - VectorType resType = - vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType()) - : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + + unsigned numScalableDims = vLHS.getNumScalableDims(); + VectorType resType; + if (vRHS) { + numScalableDims += vRHS.getNumScalableDims(); + resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, + vLHS.getElementType(), numScalableDims); + } else { + // Scalar RHS operand + resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), + numScalableDims); + } if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { result.attributes.append( @@ -2696,6 +2704,9 @@ return emitOpError("expected #1 operand dim to match result dim #1"); if (vRHS.getDimSize(0) != vRES.getDimSize(1)) return emitOpError("expected #2 operand dim to match result dim #2"); + if (vRHS.isScalable() != vLHS.isScalable()) + return emitOpError("expected either all or none of vector operands #1 " + "and #2 to be scalable"); } else { // An AXPY operation. if (vRES.getRank() != 1)