diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -162,6 +162,9 @@ /// Read a blob from the bytecode. virtual LogicalResult readBlob(ArrayRef &result) = 0; + /// Read a bool from the bytecode. + virtual LogicalResult readBool(bool &result) = 0; + private: /// Read a handle to a dialect resource. virtual FailureOr readResourceHandle() = 0; @@ -251,6 +254,9 @@ /// written as-is, with no additional compression or compaction. virtual void writeOwnedBlob(ArrayRef blob) = 0; + /// Write a bool to the output stream. + virtual void writeOwnedBool(bool value) = 0; + /// Return the bytecode version being emitted for. virtual int64_t getBytecodeVersion() const = 0; }; diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td --- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -279,13 +279,15 @@ } def VectorTypeWithScalableDims : DialectType<(type + // TODO: Missing isScalableDim! + Array:$isScalableDim, VarInt:$numScalableDims, Array:$shape, Type:$elementType )> { let printerPredicate = "$_val.getNumScalableDims()"; // Note: order of serialization does not match order of builder. - let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)"; + let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims, isScalableDim)"; } } diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -1024,8 +1024,8 @@ ``` vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` vector-element-type ::= float-type | integer-type | index-type - vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? - static-dim-list ::= decimal-literal (`x` decimal-literal)* + vector-dim-list := (static-dim-list `x`)? (static-dim-list `x`)? + static-dim-list ::= `[` decimal-literal `]` (`x` `[` decimal-literal `]`)* ``` The vector type represents a SIMD style vector used by target-specific @@ -1033,10 +1033,7 @@ vectors (e.g. vector<16 x f32>) we also support multidimensional registers on targets that support them (like TPUs). The dimensions of a vector type can be fixed-length, scalable, or a combination of the two. The scalable - dimensions in a vector are indicated between square brackets ([ ]), and - all fixed-length dimensions, if present, must precede the set of scalable - dimensions. That is, a `vector<2x[4]xf32>` is valid, but `vector<[4]x2xf32>` - is not. + dimensions in a vector are indicated between square brackets ([ ]). Vector shapes must be positive decimal integers. 0D vectors are allowed by omitting the dimension: `vector`. @@ -1055,24 +1052,38 @@ vector<[4]xf32> // A 2D scalable-length vector that contains a multiple of 2x8 f32 elements. - vector<[2x8]xf32> + vector<[2]x[8]xf32> // A 2D mixed fixed/scalable vector that contains 4 scalable vectors of 4 f32 elements. vector<4x[4]xf32> + + // A 3D mixed fixed/scalable vector in which only the inner dimension is + // scalable. + vector<2x[4]x8f32> ``` }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, - "unsigned":$numScalableDims + "unsigned":$numScalableDims, + ArrayRefParameter<"bool", "list of bools">:$isScalableDim ); let builders = [ TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, - CArg<"unsigned", "0">:$numScalableDims + CArg<"unsigned", "0">:$numScalableDims, + CArg<"ArrayRef", "{}">:$isScalableDim ), [{ + // While `isScalableDim` is optional, it's default value should be + // `false` for every dim in `shape`. + SmallVector isScalableVec; + if (isScalableDim.empty()) + { + isScalableVec = SmallVector(shape.size(), false); + isScalableDim = isScalableVec; + } return $_get(elementType.getContext(), shape, elementType, - numScalableDims); + numScalableDims, isScalableDim); }]> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td --- a/mlir/include/mlir/IR/BytecodeBase.td +++ b/mlir/include/mlir/IR/BytecodeBase.td @@ -92,6 +92,11 @@ WithBuilder<"$_args", WithPrinter<"$_writer.writeOwnedBlob($_getter)", WithType <"ArrayRef">>>>; +def Bool : + WithParser <"succeeded($_reader.readBool($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedBool($_getter)", + WithType <"bool">>>>; class KnownWidthAPInt : WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))", WithBuilder<"$_args", @@ -125,6 +130,7 @@ // for the list print/parsing. class List : WithGetter<"$_member", t>; def SignedVarIntList : List; +def BoolList : List; // Define dialect attribute or type. class DialectAttrOrType { diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -211,7 +211,8 @@ /// Parse a vector type. VectorType parseVectorType(); ParseResult parseVectorDimensionList(SmallVectorImpl &dimensions, - unsigned &numScalableDims); + unsigned &numScalableDims, + SmallVectorImpl &isScalableDim); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, bool allowDynamic = true, bool withTrailingX = true); diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -440,8 +440,9 @@ return nullptr; SmallVector dimensions; + SmallVector isScalableDim; unsigned numScalableDims; - if (parseVectorDimensionList(dimensions, numScalableDims)) + if (parseVectorDimensionList(dimensions, numScalableDims, isScalableDim)) return nullptr; if (any_of(dimensions, [](int64_t i) { return i <= 0; })) return emitError(getToken().getLoc(), @@ -458,51 +459,45 @@ return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; - return VectorType::get(dimensions, elementType, numScalableDims); + return VectorType::get(dimensions, elementType, numScalableDims, + isScalableDim); } -/// Parse a dimension list in a vector type. This populates the dimension list, -/// and returns the number of scalable dimensions in `numScalableDims`. +/// Parse a dimension list in a vector type. This populates the dimension list. +/// For i-th dimension, `isScalableDim[i]` contains either: +/// * `false` for a non-scalable dimension (e.g. `4`), +/// * `true` for a scalable dimension (e.g. `[4]`). +/// This method also returns the number of scalable dimensions in +/// `numScalableDims`. /// -/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? -/// static-dim-list ::= decimal-literal (`x` decimal-literal)* +/// vector-dim-list := (static-dim-list `x`)? (static-dim-list `x`)? +/// static-dim-list ::= `[` decimal-literal `]` (`x` `[` decimal-literal `]`)* /// ParseResult Parser::parseVectorDimensionList(SmallVectorImpl &dimensions, - unsigned &numScalableDims) { + unsigned &numScalableDims, + SmallVectorImpl &isScalableDim) { numScalableDims = 0; // If there is a set of fixed-length dimensions, consume it - while (getToken().is(Token::integer)) { + while (getToken().is(Token::integer) || getToken().is(Token::l_square)) { int64_t value; + bool scalable = false; + if (consumeIf(Token::l_square)) + scalable = true; if (parseIntegerInDimensionList(value)) return failure(); dimensions.push_back(value); + if (scalable) { + if (!consumeIf(Token::r_square)) { + return emitWrongTokenError("missing ']' closing scalable dimension"); + } + numScalableDims++; + } + isScalableDim.push_back(scalable); // Make sure we have an 'x' or something like 'xbf32'. if (parseXInDimensionList()) return failure(); } - // If there is a set of scalable dimensions, consume it - if (consumeIf(Token::l_square)) { - while (getToken().is(Token::integer)) { - int64_t value; - if (parseIntegerInDimensionList(value)) - return failure(); - dimensions.push_back(value); - numScalableDims++; - // Check if we have reached the end of the scalable dimension list - if (consumeIf(Token::r_square)) { - // Make sure we have something like 'xbf32'. - return parseXInDimensionList(); - } - // Make sure we have an 'x' - if (parseXInDimensionList()) - return failure(); - } - // If we make it here, we've finished parsing the dimension list - // without finding ']' closing the set of scalable dimensions - return emitWrongTokenError( - "missing ']' closing set of scalable dimensions"); - } return success(); } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -988,6 +988,13 @@ return success(); } + LogicalResult readBool(bool &result) override { + if (failed(reader.parseByte(result))) + return failure(); + + return success(); + } + private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -395,6 +395,8 @@ reinterpret_cast(blob.data()), blob.size())); } + void writeOwnedBool(bool value) override { emitter.emitByte(value); } + int64_t getBytecodeVersion() const override { return bytecodeVersion; } private: diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -45,6 +45,7 @@ // file locations. } void writeOwnedBlob(ArrayRef blob) override {} + void writeOwnedBool(bool value) override {} int64_t getBytecodeVersion() const override { llvm_unreachable("unexpected querying of version in IRNumbering"); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -123,7 +123,8 @@ return UnrankedTensorType::get(i1Type); if (auto vectorType = llvm::dyn_cast(type)) return VectorType::get(vectorType.getShape(), i1Type, - vectorType.getNumScalableDims()); + vectorType.getNumScalableDims(), + vectorType.getIsScalableDim()); return i1Type; } diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -30,7 +30,8 @@ auto i1Type = IntegerType::get(type.getContext(), 1); if (auto sVectorType = llvm::dyn_cast(type)) return VectorType::get(sVectorType.getShape(), i1Type, - sVectorType.getNumScalableDims()); + sVectorType.getNumScalableDims(), + sVectorType.getIsScalableDim()); return nullptr; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -992,7 +992,15 @@ return LLVMScalableVectorType::get(elementType, numElements); return LLVMFixedVectorType::get(elementType, numElements); } - return VectorType::get(numElements, elementType, (unsigned)isScalable); + + // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as + // scalable/non-scalable. + SmallVector scalableDims(1, false); + if (isScalable) + scalableDims[0] = true; + + return VectorType::get(numElements, elementType, (unsigned)isScalable, + scalableDims); } Type mlir::LLVM::getVectorType(Type elementType, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -225,7 +225,8 @@ // TODO: Extend scalable vector type to support a bit map. bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back(); - return VectorType::get(vectorShape, elementType, numScalableDims); + return VectorType::get(vectorShape, elementType, numScalableDims, + scalableDims); } /// Masks an operation with the canonical vector mask if the operation needs @@ -1216,7 +1217,8 @@ if (firstMaxRankedType) { auto vecType = VectorType::get(firstMaxRankedType.getShape(), getElementTypeOrSelf(vecOperand.getType()), - firstMaxRankedType.getNumScalableDims()); + firstMaxRankedType.getNumScalableDims(), + firstMaxRankedType.getIsScalableDim()); vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType)); } else { vecOperands.push_back(vecOperand); @@ -1228,7 +1230,8 @@ resultTypes.push_back( firstMaxRankedType ? VectorType::get(firstMaxRankedType.getShape(), resultType, - firstMaxRankedType.getNumScalableDims()) + firstMaxRankedType.getNumScalableDims(), + firstMaxRankedType.getIsScalableDim()) : resultType); } // d. Build and return the new op. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -57,7 +57,8 @@ /// Constructs vector type for element type. static VectorType vectorType(VL vl, Type etp) { unsigned numScalableDims = vl.enableVLAVectorization; - return VectorType::get(vl.vectorLength, etp, numScalableDims); + SmallVector isScalableDim(1, vl.enableVLAVectorization); + return VectorType::get(vl.vectorLength, etp, numScalableDims, isScalableDim); } /// Constructs vector type from a memref value. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -345,9 +345,9 @@ /// Returns the mask type expected by this operation. Type MultiDimReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getIsScalableDim()); } namespace { @@ -484,9 +484,9 @@ /// Returns the mask type expected by this operation. Type ReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getIsScalableDim()); } Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, @@ -2788,16 +2788,22 @@ return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); - unsigned numScalableDims = vLHS.getNumScalableDims(); VectorType resType; if (vRHS) { - numScalableDims += vRHS.getNumScalableDims(); + SmallVector isScalableDimRes{vLHS.getIsScalableDim()[0], + vRHS.getIsScalableDim()[0]}; + auto numScalableDims = + count_if(isScalableDimRes, [](bool isScalable) { return isScalable; }); resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType(), numScalableDims); + vLHS.getElementType(), numScalableDims, + isScalableDimRes); } else { // Scalar RHS operand + SmallVector isScalableDimRes{vLHS.getIsScalableDim()[0]}; + auto numScalableDims = + count_if(isScalableDimRes, [](bool isScalable) { return isScalable; }); resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), - numScalableDims); + numScalableDims, isScalableDimRes); } if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { @@ -2861,9 +2867,9 @@ /// verification purposes. It requires the operation to be vectorized." Type OuterProductOp::getExpectedMaskType() { auto vecType = this->getResultVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getIsScalableDim()); } //===----------------------------------------------------------------------===// @@ -3521,7 +3527,12 @@ "Scalable vectors are not supported yet"); assert(invPermMap && "Inversed permutation map couldn't be computed"); SmallVector maskShape = invPermMap.compose(vecType.getShape()); - return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims()); + + SmallVector isScalableDim = + applyPermutationMap(invPermMap, vecType.getIsScalableDim()); + + return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims(), + isScalableDim); } ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { @@ -4479,9 +4490,9 @@ /// verification purposes. It requires the operation to be vectorized." Type GatherOp::getExpectedMaskType() { auto vecType = this->getIndexVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getIsScalableDim()); } std::optional> GatherOp::getShapeForUnroll() { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2458,19 +2458,18 @@ } }) .Case([&](VectorType vectorTy) { + auto scalableDims = vectorTy.getIsScalableDim(); os << "vector<"; auto vShape = vectorTy.getShape(); unsigned lastDim = vShape.size(); - unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims(); unsigned dimIdx = 0; - for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++) - os << vShape[dimIdx] << 'x'; - if (vectorTy.isScalable()) { - os << '['; - unsigned secondToLastDim = lastDim - 1; - for (; dimIdx < secondToLastDim; dimIdx++) - os << vShape[dimIdx] << 'x'; - os << vShape[dimIdx] << "]x"; + for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { + if (!scalableDims.empty() && scalableDims[dimIdx]) + os << '['; + os << vShape[dimIdx]; + if (!scalableDims.empty() && scalableDims[dimIdx]) + os << ']'; + os << 'x'; } printType(vectorTy.getElementType()); os << '>'; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -227,7 +227,8 @@ LogicalResult VectorType::verify(function_ref emitError, ArrayRef shape, Type elementType, - unsigned numScalableDims) { + unsigned numScalableDims, + ArrayRef isScalableDim) { if (!isValidElementType(elementType)) return emitError() << "vector elements must be int/index/float type but got " @@ -238,6 +239,21 @@ << "vector types must have positive constant sizes but got " << shape; + if (numScalableDims > shape.size()) + return emitError() + << "number of scalable dims cannot exceed the number of dims" + << " (" << numScalableDims << " vs " << shape.size() << ")"; + + if (isScalableDim.size() != shape.size()) + return emitError() << "number of dims must match, got " + << isScalableDim.size() << " and " << shape.size(); + + auto numScale = + count_if(isScalableDim, [](bool isScalable) { return isScalable; }); + if (numScale != numScalableDims) + return emitError() << "number of scalable dims must match, explicit: " + << numScalableDims << ", and bools:" << numScale; + return success(); } diff --git a/mlir/test/Dialect/Builtin/invalid.mlir b/mlir/test/Dialect/Builtin/invalid.mlir --- a/mlir/test/Dialect/Builtin/invalid.mlir +++ b/mlir/test/Dialect/Builtin/invalid.mlir @@ -13,7 +13,10 @@ // VectorType //===----------------------------------------------------------------------===// -// expected-error@+1 {{missing ']' closing set of scalable dimensions}} +// expected-error@+1 {{missing ']' closing scalable dimension}} func.func @scalable_vector_arg(%arg0: vector<[4xf32>) { } // ----- + +// expected-error@+1 {{missing ']' closing scalable dimension}} +func.func @scalable_vector_arg(%arg0: vector<[4x4]xf32>) { } diff --git a/mlir/test/Dialect/Builtin/ops.mlir b/mlir/test/Dialect/Builtin/ops.mlir --- a/mlir/test/Dialect/Builtin/ops.mlir +++ b/mlir/test/Dialect/Builtin/ops.mlir @@ -27,10 +27,10 @@ %scalable_vector_1d = "foo.op"() : () -> vector<[4]xi32> // A 2D scalable vector -%scalable_vector_2d = "foo.op"() : () -> vector<[2x2]xf64> +%scalable_vector_2d = "foo.op"() : () -> vector<[2]x[2]xf64> // A 2D scalable vector with fixed-length dimensions %scalable_vector_2d_mixed = "foo.op"() : () -> vector<2x[4]xbf16> // A multi-dimensional vector with mixed scalable and fixed-length dimensions -%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4x4]xi8> +%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4]x[4]xi8> diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir --- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir +++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir @@ -7,7 +7,7 @@ %1 = vector.load %src[%idx] : memref, vector<[4]xf32> %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32> - vector.store %op, %src[%idx] : memref, vector<[4x4]xf32> + vector.store %op, %src[%idx] : memref, vector<[4]x[4]xf32> %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32 vector.store %op2, %src[%idx] : memref, vector<[4]xf32> @@ -28,9 +28,9 @@ func.func @invalid_outerproduct1(%src : memref) { %idx = arith.constant 0 : index - %0 = vector.load %src[%idx] : memref, vector<[4x4]xf32> + %0 = vector.load %src[%idx] : memref, vector<[4]x[4]xf32> %1 = vector.load %src[%idx] : memref, vector<[4]xf32> - // expected-error @+1 {{expected 1-d vector for operand #1}} - %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32> + // expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}} + %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32> }