diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -93,9 +93,8 @@ return false; auto type = op->getResult(0).getType(); - if (type.isSignlessIntOrIndex()) { + if (type.isa() || type.isa()) return attr_value_binder(bind_value).match(attr); - } if (type.isa() || type.isa()) { if (auto splatAttr = attr.dyn_cast()) { return attr_value_binder(bind_value) diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -339,6 +339,30 @@ def I32 : I<32>; def I64 : I<64>; +// Unsigned integer types. +// Any unsigned integer type irrespective of its width. +def AnyUnsignedInteger : Type< + CPred<"$_self.isUnsignedInteger()">, "unsigned integer">; + +// Unsigned integer type of a specific width. +class UI + : Type, + width # "-bit unsigned integer">, + BuildableType<"$_builder.getIntegerType(" # width # + ", /*isSigned=*/false)"> { + int bitwidth = width; +} + +class UnsignedIntOfWidths widths> : + AnyTypeOf), + StrJoinInt.result # "-bit unsigned integer">; + +def UI1 : UI<1>; +def UI8 : UI<8>; +def UI16 : UI<16>; +def UI32 : UI<32>; +def UI64 : UI<64>; + // Floating point types. // Any float type irrespective of its width. diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -328,8 +328,9 @@ // Note: Non standard/builtin types are allowed to exist within tensor // types. Dialects are expected to verify that tensor types have a valid // element type within that dialect. - return type.isSignlessIntOrFloat() || type.isa() || - type.isa() || type.isa() || + return type.isa() || type.isa() || + type.isa() || type.isa() || + type.isa() || (type.getKind() > Type::Kind::LAST_STANDARD_TYPE); } diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -169,6 +169,9 @@ /// Return true of this is a signless integer or a float type. bool isSignlessIntOrFloat(); + /// Return true of this is an integer(of any signedness) or a float type. + bool isIntOrFloat(); + /// Print the current type. void print(raw_ostream &os); void dump(); diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -314,7 +314,7 @@ auto elementType = memRefType.getElementType(); unsigned sizeInBits; - if (elementType.isSignlessIntOrFloat()) { + if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); @@ -358,7 +358,7 @@ if (!memRefType.hasStaticShape()) return None; auto elementType = memRefType.getElementType(); - if (!elementType.isSignlessIntOrFloat() && !elementType.isa()) + if (!elementType.isIntOrFloat() && !elementType.isa()) return None; uint64_t sizeInBytes = getMemRefEltSizeInBytes(memRefType); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1372,17 +1372,18 @@ /// Print the integer element of the given DenseElementsAttr at 'index'. static void printDenseIntElement(DenseElementsAttr attr, raw_ostream &os, - unsigned index) { + unsigned index, bool isSigned) { APInt value = *std::next(attr.int_value_begin(), index); if (value.getBitWidth() == 1) os << (value.getBoolValue() ? "true" : "false"); else - value.print(os, /*isSigned=*/true); + value.print(os, isSigned); } /// Print the float element of the given DenseElementsAttr at 'index'. static void printDenseFloatElement(DenseElementsAttr attr, raw_ostream &os, - unsigned index) { + unsigned index, bool isSigned) { + assert(isSigned && "floating point values are always signed"); APFloat value = *std::next(attr.float_value_begin(), index); printFloatValue(value, os); } @@ -1392,6 +1393,7 @@ auto type = attr.getType(); auto shape = type.getShape(); auto rank = type.getRank(); + bool isSigned = !type.getElementType().isUnsignedInteger(); // The function used to print elements of this attribute. auto printEltFn = type.getElementType().isa() @@ -1400,7 +1402,7 @@ // Special case for 0-d and splat tensors. if (attr.isSplat()) { - printEltFn(attr, os, 0); + printEltFn(attr, os, 0, isSigned); return; } @@ -1452,7 +1454,7 @@ while (openBrackets++ < rank) os << '['; openBrackets = rank; - printEltFn(attr, os, idx); + printEltFn(attr, os, idx, isSigned); bumpCounter(); } while (openBrackets-- > 0) diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -608,7 +608,7 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef values) { - assert(type.getElementType().isSignlessIntOrFloat() && + assert(type.getElementType().isIntOrFloat() && "expected int or float element type"); assert(hasSameElementsOrSplat(type, values)); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -84,6 +84,8 @@ return isSignlessInteger() || isa(); } +bool Type::isIntOrFloat() { return isa() || isa(); } + //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// @@ -147,13 +149,10 @@ } unsigned Type::getIntOrFloatBitWidth() { - assert(isSignlessIntOrFloat() && "only ints and floats have a bitwidth"); - if (auto intType = dyn_cast()) { + assert(isIntOrFloat() && "only integers and floats have a bitwidth"); + if (auto intType = dyn_cast()) return intType.getWidth(); - } - - auto floatType = cast(); - return floatType.getWidth(); + return cast().getWidth(); } //===----------------------------------------------------------------------===// @@ -202,7 +201,7 @@ "cannot get the bit size of an aggregate with a dynamic shape"); auto elementType = getElementType(); - if (elementType.isSignlessIntOrFloat()) + if (elementType.isIntOrFloat()) return elementType.getIntOrFloatBitWidth() * getNumElements(); // Tensors can have vectors and other tensors as elements, other shaped types @@ -373,7 +372,7 @@ auto *context = elementType.getContext(); // Check that memref is formed from allowed types. - if (!elementType.isSignlessIntOrFloat() && !elementType.isa() && + if (!elementType.isIntOrFloat() && !elementType.isa() && !elementType.isa()) return emitOptionalError(location, "invalid memref element type"), MemRefType(); @@ -451,7 +450,7 @@ UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType, unsigned memorySpace) { // Check that memref is formed from allowed types. - if (!elementType.isSignlessIntOrFloat() && !elementType.isa() && + if (!elementType.isIntOrFloat() && !elementType.isa() && !elementType.isa()) return emitError(loc, "invalid memref element type"); return success(); diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -1102,7 +1102,7 @@ return nullptr; // Check that memref is formed from allowed types. - if (!elementType.isSignlessIntOrFloat() && !elementType.isa() && + if (!elementType.isIntOrFloat() && !elementType.isa() && !elementType.isa()) return emitError(typeLoc, "invalid memref element type"), nullptr; diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -869,7 +869,7 @@ auto elementType = memRefType.getElementType(); unsigned sizeInBits; - if (elementType.isSignlessIntOrFloat()) { + if (elementType.isIntOrFloat()) { sizeInBits = elementType.getIntOrFloatBitWidth(); } else { auto vectorType = elementType.cast(); diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -616,6 +616,9 @@ // CHECK: "splatBoolTensor"() {bar = dense : tensor} : () -> () "splatBoolTensor"(){bar = dense : tensor} : () -> () + // CHECK: "splatUIntTensor"() {bar = dense<222> : tensor<2x1x4xui8>} : () -> () + "splatUIntTensor"(){bar = dense<222> : tensor<2x1x4xui8>} : () -> () + // CHECK: "splatIntTensor"() {bar = dense<5> : tensor<2x1x4xi32>} : () -> () "splatIntTensor"(){bar = dense<5> : tensor<2x1x4xi32>} : () -> ()