diff --git a/mlir/include/mlir/Bytecode/BytecodeImplementation.h b/mlir/include/mlir/Bytecode/BytecodeImplementation.h --- a/mlir/include/mlir/Bytecode/BytecodeImplementation.h +++ b/mlir/include/mlir/Bytecode/BytecodeImplementation.h @@ -162,6 +162,9 @@ /// Read a blob from the bytecode. virtual LogicalResult readBlob(ArrayRef &result) = 0; + /// Read a bool from the bytecode. + virtual LogicalResult readBool(bool &result) = 0; + private: /// Read a handle to a dialect resource. virtual FailureOr readResourceHandle() = 0; @@ -251,6 +254,9 @@ /// written as-is, with no additional compression or compaction. virtual void writeOwnedBlob(ArrayRef blob) = 0; + /// Write a bool to the output stream. + virtual void writeOwnedBool(bool value) = 0; + /// Return the bytecode version being emitted for. virtual int64_t getBytecodeVersion() const = 0; }; diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td --- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -279,13 +279,14 @@ } def VectorTypeWithScalableDims : DialectType<(type + Array:$scalableDims, VarInt:$numScalableDims, Array:$shape, Type:$elementType )> { let printerPredicate = "$_val.getNumScalableDims()"; // Note: order of serialization does not match order of builder. - let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims)"; + let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims, scalableDims)"; } } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -306,17 +306,28 @@ /// Build from another VectorType. explicit Builder(VectorType other) : shape(other.getShape()), elementType(other.getElementType()), - numScalableDims(other.getNumScalableDims()) {} + numScalableDims(other.getNumScalableDims()), + scalableDims(other.getScalableDims()) {} /// Build from scratch. Builder(ArrayRef shape, Type elementType, - unsigned numScalableDims = 0) + unsigned numScalableDims = 0, ArrayRef scalableDims = {}) : shape(shape), elementType(elementType), - numScalableDims(numScalableDims) {} + numScalableDims(numScalableDims) { + if (scalableDims.empty()) + scalableDims = SmallVector(shape.size(), false); + else + this->scalableDims = scalableDims; + } - Builder &setShape(ArrayRef newShape, - unsigned newNumScalableDims = 0) { + Builder &setShape(ArrayRef newShape, unsigned newNumScalableDims = 0, + ArrayRef newIsScalableDim = {}) { numScalableDims = newNumScalableDims; + if (newIsScalableDim.empty()) + scalableDims = SmallVector(shape.size(), false); + else + scalableDims = newIsScalableDim; + shape = newShape; return *this; } @@ -333,8 +344,13 @@ numScalableDims--; if (storage.empty()) storage.append(shape.begin(), shape.end()); + if (storageScalableDims.empty()) + storageScalableDims.append(scalableDims.begin(), scalableDims.end()); storage.erase(storage.begin() + pos); + storageScalableDims.erase(storageScalableDims.begin() + pos); shape = {storage.data(), storage.size()}; + scalableDims = + ArrayRef(storageScalableDims.data(), storageScalableDims.size()); return *this; } @@ -344,7 +360,7 @@ operator Type() { if (shape.empty()) return elementType; - return VectorType::get(shape, elementType, numScalableDims); + return VectorType::get(shape, elementType, numScalableDims, scalableDims); } private: @@ -353,6 +369,9 @@ SmallVector storage; Type elementType; unsigned numScalableDims; + ArrayRef scalableDims; + // Owning scalableDims data for copy-on-write operations. + SmallVector storageScalableDims; }; /// Given an `originalShape` and a `reducedShape` assumed to be a subset of diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -1024,8 +1024,9 @@ ``` vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` vector-element-type ::= float-type | integer-type | index-type - vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? - static-dim-list ::= decimal-literal (`x` decimal-literal)* + vector-dim-list := (static-dim-list `x`)? + static-dim-list ::= static-dim (`x` static-dim)* + static-dim ::= (decimal-literal | `[` decimal-literal `]`) ``` The vector type represents a SIMD style vector used by target-specific @@ -1033,10 +1034,7 @@ vectors (e.g. vector<16 x f32>) we also support multidimensional registers on targets that support them (like TPUs). The dimensions of a vector type can be fixed-length, scalable, or a combination of the two. The scalable - dimensions in a vector are indicated between square brackets ([ ]), and - all fixed-length dimensions, if present, must precede the set of scalable - dimensions. That is, a `vector<2x[4]xf32>` is valid, but `vector<[4]x2xf32>` - is not. + dimensions in a vector are indicated between square brackets ([ ]). Vector shapes must be positive decimal integers. 0D vectors are allowed by omitting the dimension: `vector`. @@ -1055,24 +1053,37 @@ vector<[4]xf32> // A 2D scalable-length vector that contains a multiple of 2x8 f32 elements. - vector<[2x8]xf32> + vector<[2]x[8]xf32> // A 2D mixed fixed/scalable vector that contains 4 scalable vectors of 4 f32 elements. vector<4x[4]xf32> + + // A 3D mixed fixed/scalable vector in which only the inner dimension is + // scalable. + vector<2x[4]x8xf32> ``` }]; let parameters = (ins ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, - "unsigned":$numScalableDims + "unsigned":$numScalableDims, + ArrayRefParameter<"bool">:$scalableDims ); let builders = [ TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, - CArg<"unsigned", "0">:$numScalableDims + CArg<"unsigned", "0">:$numScalableDims, + CArg<"ArrayRef", "{}">:$scalableDims ), [{ + // While `scalableDims` is optional, its default value should be + // `false` for every dim in `shape`. + SmallVector isScalableVec; + if (scalableDims.empty()) { + isScalableVec.resize(shape.size(), false); + scalableDims = isScalableVec; + } return $_get(elementType.getContext(), shape, elementType, - numScalableDims); + numScalableDims, scalableDims); }]> ]; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/IR/BytecodeBase.td b/mlir/include/mlir/IR/BytecodeBase.td --- a/mlir/include/mlir/IR/BytecodeBase.td +++ b/mlir/include/mlir/IR/BytecodeBase.td @@ -92,6 +92,11 @@ WithBuilder<"$_args", WithPrinter<"$_writer.writeOwnedBlob($_getter)", WithType <"ArrayRef">>>>; +def Bool : + WithParser <"succeeded($_reader.readBool($_var))", + WithBuilder<"$_args", + WithPrinter<"$_writer.writeOwnedBool($_getter)", + WithType <"bool">>>>; class KnownWidthAPInt : WithParser <"succeeded(readAPIntWithKnownWidth($_reader, " # s # ", $_var))", WithBuilder<"$_args", @@ -125,6 +130,7 @@ // for the list print/parsing. class List : WithGetter<"$_member", t>; def SignedVarIntList : List; +def BoolList : List; // Define dialect attribute or type. class DialectAttrOrType { diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -211,7 +211,8 @@ /// Parse a vector type. VectorType parseVectorType(); ParseResult parseVectorDimensionList(SmallVectorImpl &dimensions, - unsigned &numScalableDims); + unsigned &numScalableDims, + SmallVectorImpl &scalableDims); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, bool allowDynamic = true, bool withTrailingX = true); diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -440,8 +440,9 @@ return nullptr; SmallVector dimensions; + SmallVector scalableDims; unsigned numScalableDims; - if (parseVectorDimensionList(dimensions, numScalableDims)) + if (parseVectorDimensionList(dimensions, numScalableDims, scalableDims)) return nullptr; if (any_of(dimensions, [](int64_t i) { return i <= 0; })) return emitError(getToken().getLoc(), @@ -458,51 +459,43 @@ return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; - return VectorType::get(dimensions, elementType, numScalableDims); + return VectorType::get(dimensions, elementType, numScalableDims, + scalableDims); } -/// Parse a dimension list in a vector type. This populates the dimension list, -/// and returns the number of scalable dimensions in `numScalableDims`. +/// Parse a dimension list in a vector type. This populates the dimension list. +/// For i-th dimension, `scalableDims[i]` contains either: +/// * `false` for a non-scalable dimension (e.g. `4`), +/// * `true` for a scalable dimension (e.g. `[4]`). +/// This method also returns the number of scalable dimensions in +/// `numScalableDims`. /// -/// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? -/// static-dim-list ::= decimal-literal (`x` decimal-literal)* +/// vector-dim-list := (static-dim-list `x`)? +/// static-dim-list ::= static-dim (`x` static-dim)* +/// static-dim ::= (decimal-literal | `[` decimal-literal `]`) /// ParseResult Parser::parseVectorDimensionList(SmallVectorImpl &dimensions, - unsigned &numScalableDims) { + unsigned &numScalableDims, + SmallVectorImpl &scalableDims) { numScalableDims = 0; // If there is a set of fixed-length dimensions, consume it - while (getToken().is(Token::integer)) { + while (getToken().is(Token::integer) || getToken().is(Token::l_square)) { int64_t value; + bool scalable = consumeIf(Token::l_square); if (parseIntegerInDimensionList(value)) return failure(); dimensions.push_back(value); + if (scalable) { + if (!consumeIf(Token::r_square)) + return emitWrongTokenError("missing ']' closing scalable dimension"); + numScalableDims++; + } + scalableDims.push_back(scalable); // Make sure we have an 'x' or something like 'xbf32'. if (parseXInDimensionList()) return failure(); } - // If there is a set of scalable dimensions, consume it - if (consumeIf(Token::l_square)) { - while (getToken().is(Token::integer)) { - int64_t value; - if (parseIntegerInDimensionList(value)) - return failure(); - dimensions.push_back(value); - numScalableDims++; - // Check if we have reached the end of the scalable dimension list - if (consumeIf(Token::r_square)) { - // Make sure we have something like 'xbf32'. - return parseXInDimensionList(); - } - // Make sure we have an 'x' - if (parseXInDimensionList()) - return failure(); - } - // If we make it here, we've finished parsing the dimension list - // without finding ']' closing the set of scalable dimensions - return emitWrongTokenError( - "missing ']' closing set of scalable dimensions"); - } return success(); } diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -994,6 +994,10 @@ return success(); } + LogicalResult readBool(bool &result) override { + return reader.parseByte(result); + } + private: AttrTypeReader &attrTypeReader; StringSectionReader &stringReader; diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -396,6 +396,8 @@ reinterpret_cast(blob.data()), blob.size())); } + void writeOwnedBool(bool value) override { emitter.emitByte(value); } + int64_t getBytecodeVersion() const override { return bytecodeVersion; } private: diff --git a/mlir/lib/Bytecode/Writer/IRNumbering.cpp b/mlir/lib/Bytecode/Writer/IRNumbering.cpp --- a/mlir/lib/Bytecode/Writer/IRNumbering.cpp +++ b/mlir/lib/Bytecode/Writer/IRNumbering.cpp @@ -45,6 +45,7 @@ // file locations. } void writeOwnedBlob(ArrayRef blob) override {} + void writeOwnedBool(bool value) override {} int64_t getBytecodeVersion() const override { llvm_unreachable("unexpected querying of version in IRNumbering"); diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -463,8 +463,9 @@ return {}; if (type.getShape().empty()) return VectorType::get({1}, elementType); - Type vectorType = VectorType::get(type.getShape().back(), elementType, - type.getNumScalableDims()); + Type vectorType = + VectorType::get(type.getShape().back(), elementType, + type.getNumScalableDims(), type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); auto shape = type.getShape(); diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -123,7 +123,8 @@ return UnrankedTensorType::get(i1Type); if (auto vectorType = llvm::dyn_cast(type)) return VectorType::get(vectorType.getShape(), i1Type, - vectorType.getNumScalableDims()); + vectorType.getNumScalableDims(), + vectorType.getScalableDims()); return i1Type; } diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -30,7 +30,8 @@ auto i1Type = IntegerType::get(type.getContext(), 1); if (auto sVectorType = llvm::dyn_cast(type)) return VectorType::get(sVectorType.getShape(), i1Type, - sVectorType.getNumScalableDims()); + sVectorType.getNumScalableDims(), + sVectorType.getScalableDims()); return nullptr; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -992,7 +992,13 @@ return LLVMScalableVectorType::get(elementType, numElements); return LLVMFixedVectorType::get(elementType, numElements); } - return VectorType::get(numElements, elementType, (unsigned)isScalable); + + // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as + // scalable/non-scalable. + SmallVector scalableDims(1, isScalable); + + return VectorType::get(numElements, elementType, + static_cast(isScalable), scalableDims); } Type mlir::LLVM::getVectorType(Type elementType, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -225,7 +225,8 @@ // TODO: Extend scalable vector type to support a bit map. bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back(); - return VectorType::get(vectorShape, elementType, numScalableDims); + return VectorType::get(vectorShape, elementType, numScalableDims, + scalableDims); } /// Masks an operation with the canonical vector mask if the operation needs @@ -1227,7 +1228,8 @@ if (firstMaxRankedType) { auto vecType = VectorType::get(firstMaxRankedType.getShape(), getElementTypeOrSelf(vecOperand.getType()), - firstMaxRankedType.getNumScalableDims()); + firstMaxRankedType.getNumScalableDims(), + firstMaxRankedType.getScalableDims()); vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType)); } else { vecOperands.push_back(vecOperand); @@ -1239,7 +1241,8 @@ resultTypes.push_back( firstMaxRankedType ? VectorType::get(firstMaxRankedType.getShape(), resultType, - firstMaxRankedType.getNumScalableDims()) + firstMaxRankedType.getNumScalableDims(), + firstMaxRankedType.getScalableDims()) : resultType); } // d. Build and return the new op. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -57,7 +57,8 @@ /// Constructs vector type for element type. static VectorType vectorType(VL vl, Type etp) { unsigned numScalableDims = vl.enableVLAVectorization; - return VectorType::get(vl.vectorLength, etp, numScalableDims); + return VectorType::get(vl.vectorLength, etp, numScalableDims, + vl.enableVLAVectorization); } /// Constructs vector type from a memref value. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -345,9 +345,9 @@ /// Returns the mask type expected by this operation. Type MultiDimReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getScalableDims()); } namespace { @@ -484,9 +484,9 @@ /// Returns the mask type expected by this operation. Type ReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getScalableDims()); } Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, @@ -2788,16 +2788,22 @@ return parser.emitError(parser.getNameLoc(), "expected vector type for operand #1"); - unsigned numScalableDims = vLHS.getNumScalableDims(); VectorType resType; if (vRHS) { - numScalableDims += vRHS.getNumScalableDims(); + SmallVector scalableDimsRes{vLHS.getScalableDims()[0], + vRHS.getScalableDims()[0]}; + auto numScalableDims = + count_if(scalableDimsRes, [](bool isScalable) { return isScalable; }); resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType(), numScalableDims); + vLHS.getElementType(), numScalableDims, + scalableDimsRes); } else { // Scalar RHS operand + SmallVector scalableDimsRes{vLHS.getScalableDims()[0]}; + auto numScalableDims = + count_if(scalableDimsRes, [](bool isScalable) { return isScalable; }); resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), - numScalableDims); + numScalableDims, scalableDimsRes); } if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { @@ -2861,9 +2867,9 @@ /// verification purposes. It requires the operation to be vectorized." Type OuterProductOp::getExpectedMaskType() { auto vecType = this->getResultVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getScalableDims()); } //===----------------------------------------------------------------------===// @@ -3516,12 +3522,14 @@ AffineMap permMap) { auto i1Type = IntegerType::get(permMap.getContext(), 1); AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); - // TODO: Extend the scalable vector type representation with a bit map. - assert((permMap.isMinorIdentity() || vecType.getNumScalableDims() == 0) && - "Scalable vectors are not supported yet"); assert(invPermMap && "Inversed permutation map couldn't be computed"); SmallVector maskShape = invPermMap.compose(vecType.getShape()); - return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims()); + + SmallVector scalableDims = + applyPermutationMap(invPermMap, vecType.getScalableDims()); + + return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims(), + scalableDims); } ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { @@ -4479,9 +4487,9 @@ /// verification purposes. It requires the operation to be vectorized." Type GatherOp::getExpectedMaskType() { auto vecType = this->getIndexVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims()); + return VectorType::get( + vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getNumScalableDims(), vecType.getScalableDims()); } std::optional> GatherOp::getShapeForUnroll() { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2458,19 +2458,18 @@ } }) .Case([&](VectorType vectorTy) { + auto scalableDims = vectorTy.getScalableDims(); os << "vector<"; auto vShape = vectorTy.getShape(); unsigned lastDim = vShape.size(); - unsigned lastFixedDim = lastDim - vectorTy.getNumScalableDims(); unsigned dimIdx = 0; - for (dimIdx = 0; dimIdx < lastFixedDim; dimIdx++) - os << vShape[dimIdx] << 'x'; - if (vectorTy.isScalable()) { - os << '['; - unsigned secondToLastDim = lastDim - 1; - for (; dimIdx < secondToLastDim; dimIdx++) - os << vShape[dimIdx] << 'x'; - os << vShape[dimIdx] << "]x"; + for (dimIdx = 0; dimIdx < lastDim; dimIdx++) { + if (!scalableDims.empty() && scalableDims[dimIdx]) + os << '['; + os << vShape[dimIdx]; + if (!scalableDims.empty() && scalableDims[dimIdx]) + os << ']'; + os << 'x'; } printType(vectorTy.getElementType()); os << '>'; diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -227,7 +227,8 @@ LogicalResult VectorType::verify(function_ref emitError, ArrayRef shape, Type elementType, - unsigned numScalableDims) { + unsigned numScalableDims, + ArrayRef scalableDims) { if (!isValidElementType(elementType)) return emitError() << "vector elements must be int/index/float type but got " @@ -238,6 +239,21 @@ << "vector types must have positive constant sizes but got " << shape; + if (numScalableDims > shape.size()) + return emitError() + << "number of scalable dims cannot exceed the number of dims" + << " (" << numScalableDims << " vs " << shape.size() << ")"; + + if (scalableDims.size() != shape.size()) + return emitError() << "number of dims must match, got " + << scalableDims.size() << " and " << shape.size(); + + auto numScale = + count_if(scalableDims, [](bool isScalable) { return isScalable; }); + if (numScale != numScalableDims) + return emitError() << "number of scalable dims must match, explicit: " + << numScalableDims << ", and bools:" << numScale; + return success(); } diff --git a/mlir/test/Dialect/Builtin/invalid.mlir b/mlir/test/Dialect/Builtin/invalid.mlir --- a/mlir/test/Dialect/Builtin/invalid.mlir +++ b/mlir/test/Dialect/Builtin/invalid.mlir @@ -13,7 +13,10 @@ // VectorType //===----------------------------------------------------------------------===// -// expected-error@+1 {{missing ']' closing set of scalable dimensions}} +// expected-error@+1 {{missing ']' closing scalable dimension}} func.func @scalable_vector_arg(%arg0: vector<[4xf32>) { } // ----- + +// expected-error@+1 {{missing ']' closing scalable dimension}} +func.func @scalable_vector_arg(%arg0: vector<[4x4]xf32>) { } diff --git a/mlir/test/Dialect/Builtin/ops.mlir b/mlir/test/Dialect/Builtin/ops.mlir --- a/mlir/test/Dialect/Builtin/ops.mlir +++ b/mlir/test/Dialect/Builtin/ops.mlir @@ -27,10 +27,10 @@ %scalable_vector_1d = "foo.op"() : () -> vector<[4]xi32> // A 2D scalable vector -%scalable_vector_2d = "foo.op"() : () -> vector<[2x2]xf64> +%scalable_vector_2d = "foo.op"() : () -> vector<[2]x[2]xf64> // A 2D scalable vector with fixed-length dimensions %scalable_vector_2d_mixed = "foo.op"() : () -> vector<2x[4]xbf16> // A multi-dimensional vector with mixed scalable and fixed-length dimensions -%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4x4]xi8> +%scalable_vector_multi_mixed = "foo.op"() : () -> vector<2x2x[4]x[4]xi8> diff --git a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir --- a/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir +++ b/mlir/test/Dialect/Vector/vector-scalable-outerproduct.mlir @@ -7,7 +7,7 @@ %1 = vector.load %src[%idx] : memref, vector<[4]xf32> %op = vector.outerproduct %0, %1 : vector<[4]xf32>, vector<[4]xf32> - vector.store %op, %src[%idx] : memref, vector<[4x4]xf32> + vector.store %op, %src[%idx] : memref, vector<[4]x[4]xf32> %op2 = vector.outerproduct %0, %cst : vector<[4]xf32>, f32 vector.store %op2, %src[%idx] : memref, vector<[4]xf32> @@ -28,9 +28,9 @@ func.func @invalid_outerproduct1(%src : memref) { %idx = arith.constant 0 : index - %0 = vector.load %src[%idx] : memref, vector<[4x4]xf32> + %0 = vector.load %src[%idx] : memref, vector<[4]x[4]xf32> %1 = vector.load %src[%idx] : memref, vector<[4]xf32> - // expected-error @+1 {{expected 1-d vector for operand #1}} - %op = vector.outerproduct %0, %1 : vector<[4x4]xf32>, vector<[4]xf32> + // expected-error @+1 {{'vector.outerproduct' op expected 1-d vector for operand #1}} + %op = vector.outerproduct %0, %1 : vector<[4]x[4]xf32>, vector<[4]xf32> }