diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -897,7 +897,7 @@ TCresVTEtIsSameAsOpBase<0, 1>>, DeclareOpInterfaceMethods]>, Arguments<(ins AnyVector:$lhs, AnyType:$rhs, - Variadic:$acc, + Optional:$acc, DefaultValuedAttr:$kind)>, Results<(outs AnyVector)> { let summary = "vector outerproduct with optional fused add"; @@ -961,9 +961,9 @@ return getRhs().getType(); } VectorType getOperandVectorTypeACC() { - return getAcc().empty() - ? VectorType() - : ::llvm::cast((*getAcc().begin()).getType()); + return getAcc() + ? ::llvm::cast(getAcc().getType()) + : VectorType(); } VectorType getResultVectorType() { return ::llvm::cast(getResult().getType()); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2756,7 +2756,7 @@ void OuterProductOp::print(OpAsmPrinter &p) { p << " " << getLhs() << ", " << getRhs(); - if (!getAcc().empty()) { + if (getAcc()) { p << ", " << getAcc(); p.printOptionalAttrDict((*this)->getAttrs()); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1128,7 +1128,7 @@ VectorType resType = op.getResultVectorType(); Type eltType = resType.getElementType(); bool isInt = isa(eltType); - Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; + Value acc = op.getAcc(); vector::CombiningKind kind = op.getKind(); // Vector mask setup.