diff --git a/mlir/include/mlir/IR/BuiltinDialectBytecode.td b/mlir/include/mlir/IR/BuiltinDialectBytecode.td --- a/mlir/include/mlir/IR/BuiltinDialectBytecode.td +++ b/mlir/include/mlir/IR/BuiltinDialectBytecode.td @@ -275,18 +275,17 @@ Array:$shape, Type:$elementType )> { - let printerPredicate = "!$_val.getNumScalableDims()"; + let printerPredicate = "!$_val.isScalable()"; } def VectorTypeWithScalableDims : DialectType<(type Array:$scalableDims, - VarInt:$numScalableDims, Array:$shape, Type:$elementType )> { - let printerPredicate = "$_val.getNumScalableDims()"; + let printerPredicate = "$_val.isScalable()"; // Note: order of serialization does not match order of builder. - let cBuilder = "get<$_resultType>(context, shape, elementType, numScalableDims, scalableDims)"; + let cBuilder = "get<$_resultType>(context, shape, elementType, scalableDims)"; } } diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h --- a/mlir/include/mlir/IR/BuiltinTypes.h +++ b/mlir/include/mlir/IR/BuiltinTypes.h @@ -306,23 +306,20 @@ /// Build from another VectorType. explicit Builder(VectorType other) : shape(other.getShape()), elementType(other.getElementType()), - numScalableDims(other.getNumScalableDims()), scalableDims(other.getScalableDims()) {} /// Build from scratch. Builder(ArrayRef shape, Type elementType, unsigned numScalableDims = 0, ArrayRef scalableDims = {}) - : shape(shape), elementType(elementType), - numScalableDims(numScalableDims) { + : shape(shape), elementType(elementType) { if (scalableDims.empty()) scalableDims = SmallVector(shape.size(), false); else this->scalableDims = scalableDims; } - Builder &setShape(ArrayRef newShape, unsigned newNumScalableDims = 0, + Builder &setShape(ArrayRef newShape, ArrayRef newIsScalableDim = {}) { - numScalableDims = newNumScalableDims; if (newIsScalableDim.empty()) scalableDims = SmallVector(shape.size(), false); else @@ -340,8 +337,6 @@ /// Erase a dim from shape @pos. Builder &dropDim(unsigned pos) { assert(pos < shape.size() && "overflow"); - if (pos >= shape.size() - numScalableDims) - numScalableDims--; if (storage.empty()) storage.append(shape.begin(), shape.end()); if (storageScalableDims.empty()) @@ -360,7 +355,7 @@ operator Type() { if (shape.empty()) return elementType; - return VectorType::get(shape, elementType, numScalableDims, scalableDims); + return VectorType::get(shape, elementType, scalableDims); } private: @@ -368,7 +363,6 @@ // Owning shape data for copy-on-write operations. SmallVector storage; Type elementType; - unsigned numScalableDims; ArrayRef scalableDims; // Owning scalableDims data for copy-on-write operations. SmallVector storageScalableDims; diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -1066,13 +1066,11 @@ let parameters = (ins ArrayRefParameter<"int64_t">:$shape, "Type":$elementType, - "unsigned":$numScalableDims, ArrayRefParameter<"bool">:$scalableDims ); let builders = [ TypeBuilderWithInferredContext<(ins "ArrayRef":$shape, "Type":$elementType, - CArg<"unsigned", "0">:$numScalableDims, CArg<"ArrayRef", "{}">:$scalableDims ), [{ // While `scalableDims` is optional, its default value should be @@ -1082,8 +1080,7 @@ isScalableVec.resize(shape.size(), false); scalableDims = isScalableVec; } - return $_get(elementType.getContext(), shape, elementType, - numScalableDims, scalableDims); + return $_get(elementType.getContext(), shape, elementType, scalableDims); }]> ]; let extraClassDeclaration = [{ @@ -1100,7 +1097,13 @@ /// Returns true if the vector contains scalable dimensions. bool isScalable() const { - return getNumScalableDims() > 0; + return llvm::is_contained(getScalableDims(), true); + } + bool allDimsScalable() const { + // Treat 0-d vectors as fixed size. + if (getRank() == 0) + return false; + return !llvm::is_contained(getScalableDims(), false); } /// Get or create a new VectorType with the same shape as `this` and an diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -211,7 +211,6 @@ /// Parse a vector type. VectorType parseVectorType(); ParseResult parseVectorDimensionList(SmallVectorImpl &dimensions, - unsigned &numScalableDims, SmallVectorImpl &scalableDims); ParseResult parseDimensionListRanked(SmallVectorImpl &dimensions, bool allowDynamic = true, diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -441,8 +441,7 @@ SmallVector dimensions; SmallVector scalableDims; - unsigned numScalableDims; - if (parseVectorDimensionList(dimensions, numScalableDims, scalableDims)) + if (parseVectorDimensionList(dimensions, scalableDims)) return nullptr; if (any_of(dimensions, [](int64_t i) { return i <= 0; })) return emitError(getToken().getLoc(), @@ -459,16 +458,13 @@ return emitError(typeLoc, "vector elements must be int/index/float type"), nullptr; - return VectorType::get(dimensions, elementType, numScalableDims, - scalableDims); + return VectorType::get(dimensions, elementType, scalableDims); } /// Parse a dimension list in a vector type. This populates the dimension list. /// For i-th dimension, `scalableDims[i]` contains either: /// * `false` for a non-scalable dimension (e.g. `4`), /// * `true` for a scalable dimension (e.g. `[4]`). -/// This method also returns the number of scalable dimensions in -/// `numScalableDims`. /// /// vector-dim-list := (static-dim-list `x`)? /// static-dim-list ::= static-dim (`x` static-dim)* @@ -476,9 +472,7 @@ /// ParseResult Parser::parseVectorDimensionList(SmallVectorImpl &dimensions, - unsigned &numScalableDims, SmallVectorImpl &scalableDims) { - numScalableDims = 0; // If there is a set of fixed-length dimensions, consume it while (getToken().is(Token::integer) || getToken().is(Token::l_square)) { int64_t value; @@ -489,7 +483,6 @@ if (scalable) { if (!consumeIf(Token::r_square)) return emitWrongTokenError("missing ']' closing scalable dimension"); - numScalableDims++; } scalableDims.push_back(scalable); // Make sure we have an 'x' or something like 'xbf32'. diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -463,11 +463,12 @@ return {}; if (type.getShape().empty()) return VectorType::get({1}, elementType); - Type vectorType = - VectorType::get(type.getShape().back(), elementType, - type.getNumScalableDims(), type.getScalableDims().back()); + Type vectorType = VectorType::get(type.getShape().back(), elementType, + type.getScalableDims().back()); assert(LLVM::isCompatibleVectorType(vectorType) && "expected vector type compatible with the LLVM dialect"); + assert((type.isScalable() == type.allDimsScalable()) && + "expected scalable vector with all dims scalable"); auto shape = type.getShape(); for (int i = shape.size() - 2; i >= 0; --i) vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -31,21 +31,15 @@ // Helper to reduce vector type by one rank at front. static VectorType reducedVectorTypeFront(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - unsigned numScalableDims = tp.getNumScalableDims(); - if (tp.getShape().size() == numScalableDims) - --numScalableDims; return VectorType::get(tp.getShape().drop_front(), tp.getElementType(), - numScalableDims); + tp.getScalableDims().drop_front()); } // Helper to reduce vector type by *all* but one rank at back. static VectorType reducedVectorTypeBack(VectorType tp) { assert((tp.getRank() > 1) && "unlowerable vector type"); - unsigned numScalableDims = tp.getNumScalableDims(); - if (numScalableDims > 0) - --numScalableDims; return VectorType::get(tp.getShape().take_back(), tp.getElementType(), - numScalableDims); + tp.getScalableDims().take_back()); } // Helper that picks the proper sequence for inserting. diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -123,7 +123,6 @@ return UnrankedTensorType::get(i1Type); if (auto vectorType = llvm::dyn_cast(type)) return VectorType::get(vectorType.getShape(), i1Type, - vectorType.getNumScalableDims(), vectorType.getScalableDims()); return i1Type; } diff --git a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp --- a/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp +++ b/mlir/lib/Dialect/ArmSVE/IR/ArmSVEDialect.cpp @@ -30,7 +30,6 @@ auto i1Type = IntegerType::get(type.getContext(), 1); if (auto sVectorType = llvm::dyn_cast(type)) return VectorType::get(sVectorType.getShape(), i1Type, - sVectorType.getNumScalableDims(), sVectorType.getScalableDims()); return nullptr; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -995,10 +995,7 @@ // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as // scalable/non-scalable. - SmallVector scalableDims(1, isScalable); - - return VectorType::get(numElements, elementType, - static_cast(isScalable), scalableDims); + return VectorType::get(numElements, elementType, {isScalable}); } Type mlir::LLVM::getVectorType(Type elementType, @@ -1030,7 +1027,10 @@ "type"); if (useLLVM) return LLVMScalableVectorType::get(elementType, numElements); - return VectorType::get(numElements, elementType, /*numScalableDims=*/1); + + // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as + // scalable/non-scalable. + return VectorType::get(numElements, elementType, /*scalableDims=*/true); } llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -223,10 +223,7 @@ assert(areValidScalableVecDims(scalableDims) && "Permuted scalable vector dimensions are not supported"); - // TODO: Extend scalable vector type to support a bit map. - bool numScalableDims = !scalableVecDims.empty() && scalableVecDims.back(); - return VectorType::get(vectorShape, elementType, numScalableDims, - scalableDims); + return VectorType::get(vectorShape, elementType, scalableDims); } /// Masks an operation with the canonical vector mask if the operation needs @@ -1228,7 +1225,6 @@ if (firstMaxRankedType) { auto vecType = VectorType::get(firstMaxRankedType.getShape(), getElementTypeOrSelf(vecOperand.getType()), - firstMaxRankedType.getNumScalableDims(), firstMaxRankedType.getScalableDims()); vecOperands.push_back(broadcastIfNeeded(rewriter, vecOperand, vecType)); } else { @@ -1241,7 +1237,6 @@ resultTypes.push_back( firstMaxRankedType ? VectorType::get(firstMaxRankedType.getShape(), resultType, - firstMaxRankedType.getNumScalableDims(), firstMaxRankedType.getScalableDims()) : resultType); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -56,9 +56,7 @@ /// Constructs vector type for element type. static VectorType vectorType(VL vl, Type etp) { - unsigned numScalableDims = vl.enableVLAVectorization; - return VectorType::get(vl.vectorLength, etp, numScalableDims, - vl.enableVLAVectorization); + return VectorType::get(vl.vectorLength, etp, vl.enableVLAVectorization); } /// Constructs vector type from a memref value. diff --git a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp --- a/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp +++ b/mlir/lib/Dialect/SparseTensor/Utils/Merger.cpp @@ -1176,7 +1176,7 @@ // Inspect source type. For vector types, apply the same // vectorization to the destination type. if (auto vtp = dyn_cast(src.getType())) - return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); + return VectorType::get(vtp.getNumElements(), dtp, vtp.getScalableDims()); return dtp; } diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -345,9 +345,9 @@ /// Returns the mask type expected by this operation. Type MultiDimReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get( - vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims(), vecType.getScalableDims()); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getScalableDims()); } namespace { @@ -484,9 +484,9 @@ /// Returns the mask type expected by this operation. Type ReductionOp::getExpectedMaskType() { auto vecType = getSourceVectorType(); - return VectorType::get( - vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims(), vecType.getScalableDims()); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getScalableDims()); } Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, @@ -929,8 +929,7 @@ assert(!ShapedType::isDynamicShape(maskShape) && "Mask shape couldn't be computed"); // TODO: Extend the scalable vector type representation with a bit map. - assert(lhsType.getNumScalableDims() == 0 && - rhsType.getNumScalableDims() == 0 && + assert(!lhsType.isScalable() && !rhsType.isScalable() && "Scalable vectors are not supported yet"); return VectorType::get(maskShape, @@ -2792,18 +2791,13 @@ if (vRHS) { SmallVector scalableDimsRes{vLHS.getScalableDims()[0], vRHS.getScalableDims()[0]}; - auto numScalableDims = - count_if(scalableDimsRes, [](bool isScalable) { return isScalable; }); resType = VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, - vLHS.getElementType(), numScalableDims, - scalableDimsRes); + vLHS.getElementType(), scalableDimsRes); } else { // Scalar RHS operand SmallVector scalableDimsRes{vLHS.getScalableDims()[0]}; - auto numScalableDims = - count_if(scalableDimsRes, [](bool isScalable) { return isScalable; }); resType = VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType(), - numScalableDims, scalableDimsRes); + scalableDimsRes); } if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { @@ -2867,9 +2861,9 @@ /// verification purposes. It requires the operation to be vectorized." Type OuterProductOp::getExpectedMaskType() { auto vecType = this->getResultVectorType(); - return VectorType::get( - vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims(), vecType.getScalableDims()); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getScalableDims()); } //===----------------------------------------------------------------------===// @@ -3528,8 +3522,7 @@ SmallVector scalableDims = applyPermutationMap(invPermMap, vecType.getScalableDims()); - return VectorType::get(maskShape, i1Type, vecType.getNumScalableDims(), - scalableDims); + return VectorType::get(maskShape, i1Type, scalableDims); } ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { @@ -4487,9 +4480,9 @@ /// verification purposes. It requires the operation to be vectorized." Type GatherOp::getExpectedMaskType() { auto vecType = this->getIndexVectorType(); - return VectorType::get( - vecType.getShape(), IntegerType::get(vecType.getContext(), /*width=*/1), - vecType.getNumScalableDims(), vecType.getScalableDims()); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1), + vecType.getScalableDims()); } std::optional> GatherOp::getShapeForUnroll() { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -1024,7 +1024,7 @@ Value mask = rewriter.create( loc, VectorType::get(vtp.getShape(), rewriter.getI1Type(), - vtp.getNumScalableDims()), + vtp.getScalableDims()), b); if (xferOp.getMask()) { // Intersect the in-bounds with the mask specified as an op parameter. diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -227,7 +227,6 @@ LogicalResult VectorType::verify(function_ref emitError, ArrayRef shape, Type elementType, - unsigned numScalableDims, ArrayRef scalableDims) { if (!isValidElementType(elementType)) return emitError() @@ -239,21 +238,10 @@ << "vector types must have positive constant sizes but got " << shape; - if (numScalableDims > shape.size()) - return emitError() - << "number of scalable dims cannot exceed the number of dims" - << " (" << numScalableDims << " vs " << shape.size() << ")"; - if (scalableDims.size() != shape.size()) return emitError() << "number of dims must match, got " << scalableDims.size() << " and " << shape.size(); - auto numScale = - count_if(scalableDims, [](bool isScalable) { return isScalable; }); - if (numScale != numScalableDims) - return emitError() << "number of scalable dims must match, explicit: " - << numScalableDims << ", and bools:" << numScale; - return success(); } @@ -262,17 +250,17 @@ return VectorType(); if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) - return VectorType::get(getShape(), scaledEt, getNumScalableDims()); + return VectorType::get(getShape(), scaledEt, getScalableDims()); if (auto et = llvm::dyn_cast(getElementType())) if (auto scaledEt = et.scaleElementBitwidth(scale)) - return VectorType::get(getShape(), scaledEt, getNumScalableDims()); + return VectorType::get(getShape(), scaledEt, getScalableDims()); return VectorType(); } VectorType VectorType::cloneWith(std::optional> shape, Type elementType) const { return VectorType::get(shape.value_or(getShape()), elementType, - getNumScalableDims()); + getScalableDims()); } //===----------------------------------------------------------------------===//