diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h @@ -165,9 +165,8 @@ if (*rawConstantIter == GEPOp::kDynamicIndex) return *valuesIter; - return IntegerAttr::get( - ElementsAttr::getElementType(base->rawConstantIndices), - *rawConstantIter); + return IntegerAttr::get(base->rawConstantIndices.getElementType(), + *rawConstantIter); } iterator &operator++() { diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -155,8 +155,7 @@ }]; } -def Builtin_DenseArray : Builtin_Attr< - "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> { +def Builtin_DenseArray : Builtin_Attr<"DenseArray"> { let summary = "A dense array of integer or floating point elements."; let description = [{ A dense array attribute is an attribute that represents a dense array of @@ -195,43 +194,26 @@ }]; let parameters = (ins - AttributeSelfTypeParameter<"", "RankedTensorType">:$type, + "Type":$elementType, + "int64_t":$size, Builtin_DenseArrayRawDataParameter:$rawData ); let builders = [ - AttrBuilderWithInferredContext<(ins "RankedTensorType":$type, + AttrBuilderWithInferredContext<(ins "Type":$elementType, "unsigned":$size, "ArrayRef":$rawData), [{ - return $_get(type.getContext(), type, rawData); + return $_get(elementType.getContext(), elementType, size, rawData); }]>, ]; - let extraClassDeclaration = [{ - /// Allow implicit conversion to ElementsAttr. - operator ElementsAttr() const { - return *this ? cast() : nullptr; - } + let genVerifyDecl = 1; - /// ElementsAttr implementation. - using ContiguousIterableTypesT = - std::tuple; - FailureOr - try_value_begin_impl(OverloadToken) const; - FailureOr - try_value_begin_impl(OverloadToken) const; - FailureOr - try_value_begin_impl(OverloadToken) const; - FailureOr - try_value_begin_impl(OverloadToken) const; - FailureOr - try_value_begin_impl(OverloadToken) const; - FailureOr - try_value_begin_impl(OverloadToken) const; - FailureOr - try_value_begin_impl(OverloadToken) const; + let extraClassDeclaration = [{ + /// Get the number of elements in the array. + int64_t size() const { return getSize(); } + /// Return true if there are no elements in the dense array. + bool empty() const { return !size(); } }]; - - let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -844,9 +844,7 @@ ParseResult parseFloatElement(Parser &p); /// Convert the current contents to a dense array. - DenseArrayAttr getAttr() { - return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData); - } + DenseArrayAttr getAttr() { return DenseArrayAttr::get(type, size, rawData); } private: /// Append the raw data of an APInt to the result. @@ -934,18 +932,9 @@ return {}; SMLoc typeLoc = getToken().getLoc(); - Type eltType; - // If an attribute type was provided, use its element type. - if (attrType) { - auto tensorType = attrType.dyn_cast(); - if (!tensorType) { - emitError(typeLoc, "dense array attribute expected ranked tensor type"); - return {}; - } - eltType = tensorType.getElementType(); - - // Otherwise, parse a type. - } else if (!(eltType = parseType())) { + Type eltType = parseType(); + if (!eltType) { + emitError(typeLoc, "expected an integer or floating point type"); return {}; } @@ -960,23 +949,11 @@ return {}; } - // If a type was provided, check that it matches the parsed type. - auto checkProvidedType = [&](DenseArrayAttr result) -> Attribute { - if (attrType && result.getType() != attrType) { - emitError(typeLoc, "expected attribute type ") - << attrType << " does not match parsed type " << result.getType(); - return {}; - } - return result; - }; - // Check for empty list. - if (consumeIf(Token::greater)) { - return checkProvidedType( - DenseArrayAttr::get(RankedTensorType::get(0, eltType), {})); - } - if (!attrType && - parseToken(Token::colon, "expected ':' after dense array type")) + if (consumeIf(Token::greater)) + return DenseArrayAttr::get(eltType, 0, {}); + + if (parseToken(Token::colon, "expected ':' after dense array type")) return {}; DenseArrayElementParser eltParser(eltType); @@ -991,7 +968,7 @@ } if (parseToken(Token::greater, "expected '>' to close an array attribute")) return {}; - return checkProvidedType(eltParser.getAttr()); + return eltParser.getAttr(); } /// Parse a dense elements attribute. diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -2028,11 +2028,9 @@ stridedLayoutAttr.print(os); } else if (auto denseArrayAttr = attr.dyn_cast()) { os << "array<"; - if (typeElision != AttrTypeElision::Must) - printType(denseArrayAttr.getType().getElementType()); + printType(denseArrayAttr.getElementType()); if (!denseArrayAttr.empty()) { - if (typeElision != AttrTypeElision::Must) - os << ": "; + os << ": "; printDenseArrayAttr(denseArrayAttr); } os << ">"; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -690,69 +690,21 @@ LogicalResult DenseArrayAttr::verify(function_ref emitError, - RankedTensorType type, ArrayRef rawData) { - if (type.getRank() != 1) - return emitError() << "expected rank 1 tensor type"; - if (!type.getElementType().isIntOrIndexOrFloat()) + Type elementType, int64_t size, ArrayRef rawData) { + if (!elementType.isIntOrIndexOrFloat()) return emitError() << "expected integer or floating point element type"; int64_t dataSize = rawData.size(); - int64_t size = type.getShape().front(); - if (type.getElementType().isInteger(1)) { - if (size != dataSize) - return emitError() << "expected " << size - << " bytes for i1 array but got " << dataSize; - } else if (size * type.getElementTypeBitWidth() != dataSize * 8) { + int64_t elementSize = + llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT); + if (size * elementSize != dataSize) { return emitError() << "expected data size (" << size << " elements, " - << type.getElementTypeBitWidth() - << " bits each) does not match: " << dataSize + << elementSize + << " bytes each) does not match: " << dataSize << " bytes"; } return success(); } -FailureOr -DenseArrayAttr::try_value_begin_impl(OverloadToken) const { - if (auto attr = dyn_cast()) - return attr.asArrayRef().begin(); - return failure(); -} -FailureOr -DenseArrayAttr::try_value_begin_impl(OverloadToken) const { - if (auto attr = dyn_cast()) - return attr.asArrayRef().begin(); - return failure(); -} -FailureOr -DenseArrayAttr::try_value_begin_impl(OverloadToken) const { - if (auto attr = dyn_cast()) - return attr.asArrayRef().begin(); - return failure(); -} -FailureOr -DenseArrayAttr::try_value_begin_impl(OverloadToken) const { - if (auto attr = dyn_cast()) - return attr.asArrayRef().begin(); - return failure(); -} -FailureOr -DenseArrayAttr::try_value_begin_impl(OverloadToken) const { - if (auto attr = dyn_cast()) - return attr.asArrayRef().begin(); - return failure(); -} -FailureOr -DenseArrayAttr::try_value_begin_impl(OverloadToken) const { - if (auto attr = dyn_cast()) - return attr.asArrayRef().begin(); - return failure(); -} -FailureOr -DenseArrayAttr::try_value_begin_impl(OverloadToken) const { - if (auto attr = dyn_cast()) - return attr.asArrayRef().begin(); - return failure(); -} - namespace { /// Instantiations of this class provide utilities for interacting with native /// data types in the context of DenseArrayAttr. @@ -898,12 +850,11 @@ template DenseArrayAttrImpl DenseArrayAttrImpl::get(MLIRContext *context, ArrayRef content) { - auto shapedType = RankedTensorType::get( - content.size(), DenseArrayAttrUtil::getElementType(context)); + Type elementType = DenseArrayAttrUtil::getElementType(context); auto rawArray = ArrayRef(reinterpret_cast(content.data()), content.size() * sizeof(T)); - return Base::get(context, shapedType, rawArray) - .template cast>(); + return cast>( + Base::get(context, elementType, content.size(), rawArray)); } template diff --git a/mlir/lib/IR/BuiltinDialectBytecode.cpp b/mlir/lib/IR/BuiltinDialectBytecode.cpp --- a/mlir/lib/IR/BuiltinDialectBytecode.cpp +++ b/mlir/lib/IR/BuiltinDialectBytecode.cpp @@ -494,17 +494,20 @@ DenseArrayAttr BuiltinDialectBytecodeInterface::readDenseArrayAttr( DialectBytecodeReader &reader) const { - RankedTensorType type; + Type elementType; + uint64_t size; ArrayRef blob; - if (failed(reader.readType(type)) || failed(reader.readBlob(blob))) + if (failed(reader.readType(elementType)) || failed(reader.readVarInt(size)) || + failed(reader.readBlob(blob))) return DenseArrayAttr(); - return DenseArrayAttr::get(type, blob); + return DenseArrayAttr::get(elementType, size, blob); } void BuiltinDialectBytecodeInterface::write( DenseArrayAttr attr, DialectBytecodeWriter &writer) const { writer.writeVarInt(builtin_encoding::kDenseArrayAttr); - writer.writeType(attr.getType()); + writer.writeType(attr.getElementType()); + writer.writeVarInt(attr.getSize()); writer.writeOwnedBlob(attr.getRawData()); } diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -626,8 +626,6 @@ x7_f16 = array }: () -> () - // CHECK: test.typed_attr tensor<4xi32> = array<1, 2, 3, 4> - test.typed_attr tensor<4xi32> = array<1, 2, 3, 4> return } diff --git a/mlir/test/IR/elements-attr-interface.mlir b/mlir/test/IR/elements-attr-interface.mlir --- a/mlir/test/IR/elements-attr-interface.mlir +++ b/mlir/test/IR/elements-attr-interface.mlir @@ -27,27 +27,6 @@ // expected-error@below {{Test iterating `IntegerAttr`: }} arith.constant dense<> : tensor<0xi64> -// expected-error@below {{Test iterating `bool`: true, false, true, false, true, false}} -// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} -arith.constant array -// expected-error@below {{Test iterating `int8_t`: 10, 11, -12, 13, 14}} -// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} -arith.constant array -// expected-error@below {{Test iterating `int16_t`: 10, 11, -12, 13, 14}} -// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} -arith.constant array -// expected-error@below {{Test iterating `int32_t`: 10, 11, -12, 13, 14}} -// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} -arith.constant array -// expected-error@below {{Test iterating `int64_t`: 10, 11, -12, 13, 14}} -arith.constant array -// expected-error@below {{Test iterating `float`: 10.00, 11.00, -12.00, 13.00, 14.00}} -// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} -arith.constant array -// expected-error@below {{Test iterating `double`: 10.00, 11.00, -12.00, 13.00, 14.00}} -// expected-error@below {{Test iterating `int64_t`: unable to iterate type}} -arith.constant array - // Check that we handle an external constant parsed from the config. // expected-error@below {{Test iterating `int64_t`: unable to iterate type}} // expected-error@below {{Test iterating `uint64_t`: 1, 2, 3}} diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -546,18 +546,3 @@ // expected-error@below {{expected '>' to close an array attribute}} #attr = array - -// ----- - -// expected-error@below {{does not match parsed type}} -test.typed_attr tensor<1xi32> = array<> - -// ----- - -// expected-error@below {{does not match parsed type}} -test.typed_attr tensor<0xi32> = array<1> diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -21,10 +21,6 @@ static void printOneElement(InFlightDiagnostic &os, T value) { os << llvm::formatv("{0}", value).str(); } -template <> -void printOneElement(InFlightDiagnostic &os, int8_t value) { - os << llvm::formatv("{0}", static_cast(value)).str(); -} namespace { struct TestElementsAttrInterface @@ -41,32 +37,6 @@ auto elementsAttr = attr.getValue().dyn_cast(); if (!elementsAttr) continue; - if (auto concreteAttr = attr.getValue().dyn_cast()) { - llvm::TypeSwitch(concreteAttr) - .Case([&](DenseBoolArrayAttr attr) { - testElementsAttrIteration(op, attr, "bool"); - }) - .Case([&](DenseI8ArrayAttr attr) { - testElementsAttrIteration(op, attr, "int8_t"); - }) - .Case([&](DenseI16ArrayAttr attr) { - testElementsAttrIteration(op, attr, "int16_t"); - }) - .Case([&](DenseI32ArrayAttr attr) { - testElementsAttrIteration(op, attr, "int32_t"); - }) - .Case([&](DenseI64ArrayAttr attr) { - testElementsAttrIteration(op, attr, "int64_t"); - }) - .Case([&](DenseF32ArrayAttr attr) { - testElementsAttrIteration(op, attr, "float"); - }) - .Case([&](DenseF64ArrayAttr attr) { - testElementsAttrIteration(op, attr, "double"); - }); - testElementsAttrIteration(op, elementsAttr, "int64_t"); - continue; - } testElementsAttrIteration(op, elementsAttr, "int64_t"); testElementsAttrIteration(op, elementsAttr, "uint64_t"); testElementsAttrIteration(op, elementsAttr, "APInt");