Index: mlir/lib/Dialect/Vector/IR/VectorOps.cpp =================================================================== --- mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2650,10 +2650,18 @@ if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); - VectorType resType = - vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType()) - : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + + unsigned numScalableDims = vLHS.getNumScalableDims(); + VectorType resType; + if (vRHS) { + numScalableDims += vRHS.getNumScalableDims(); + resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, + vLHS.getElementType(), numScalableDims); + } else { + // Scalar RHS operand + resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), + numScalableDims); + } if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { result.attributes.append( @@ -2689,6 +2697,9 @@ return emitOpError("expected #1 operand dim to match result dim #1"); if (vRHS.getDimSize(0) != vRES.getDimSize(1)) return emitOpError("expected #2 operand dim to match result dim #2"); + if (vRHS.isScalable() != vLHS.isScalable()) + return emitOpError("expected either all or none of vector operands #1 " + "and #2 to be scalable"); } else { // An AXPY operation. if (vRES.getRank() != 1) Index: mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt + +func.func @scalable_outerproduct(%src : memref) { + %idx = arith.constant 0 : index + %cst = arith.constant 1.0 : f32 + %0 = vector.load %src[%idx] : memref, vector<[4]xf32> + %1 = vector.load %src[%idx] : memref, vector<[4]xf32> + + %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32> + vector.store %op, %src[%idx] : memref, vector<[4x4]xf32> + + %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32 + vector.store %op2, %src[%idx] : memref, vector<[4]xf32> + return +} + +// ----- + +func.func @invalid_outerproduct(%src : memref) { + %idx = arith.constant 0 : index + %0 = vector.load %src[%idx] : memref, vector<[4]xf32> + %1 = vector.load %src[%idx] : memref, vector<4xf32> + + // expected-error @+1 {{expected either all or none of vector operands #1 and #2 to be scalable}} + %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32> +} +// ----- + +func.func @invalid_outerproduct1(%src : memref) { + %idx = arith.constant 0 : index + %0 = vector.load %src[%idx] : memref, vector<[4x4]xf32> + %1 = vector.load %src[%idx] : memref, vector<[4]xf32> + + // expected-error @+1 {{expected 1-d vector for operand #1}} + %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32> +}