diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md --- a/mlir/docs/LangRef.md +++ b/mlir/docs/LangRef.md @@ -745,11 +745,16 @@ ``` // Sized integers like i1, i4, i8, i16, i32. -integer-type ::= `i` [1-9][0-9]* +signed-integer-type ::= `si` [1-9][0-9]* +unsigned-integer-type ::= `ui` [1-9][0-9]* +signless-integer-type ::= `i` [1-9][0-9]* +integer-type ::= signed-integer-type | + unsigned-integer-type | + signless-integer-type ``` -MLIR supports arbitrary precision integer types. Integer types are signless, but -have a designated width. +MLIR supports arbitrary precision integer types. Integer types have a designated +width and may have signedness semantics. **Rationale:** low precision integers (like `i2`, `i4` etc) are useful for low-precision inference chips, and arbitrary precision integers are useful for diff --git a/mlir/docs/Rationale.md b/mlir/docs/Rationale.md --- a/mlir/docs/Rationale.md +++ b/mlir/docs/Rationale.md @@ -244,13 +244,22 @@ The bit width is not defined for dialect-specific types at MLIR level. Dialects are free to define their own quantities for type sizes. -### Signless types +### Integer signedness semantics Integers in the builtin MLIR type system have a bitwidth (note that the `index` -type has a symbolic width equal to the machine word size), but they do not have -an intrinsic sign. This means that the "standard ops" operation set has things -like `addi` and `muli` which do two's complement arithmetic, but some other -operations get a sign, e.g. `divis` vs `diviu`. +type has a symbolic width equal to the machine word size), and they *may* +additionally have signedness semantics. The purpose is to satisfy the needs of +different dialects, which can model different level of abstractions. Certain +abstraction, especially more near source language, might want to differentiate +signedness with integer types; while others, especially more near machine +instruction, might want signless integers. Instead of forcing each abstraction +to adopt the same integer modelling or develop its own one in house, Integer +types provides this as an option to help code reuse and consistency. + +For the standard dialect, the choice is to have signless integer types. An +integer value does not have an intrinsic sign, and it's up to the specific op +for interpretation. For example, ops like `addi` and `muli` do two's complement +arithmetic, but some other operations get a sign, e.g. `divis` vs `diviu`. LLVM uses the [same design](http://llvm.org/docs/LangRef.html#integer-type), which was introduced in a revamp rolled out diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -685,7 +685,7 @@ const char *data = reinterpret_cast(values.data()); return getRawIntOrFloat( type, ArrayRef(data, values.size() * sizeof(T)), sizeof(T), - /*isInt=*/std::numeric_limits::is_integer); + std::numeric_limits::is_integer, std::numeric_limits::is_signed); } /// Constructs a dense integer elements attribute from a single element. @@ -853,7 +853,8 @@ std::numeric_limits::is_integer) || llvm::is_one_of::value>::type> llvm::iterator_range> getValues() const { - assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer)); + assert(isValidIntOrFloat(sizeof(T), std::numeric_limits::is_integer, + std::numeric_limits::is_signed)); auto rawData = getRawData().data(); bool splat = isSplat(); return {ElementIterator(rawData, splat, 0), @@ -964,12 +965,13 @@ /// invariants that the templatized 'get' method cannot. static DenseElementsAttr getRawIntOrFloat(ShapedType type, ArrayRef data, - int64_t dataEltSize, bool isInt); + int64_t dataEltSize, bool isInt, + bool isSigned); - /// Check the information for a c++ data type, check if this type is valid for + /// Check the information for a C++ data type, check if this type is valid for /// the current attribute. This method is used to verify specific type /// invariants that the templatized 'getValues' method cannot. - bool isValidIntOrFloat(int64_t dataEltSize, bool isInt) const; + bool isValidIntOrFloat(int64_t dataEltSize, bool isInt, bool isSigned) const; }; /// An attribute that represents a reference to a dense float vector or tensor diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -70,6 +70,7 @@ IntegerType getI1Type(); IntegerType getIntegerType(unsigned width); + IntegerType getIntegerType(unsigned width, bool isSigned); FunctionType getFunctionType(ArrayRef inputs, ArrayRef results); TupleType getTupleType(ArrayRef elementTypes); NoneType getNoneType(); @@ -111,6 +112,10 @@ IntegerAttr getI32IntegerAttr(int32_t value); IntegerAttr getI64IntegerAttr(int64_t value); + /// Signed and unsigned integer attribute getters. + IntegerAttr getSI32IntegerAttr(int32_t value); + IntegerAttr getUI32IntegerAttr(uint32_t value); + DenseIntElementsAttr getI32VectorAttr(ArrayRef values); ArrayAttr getAffineMapArrayAttr(ArrayRef values); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -316,7 +316,7 @@ // Integer type of a specific width. class I - : Type, + : Type, width # "-bit integer">, BuildableType<"getIntegerType(" # width # ")"> { int bitwidth = width; diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -84,25 +84,57 @@ public: using Base::Base; + /// Signedness semantics. + enum SignednessSemantics { + Signless, /// No signedness semantics + Signed, /// Signed integer + Unsigned, /// Unsigned integer + }; + /// Get or create a new IntegerType of the given width within the context. - /// Assume the width is within the allowed range and assert on failures. - /// Use getChecked to handle failures gracefully. + /// The created IntegerType is signless (i.e., no signedness semantics). + /// Assume the width is within the allowed range and assert on failures. Use + /// getChecked to handle failures gracefully. static IntegerType get(unsigned width, MLIRContext *context); + /// Get or create a new IntegerType of the given width within the context. + /// The created IntegerType has sigedness semantics as indicated via + /// `signedness`. Assume the width is within the allowed range and assert on + /// failures. Use getChecked to handle failures gracefully. + static IntegerType get(unsigned width, SignednessSemantics signedness, + MLIRContext *context); + /// Get or create a new IntegerType of the given width within the context, - /// defined at the given, potentially unknown, location. If the width is + /// defined at the given, potentially unknown, location. The created + /// IntegerType is signless (i.e., no signedness semantics). If the width is /// outside the allowed range, emit errors and return a null type. - static IntegerType getChecked(unsigned width, MLIRContext *context, + static IntegerType getChecked(unsigned width, Location location); + + /// Get or create a new IntegerType of the given width within the context, + /// defined at the given, potentially unknown, location. The created + /// IntegerType has sigedness semantics as indicatd via `signedness`. If the + /// width is outside the allowed range, emit errors and return a null type. + static IntegerType getChecked(unsigned width, SignednessSemantics signedness, Location location); /// Verify the construction of an integer type. - static LogicalResult verifyConstructionInvariants(Optional loc, - MLIRContext *context, - unsigned width); + static LogicalResult + verifyConstructionInvariants(Optional loc, MLIRContext *context, + unsigned width, SignednessSemantics signedness); /// Return the bitwidth of this integer type. unsigned getWidth() const; + /// Return the signedness semantics of this integer type. + SignednessSemantics getSignedness() const; + + /// Return true if this is a singless integer type. + bool isSignless() const { return getSignedness() == Signless; } + /// Return true if this is a signed integer type. + bool isSigned() const { return getSignedness() == Signed; } + /// Return true if this is an unsigned integer type. + bool isUnsigned() const { return getSignedness() == Unsigned; } + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(unsigned kind) { return kind == StandardTypes::Integer; } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -148,6 +148,16 @@ /// Return true if this is an integer type with the specified width. bool isInteger(unsigned width); + /// Return true if this is a signless integer type (with the specified width). + bool isSignlessInteger(); + bool isSignlessInteger(unsigned width); + /// Return true if this is a signed integer type (with the specified width). + bool isSignedInteger(); + bool isSignedInteger(unsigned width); + /// Return true if this is an unsigned integer type (with the specified + /// width). + bool isUnsignedInteger(); + bool isUnsignedInteger(unsigned width); /// Return the bit width of an integer or a float type, assert failure on /// other types. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1420,6 +1420,10 @@ case StandardTypes::Integer: { auto integer = type.cast(); + if (integer.isSigned()) + os << 's'; + else if (integer.isUnsigned()) + os << 'u'; os << 'i' << integer.getWidth(); return; } diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -274,12 +274,16 @@ return get(type, APInt(64, value)); auto intType = type.cast(); - return get(type, APInt(intType.getWidth(), value)); + return get(type, APInt(intType.getWidth(), value, intType.isSignedInteger())); } APInt IntegerAttr::getValue() const { return getImpl()->getValue(); } -int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } +int64_t IntegerAttr::getInt() const { + assert(!getImpl()->getType().isUnsignedInteger() && + "integer signedness unsupported"); + return getValue().getSExtValue(); +} //===----------------------------------------------------------------------===// // IntegerSetAttr @@ -670,19 +674,28 @@ data, isSplat); } -/// Check the information for a c++ data type, check if this type is valid for +/// Check the information for a C++ data type, check if this type is valid for /// the current attribute. This method is used to verify specific type /// invariants that the templatized 'getValues' method cannot. -static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, - bool isInt) { +static bool isValidIntOrFloat(ShapedType type, int64_t dataEltSize, bool isInt, + bool isSigned) { // Make sure that the data element size is the same as the type element width. if (getDenseElementBitwidth(type.getElementType()) != static_cast(dataEltSize * CHAR_BIT)) return false; - // Check that the element type is valid. - return isInt ? type.getElementType().isa() - : type.getElementType().isa(); + // Check that the element type is either float or integer. + if (!isInt) + return type.getElementType().isa(); + + auto intType = type.getElementType().dyn_cast(); + if (!intType) + return false; + + // Make sure signedness semantics is consistent. + if (intType.isSignless()) + return true; + return intType.isSigned() ? isSigned : !isSigned; } /// Overload of the 'getRaw' method that asserts that the given type is of @@ -691,8 +704,9 @@ DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef data, int64_t dataEltSize, - bool isInt) { - assert(::isValidIntOrFloat(type, dataEltSize, isInt)); + bool isInt, + bool isSigned) { + assert(::isValidIntOrFloat(type, dataEltSize, isInt, isSigned)); int64_t numElements = data.size() / dataEltSize; assert(numElements == 1 || numElements == type.getNumElements()); @@ -701,9 +715,9 @@ /// A method used to verify specific type invariants that the templatized 'get' /// method cannot. -bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, - bool isInt) const { - return ::isValidIntOrFloat(getType(), dataEltSize, isInt); +bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, + bool isSigned) const { + return ::isValidIntOrFloat(getType(), dataEltSize, isInt, isSigned); } /// Return the raw storage data held by this attribute. diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -59,6 +59,11 @@ return IntegerType::get(width, context); } +IntegerType Builder::getIntegerType(unsigned width, bool isSigned) { + return IntegerType::get( + width, isSigned ? IntegerType::Signed : IntegerType::Unsigned, context); +} + FunctionType Builder::getFunctionType(ArrayRef inputs, ArrayRef results) { return FunctionType::get(inputs, results, context); @@ -104,6 +109,16 @@ return IntegerAttr::get(getIntegerType(32), APInt(32, value)); } +IntegerAttr Builder::getSI32IntegerAttr(int32_t value) { + return IntegerAttr::get(getIntegerType(32, /*isSigned=*/true), + APInt(32, value, /*isSigned=*/true)); +} + +IntegerAttr Builder::getUI32IntegerAttr(uint32_t value) { + return IntegerAttr::get(getIntegerType(32, /*isSigned=*/false), + APInt(32, value, /*isSigned=*/false)); +} + IntegerAttr Builder::getI16IntegerAttr(int16_t value) { return IntegerAttr::get(getIntegerType(16), APInt(16, value)); } @@ -115,7 +130,8 @@ IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) { if (type.isIndex()) return IntegerAttr::get(type, APInt(64, value)); - return IntegerAttr::get(type, APInt(type.getIntOrFloatBitWidth(), value)); + return IntegerAttr::get( + type, APInt(type.getIntOrFloatBitWidth(), value, type.isSignedInteger())); } IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -245,16 +245,18 @@ /// Index Type. impl->indexTy = TypeUniquer::get(this, StandardTypes::Index); /// Integer Types. - impl->int1Ty = TypeUniquer::get(this, StandardTypes::Integer, 1); - impl->int8Ty = TypeUniquer::get(this, StandardTypes::Integer, 8); - impl->int16Ty = - TypeUniquer::get(this, StandardTypes::Integer, 16); - impl->int32Ty = - TypeUniquer::get(this, StandardTypes::Integer, 32); - impl->int64Ty = - TypeUniquer::get(this, StandardTypes::Integer, 64); - impl->int128Ty = - TypeUniquer::get(this, StandardTypes::Integer, 128); + impl->int1Ty = TypeUniquer::get(this, StandardTypes::Integer, 1, + IntegerType::Signless); + impl->int8Ty = TypeUniquer::get(this, StandardTypes::Integer, 8, + IntegerType::Signless); + impl->int16Ty = TypeUniquer::get(this, StandardTypes::Integer, + 16, IntegerType::Signless); + impl->int32Ty = TypeUniquer::get(this, StandardTypes::Integer, + 32, IntegerType::Signless); + impl->int64Ty = TypeUniquer::get(this, StandardTypes::Integer, + 64, IntegerType::Signless); + impl->int128Ty = TypeUniquer::get(this, StandardTypes::Integer, + 128, IntegerType::Signless); /// None Type. impl->noneType = TypeUniquer::get(this, StandardTypes::None); @@ -489,7 +491,13 @@ /// Return an existing integer type instance if one is cached within the /// context. -static IntegerType getCachedIntegerType(unsigned width, MLIRContext *context) { +static IntegerType +getCachedIntegerType(unsigned width, + IntegerType::SignednessSemantics signedness, + MLIRContext *context) { + if (signedness != IntegerType::Signless) + return IntegerType(); + switch (width) { case 1: return context->getImpl().int1Ty; @@ -509,16 +517,29 @@ } IntegerType IntegerType::get(unsigned width, MLIRContext *context) { - if (auto cached = getCachedIntegerType(width, context)) + return get(width, IntegerType::Signless, context); +} + +IntegerType IntegerType::get(unsigned width, + IntegerType::SignednessSemantics signedness, + MLIRContext *context) { + if (auto cached = getCachedIntegerType(width, signedness, context)) return cached; - return Base::get(context, StandardTypes::Integer, width); + return Base::get(context, StandardTypes::Integer, width, signedness); +} + +IntegerType IntegerType::getChecked(unsigned width, Location location) { + return getChecked(width, IntegerType::Signless, location); } -IntegerType IntegerType::getChecked(unsigned width, MLIRContext *context, +IntegerType IntegerType::getChecked(unsigned width, + SignednessSemantics signedness, Location location) { - if (auto cached = getCachedIntegerType(width, context)) + if (auto cached = + getCachedIntegerType(width, signedness, location->getContext())) return cached; - return Base::getChecked(location, context, StandardTypes::Integer, width); + return Base::getChecked(location, location->getContext(), + StandardTypes::Integer, width, signedness); } /// Get an instance of the NoneType. diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -36,6 +36,42 @@ return false; } +bool Type::isSignlessInteger() { + if (auto intTy = dyn_cast()) + return intTy.isSignless(); + return false; +} + +bool Type::isSignlessInteger(unsigned width) { + if (auto intTy = dyn_cast()) + return intTy.isSignless() && intTy.getWidth() == width; + return false; +} + +bool Type::isSignedInteger() { + if (auto intTy = dyn_cast()) + return intTy.isSigned(); + return false; +} + +bool Type::isSignedInteger(unsigned width) { + if (auto intTy = dyn_cast()) + return intTy.isSigned() && intTy.getWidth() == width; + return false; +} + +bool Type::isUnsignedInteger() { + if (auto intTy = dyn_cast()) + return intTy.isUnsigned(); + return false; +} + +bool Type::isUnsignedInteger(unsigned width) { + if (auto intTy = dyn_cast()) + return intTy.isUnsigned() && intTy.getWidth() == width; + return false; +} + bool Type::isIntOrIndex() { return isa() || isa(); } bool Type::isIntOrIndexOrFloat() { @@ -52,17 +88,24 @@ constexpr unsigned IntegerType::kMaxWidth; /// Verify the construction of an integer type. -LogicalResult IntegerType::verifyConstructionInvariants(Optional loc, - MLIRContext *context, - unsigned width) { +LogicalResult +IntegerType::verifyConstructionInvariants(Optional loc, + MLIRContext *context, unsigned width, + SignednessSemantics signedness) { if (width > IntegerType::kMaxWidth) { return emitOptionalError(loc, "integer bitwidth is limited to ", IntegerType::kMaxWidth, " bits"); } + if (width == 1 && signedness != IntegerType::Signless) + return emitOptionalError(loc, "cannot have signedness semantics for i1"); return success(); } -unsigned IntegerType::getWidth() const { return getImpl()->width; } +unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); } + +IntegerType::SignednessSemantics IntegerType::getSignedness() const { + return getImpl()->getSignedness(); +} //===----------------------------------------------------------------------===// // Float Type diff --git a/mlir/lib/IR/TypeDetail.h b/mlir/lib/IR/TypeDetail.h --- a/mlir/lib/IR/TypeDetail.h +++ b/mlir/lib/IR/TypeDetail.h @@ -15,8 +15,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Identifier.h" #include "mlir/IR/MLIRContext.h" -#include "mlir/IR/TypeSupport.h" -#include "mlir/IR/Types.h" +#include "mlir/IR/StandardTypes.h" #include "llvm/Support/TrailingObjects.h" namespace mlir { @@ -52,19 +51,44 @@ /// Integer Type Storage and Uniquing. struct IntegerTypeStorage : public TypeStorage { - IntegerTypeStorage(unsigned width) : width(width) {} + IntegerTypeStorage(unsigned width, + IntegerType::SignednessSemantics signedness) + : TypeStorage(packWithAndSignedness(width, signedness)) {} /// The hash key used for uniquing. - using KeyTy = unsigned; - bool operator==(const KeyTy &key) const { return key == width; } + using KeyTy = std::pair; + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(packWithAndSignedness(key.first, key.second)); + } + + bool operator==(const KeyTy &key) const { + return getSubclassData() == packWithAndSignedness(key.first, key.second); + } static IntegerTypeStorage *construct(TypeStorageAllocator &allocator, - KeyTy bitwidth) { + KeyTy key) { return new (allocator.allocate()) - IntegerTypeStorage(bitwidth); + IntegerTypeStorage(key.first, key.second); + } + + static const unsigned numWidthBits = sizeof(unsigned) * CHAR_BIT - 2; + + /// Packs the given `signedness` into the upper two bits in `width`. + static unsigned + packWithAndSignedness(unsigned width, + IntegerType::SignednessSemantics signedness) { + return width | (static_cast(signedness) << numWidthBits); + } + + unsigned getWidth() { + return getSubclassData() & ((1u << numWidthBits) - 1u); } - unsigned width; + IntegerType::SignednessSemantics getSignedness() { + return static_cast(getSubclassData() >> + numWidthBits); + } }; /// Function Type Storage and Uniquing. @@ -321,4 +345,5 @@ } // namespace detail } // namespace mlir + #endif // TYPEDETAIL_H_ diff --git a/mlir/lib/Parser/Lexer.cpp b/mlir/lib/Parser/Lexer.cpp --- a/mlir/lib/Parser/Lexer.cpp +++ b/mlir/lib/Parser/Lexer.cpp @@ -187,7 +187,7 @@ /// Lex a bare identifier or keyword that starts with a letter. /// /// bare-id ::= (letter|[_]) (letter|digit|[_$.])* -/// integer-type ::= `i[1-9][0-9]*` +/// integer-type ::= `[su]?i[1-9][0-9]*` /// Token Lexer::lexBareIdentifierOrKeyword(const char *tokStart) { // Match the rest of the identifier regex: [0-9a-zA-Z_.$]* @@ -198,14 +198,17 @@ // Check to see if this identifier is a keyword. StringRef spelling(tokStart, curPtr - tokStart); - // Check for i123. - if (tokStart[0] == 'i') { - bool allDigits = true; - for (auto c : spelling.drop_front()) - allDigits &= isdigit(c) != 0; - if (allDigits && spelling.size() != 1) - return Token(Token::inttype, spelling); - } + auto isAllDigit = [](StringRef str) { + return llvm::all_of(str, [](char c) { return llvm::isDigit(c); }); + }; + + // Check for i123, si456, ui789. + if ((spelling.size() > 1 && tokStart[0] == 'i' && + isAllDigit(spelling.drop_front())) || + ((spelling.size() > 2 && tokStart[1] == 'i' && + (tokStart[0] == 's' || tokStart[0] == 'u')) && + isAllDigit(spelling.drop_front(2)))) + return Token(Token::inttype, spelling); Token::Kind kind = llvm::StringSwitch(spelling) #define TOK_KEYWORD(SPELLING) .Case(#SPELLING, Token::kw_##SPELLING) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1197,9 +1197,14 @@ auto width = getToken().getIntTypeBitwidth(); if (!width.hasValue()) return (emitError("invalid integer width"), nullptr); + + IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; + if (Optional signedness = getToken().getIntTypeSignedness()) + signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; + auto loc = getEncodedSourceLocation(getToken().getLoc()); consumeToken(Token::inttype); - return IntegerType::getChecked(width.getValue(), builder.getContext(), loc); + return IntegerType::getChecked(width.getValue(), signSemantics, loc); } // float-type @@ -1767,6 +1772,12 @@ return emitError(loc, "integer literal not valid for specified type"), nullptr; + if (isNegative && type.isUnsignedInteger()) { + emitError(loc, + "negative integer literal not valid for unsigned integer type"); + return nullptr; + } + // Parse the integer literal. int width = type.isIndex() ? 64 : type.getIntOrFloatBitWidth(); APInt apInt(width, *val, isNegative); @@ -1928,13 +1939,22 @@ IntegerType eltTy) { std::vector intElements; intElements.reserve(storage.size()); + auto isUintType = type.getElementType().isUnsignedInteger(); for (const auto &signAndToken : storage) { bool isNegative = signAndToken.first; const Token &token = signAndToken.second; + auto tokenLoc = token.getLoc(); + + if (isNegative && isUintType) { + p.emitError(tokenLoc) + << "expected unsigned integer elements, but parsed negative value"; + return nullptr; + } // Check to see if floating point values were parsed. if (token.is(Token::floatliteral)) { - p.emitError() << "expected integer elements, but parsed floating-point"; + p.emitError(tokenLoc) + << "expected integer elements, but parsed floating-point"; return nullptr; } @@ -1942,7 +1962,8 @@ "unexpected token type"); if (token.isAny(Token::kw_true, Token::kw_false)) { if (!eltTy.isInteger(1)) - p.emitError() << "expected i1 type for 'true' or 'false' values"; + p.emitError(tokenLoc) + << "expected i1 type for 'true' or 'false' values"; APInt apInt(eltTy.getWidth(), token.is(Token::kw_true), /*isSigned=*/false); intElements.push_back(apInt); @@ -1953,13 +1974,13 @@ auto val = token.getUInt64IntegerValue(); if (!val.hasValue() || (isNegative ? (int64_t)-val.getValue() >= 0 : (int64_t)val.getValue() < 0)) { - p.emitError(token.getLoc(), - "integer constant out of range for attribute"); + p.emitError(tokenLoc, "integer constant out of range for attribute"); return nullptr; } APInt apInt(eltTy.getWidth(), val.getValue(), isNegative); if (apInt != val.getValue()) - return (p.emitError("integer constant out of range for type"), nullptr); + return (p.emitError(tokenLoc, "integer constant out of range for type"), + nullptr); intElements.push_back(isNegative ? -apInt : apInt); } diff --git a/mlir/lib/Parser/Token.h b/mlir/lib/Parser/Token.h --- a/mlir/lib/Parser/Token.h +++ b/mlir/lib/Parser/Token.h @@ -74,6 +74,11 @@ /// For an inttype token, return its bitwidth. Optional getIntTypeBitwidth() const; + /// For an inttype token, return its signedness semantics: llvm::None means no + /// signedness semantics; true means signed integer type; false means unsigned + /// integer type. + Optional getIntTypeSignedness() const; + /// Given a hash_identifier token like #123, try to parse the number out of /// the identifier, returning None if it is a named identifier like #x or /// if the integer doesn't fit. diff --git a/mlir/lib/Parser/Token.cpp b/mlir/lib/Parser/Token.cpp --- a/mlir/lib/Parser/Token.cpp +++ b/mlir/lib/Parser/Token.cpp @@ -57,13 +57,26 @@ /// For an inttype token, return its bitwidth. Optional Token::getIntTypeBitwidth() const { + assert(getKind() == inttype); + unsigned bitwidthStart = (spelling[0] == 'i' ? 1 : 2); unsigned result = 0; - if (spelling[1] == '0' || spelling.drop_front().getAsInteger(10, result) || + if (spelling[bitwidthStart] == '0' || + spelling.drop_front(bitwidthStart).getAsInteger(10, result) || result == 0) return None; return result; } +Optional Token::getIntTypeSignedness() const { + assert(getKind() == inttype); + if (spelling[0] == 'i') + return llvm::None; + if (spelling[0] == 's') + return true; + assert(spelling[0] == 'u'); + return false; +} + /// Given a token containing a string literal, return its value, including /// removing the quote characters and unescaping the contents of the string. The /// lexer has already verified that this token is valid. diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -52,7 +52,7 @@ TOK_LITERAL(floatliteral) // 2.0 TOK_LITERAL(integer) // 42 TOK_LITERAL(string) // "foo" -TOK_LITERAL(inttype) // i421 +TOK_LITERAL(inttype) // i4, si8, ui16 // Punctuation. TOK_PUNCTUATION(arrow, "->") diff --git a/mlir/test/IR/invalid.mlir b/mlir/test/IR/invalid.mlir --- a/mlir/test/IR/invalid.mlir +++ b/mlir/test/IR/invalid.mlir @@ -204,6 +204,14 @@ // ----- +func @illegaltype(ui1) // expected-error {{cannot have signedness semantics for i1}} + +// ----- + +func @illegaltype(si1) // expected-error {{cannot have signedness semantics for i1}} + +// ----- + func @malformed_for_percent() { affine.for i = 1 to 10 { // expected-error {{expected SSA operand}} @@ -1206,5 +1214,21 @@ "foo"() {bar = dense : tensor<2xi16>} : () -> () } +// ----- + // expected-error @+1 {{unbalanced ')' character in pretty dialect name}} func @bad_arrow(%arg : !unreg.ptr<(i32)->) + +// ----- + +func @negative_value_in_unsigned_int_attr() { + // expected-error @+1 {{negative integer literal not valid for unsigned integer type}} + "foo"() {bar = -5 : ui32} : () -> () +} + +// ----- + +func @negative_value_in_unsigned_vector_attr() { + // expected-error @+1 {{expected unsigned integer elements, but parsed negative value}} + "foo"() {bar = dense<[5, -5]> : vector<2xui32>} : () -> () +} diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -61,6 +61,12 @@ // CHECK: func @int_types(i1, i2, i4, i7, i87) -> (i1, index, i19) func @int_types(i1, i2, i4, i7, i87) -> (i1, index, i19) +// CHECK: func @sint_types(si2, si4) -> (si7, si1023) +func @sint_types(si2, si4) -> (si7, si1023) + +// CHECK: func @uint_types(ui2, ui4) -> (ui7, ui1023) +func @uint_types(ui2, ui4) -> (ui7, ui1023) + // CHECK: func @vectors(vector<1xf32>, vector<2x4xf32>) func @vectors(vector<1 x f32>, vector<2x4xf32>) diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -91,4 +91,4 @@ // CHECK-LABEL: OpK::verify // CHECK: for (Value v : getODSOperands(0)) { -// CHECK: if (!(((v.getType().isa())) && (((v.getType().cast().getElementType().isF32())) || ((v.getType().cast().getElementType().isInteger(32)))))) +// CHECK: if (!(((v.getType().isa())) && (((v.getType().cast().getElementType().isF32())) || ((v.getType().cast().getElementType().isSignlessInteger(32))))))