diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -76,6 +76,9 @@ /// Support method to enable LLVM-style type casting. static bool kindof(unsigned kind) { return kind == StandardTypes::Index; } + + /// Storage bit width used for IndexType by internal compiler data structures. + static constexpr unsigned kInternalStorageBitWidth = 64; }; /// Integer types can have arbitrary bitwidth up to a large fixed limit. diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -169,6 +169,8 @@ /// Return true of this is a signless integer or a float type. bool isSignlessIntOrFloat(); + /// Return true if this is an integer (of any signedness) or an index type. + bool isIntOrIndex(); /// Return true if this is an integer (of any signedness) or a float type. bool isIntOrFloat(); /// Return true if this is an integer (of any signedness), index, or float diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1462,7 +1462,7 @@ bool isSigned = !type.getElementType().isUnsignedInteger(); // The function used to print elements of this attribute. - auto printEltFn = type.getElementType().isa() + auto printEltFn = type.getElementType().isIntOrIndex() ? printDenseIntElement : printDenseFloatElement; diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -372,6 +372,17 @@ // Elements Attributes //===----------------------------------------------------------------------===// +/// Return the bit width which DenseElementsAttr should use for this type. +inline size_t getDenseElementBitWidth(Type eltType) { + // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored + // with double semantics. + if (eltType.isBF16()) + return 64; + if (eltType.isIndex()) + return IndexType::kInternalStorageBitWidth; + return eltType.getIntOrFloatBitWidth(); +} + /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { struct KeyTy { @@ -405,7 +416,7 @@ // same. Boolean values are packed at the bit level, and even though a splat // is detected the rest of the bits in the first byte may differ from the // splat value. - if (key.type.getElementTypeBitWidth() == 1) { + if (key.type.getElementType().isInteger(1)) { if (key.isSplat != isSplat) return false; if (isSplat) @@ -437,15 +448,10 @@ assert(numElements != 1 && "splat of 1 element should already be detected"); // Handle boolean values directly as they are packed to 1-bit. - size_t elementWidth = ty.getElementTypeBitWidth(); - if (elementWidth == 1) + if (ty.getElementType().isInteger(1) == 1) return getKeyForBoolData(ty, data, numElements); - // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored - // with double semantics. - if (ty.getElementType().isBF16()) - elementWidth = 64; - + size_t elementWidth = getDenseElementBitWidth(ty.getElementType()); // Non 1-bit dense elements are padded to 8-bits. size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT); assert(((data.size() / storageSize) == numElements) && @@ -517,7 +523,7 @@ std::memcpy(rawData, data.data(), data.size()); // If this is a boolean splat, make sure only the first bit is used. - if (key.isSplat && key.type.getElementTypeBitWidth() == 1) + if (key.isSplat && key.type.getElementType().isInteger(1)) rawData[0] &= 1; copy = ArrayRef(rawData, data.size()); } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -275,7 +275,7 @@ IntegerAttr IntegerAttr::get(Type type, int64_t value) { // This uses 64 bit APInts by default for index type. if (type.isIndex()) - return get(type, APInt(64, value)); + return get(type, APInt(IndexType::kInternalStorageBitWidth, value)); auto intType = type.cast(); return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); @@ -483,12 +483,6 @@ // DenseElementAttr Utilities //===----------------------------------------------------------------------===// -static size_t getDenseElementBitwidth(Type eltType) { - // FIXME(b/121118307): using 64 bits for BF16 because it is currently stored - // with double semantics. - return eltType.isBF16() ? 64 : eltType.getIntOrFloatBitWidth(); -} - /// Get the bitwidth of a dense element type within the buffer. /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. static size_t getDenseElementStorageWidth(size_t origWidth) { @@ -592,7 +586,7 @@ DenseElementsAttr attr, size_t dataIndex) : DenseElementIndexedIteratorImpl( attr.getRawData().data(), attr.isSplat(), dataIndex), - bitWidth(getDenseElementBitwidth(attr.getType().getElementType())) {} + bitWidth(getDenseElementBitWidth(attr.getType().getElementType())) {} /// Accesses the raw APInt value at this iterator position. APInt DenseElementsAttr::IntElementIterator::operator*() const { @@ -613,12 +607,12 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { - assert(type.getElementType().isIntOrFloat() && - "expected int or float element type"); + assert(type.getElementType().isIntOrIndexOrFloat() && + "expected int or index or float element type"); assert(hasSameElementsOrSplat(type, values)); auto eltType = type.getElementType(); - size_t bitWidth = getDenseElementBitwidth(eltType); + size_t bitWidth = getDenseElementBitWidth(eltType); size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); // Compress the attribute values into a character buffer. @@ -637,6 +631,7 @@ intVal = values[i].cast().getValue().bitcastToAPInt(); break; case StandardTypes::Integer: + case StandardTypes::Index: intVal = values[i].isa() ? APInt(1, values[i].cast().getValue() ? 1 : 0) : values[i].cast().getValue(); @@ -667,7 +662,7 @@ /// element type of 'type'. DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { - assert(type.getElementType().isa()); + assert(type.getElementType().isIntOrIndex()); return getRaw(type, values); } @@ -701,7 +696,7 @@ ArrayRef values) { assert(hasSameElementsOrSplat(type, values)); - size_t bitWidth = getDenseElementBitwidth(type.getElementType()); + size_t bitWidth = getDenseElementBitWidth(type.getElementType()); size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); std::vector elementData(llvm::divideCeil(storageBitWidth, CHAR_BIT) * values.size()); @@ -727,14 +722,17 @@ static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt, bool isSigned) { // Make sure that the data element size is the same as the type element width. - if (getDenseElementBitwidth(type.getElementType()) != + if (getDenseElementBitWidth(type.getElementType()) != static_cast(dataEltSize * CHAR_BIT)) return false; - // Check that the element type is either float or integer. + // Check that the element type is either float or integer or index. if (!isInt) return type.getElementType().isa(); + if (type.getElementType().isIndex()) + return true; + auto intType = type.getElementType().dyn_cast(); if (!intType) return false; @@ -798,18 +796,15 @@ /// this attribute must be of integer type. auto DenseElementsAttr::getIntValues() const -> llvm::iterator_range { - assert(getType().getElementType().isa() && - "expected integer type"); + assert(getType().getElementType().isIntOrIndex() && "expected integral type"); return {raw_int_begin(), raw_int_end()}; } auto DenseElementsAttr::int_value_begin() const -> IntElementIterator { - assert(getType().getElementType().isa() && - "expected integer type"); + assert(getType().getElementType().isIntOrIndex() && "expected integral type"); return raw_int_begin(); } auto DenseElementsAttr::int_value_end() const -> IntElementIterator { - assert(getType().getElementType().isa() && - "expected integer type"); + assert(getType().getElementType().isIntOrIndex() && "expected integral type"); return raw_int_end(); } @@ -870,7 +865,7 @@ static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, Type newElementType, llvm::SmallVectorImpl &data) { - size_t bitWidth = getDenseElementBitwidth(newElementType); + size_t bitWidth = getDenseElementBitWidth(newElementType); size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); ShapedType newArrayType; @@ -937,7 +932,7 @@ /// Method for supporting type inquiry through isa, cast and dyn_cast. bool DenseIntElementsAttr::classof(Attribute attr) { return attr.isa() && - attr.getType().cast().getElementType().isa(); + attr.getType().cast().getElementType().isIntOrIndex(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -83,6 +83,8 @@ return isSignlessInteger() || isa(); } +bool Type::isIntOrIndex() { return isa() || isIndex(); } + bool Type::isIntOrFloat() { return isa() || isa(); } bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); } diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1797,7 +1797,8 @@ return llvm::None; // Extend or truncate the bitwidth to the right size. - unsigned width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); + unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth + : type.getIntOrFloatBitWidth(); if (width > result.getBitWidth()) { result = result.zext(width); } else if (width < result.getBitWidth()) { @@ -1968,8 +1969,7 @@ } /// Build a Dense Integer attribute for the given type. - DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, - IntegerType eltTy); + DenseElementsAttr getIntAttr(llvm::SMLoc loc, ShapedType type, Type eltTy); /// Build a Dense Float attribute for the given type. DenseElementsAttr getFloatAttr(llvm::SMLoc loc, ShapedType type, @@ -2044,14 +2044,17 @@ // If the type is an integer, build a set of APInt values from the storage // with the correct bitwidth. - if (auto intTy = type.getElementType().dyn_cast()) + Type eltType = type.getElementType(); + if (auto intTy = eltType.dyn_cast()) return getIntAttr(loc, type, intTy); + if (auto indexTy = eltType.dyn_cast()) + return getIntAttr(loc, type, indexTy); // Otherwise, this must be a floating point type. - auto floatTy = type.getElementType().dyn_cast(); + auto floatTy = eltType.dyn_cast(); if (!floatTy) { p.emitError(loc) << "expected floating-point or integer element type, got " - << type.getElementType(); + << eltType; return nullptr; } return getFloatAttr(loc, type, floatTy); @@ -2059,8 +2062,7 @@ /// Build a Dense Integer attribute for the given type. DenseElementsAttr TensorLiteralParser::getIntAttr(llvm::SMLoc loc, - ShapedType type, - IntegerType eltTy) { + ShapedType type, Type eltTy) { std::vector intElements; intElements.reserve(storage.size()); auto isUintType = type.getElementType().isUnsignedInteger(); @@ -2085,11 +2087,12 @@ assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) && "unexpected token type"); if (token.isAny(Token::kw_true, Token::kw_false)) { - if (!eltTy.isInteger(1)) + if (!eltTy.isInteger(1)) { p.emitError(tokenLoc) << "expected i1 type for 'true' or 'false' values"; - APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), - /*isSigned=*/false); + return nullptr; + } + APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false); intElements.push_back(apInt); continue; } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -697,6 +697,11 @@ "intscalar"(){bar = dense<1> : tensor} : () -> () // CHECK: "floatscalar"() {bar = dense<5.000000e+00> : tensor} : () -> () "floatscalar"(){bar = dense<5.0> : tensor} : () -> () + +// CHECK: "index"() {bar = dense<1> : tensor} : () -> () + "index"(){bar = dense<1> : tensor} : () -> () +// CHECK: "index"() {bar = dense<[1, 2]> : tensor<2xindex>} : () -> () + "index"(){bar = dense<[1, 2]> : tensor<2xindex>} : () -> () return }