diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2650,10 +2650,26 @@ if (!vLHS) return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); - VectorType resType = - vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType()) - : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); + + unsigned numScalableDims = vLHS.isScalable() ? 1 : 0; + VectorType resType; + if (vRHS) { + if (vRHS.isScalable()) + numScalableDims++; + else if (vLHS.isScalable()) { + // Currently VectorType can only have the scalable dimensions as + // innermost; this can be changed if the limitation is removed. + return parser.emitError( + parser.getNameLoc(), + "expected operand #2 to be scalable if operand #1 is scalable"); + } + resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, + vLHS.getElementType(), numScalableDims); + } else { + // Scalar RHS operand + resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), + numScalableDims); + } if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { result.attributes.append( diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt -split-input-file -verify-diagnostics %s | mlir-opt + +func.func @scalable_outerproduct(%src : memref) { + %idx = arith.constant 0 : index + %cst = arith.constant 1.0 : f32 + %0 = vector.load %src[%idx] : memref, vector<[4]xf32> + %1 = vector.load %src[%idx] : memref, vector<[4]xf32> + + %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32> + vector.store %op, %src[%idx] : memref, vector<[4x4]xf32> + + %2 = vector.load %src[%idx] : memref, vector<4xf32> + %op1 = vector.outerproduct %2, %1 : vector<4xf32>, vector<[4]xf32> + vector.store %op1, %src[%idx] : memref, vector<4x[4]xf32> + + %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32 + vector.store %op2, %src[%idx] : memref, vector<[4]xf32> + return +} + +// ----- + +func.func @invalid_outerproduct(%src : memref) { + %idx = arith.constant 0 : index + %0 = vector.load %src[%idx] : memref, vector<[4]xf32> + %1 = vector.load %src[%idx] : memref, vector<4xf32> + + // expected-error @+1 {{expected operand #2 to be scalable if operand #1 is scalable}} + %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<4xf32> +} +// ----- + +func.func @invalid_outerproduct1(%src : memref) { + %idx = arith.constant 0 : index + %0 = vector.load %src[%idx] : memref, vector<[4x4]xf32> + %1 = vector.load %src[%idx] : memref, vector<[4]xf32> + + // expected-error @+1 {{expected 1-d vector for operand #1}} + %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32> +}