diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -761,9 +761,9 @@ /// Base class for DenseArrayAttr that is instantiated and specialized for each /// supported element type below. template -class DenseArrayAttr : public DenseArrayBaseAttr { +class DenseArrayAttrImpl : public DenseArrayAttr { public: - using DenseArrayBaseAttr::DenseArrayBaseAttr; + using DenseArrayAttr::DenseArrayAttr; /// Implicit conversion to ArrayRef. operator ArrayRef() const; @@ -773,7 +773,7 @@ T operator[](std::size_t index) const { return asArrayRef()[index]; } /// Builder from ArrayRef. - static DenseArrayAttr get(MLIRContext *context, ArrayRef content); + static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef content); /// Print the short form `[42, 100, -1]` without any type prefix. void print(AsmPrinter &printer) const; @@ -791,23 +791,23 @@ static bool classof(Attribute attr); }; -extern template class DenseArrayAttr; -extern template class DenseArrayAttr; -extern template class DenseArrayAttr; -extern template class DenseArrayAttr; -extern template class DenseArrayAttr; -extern template class DenseArrayAttr; -extern template class DenseArrayAttr; +extern template class DenseArrayAttrImpl; +extern template class DenseArrayAttrImpl; +extern template class DenseArrayAttrImpl; +extern template class DenseArrayAttrImpl; +extern template class DenseArrayAttrImpl; +extern template class DenseArrayAttrImpl; +extern template class DenseArrayAttrImpl; } // namespace detail // Public name for all the supported DenseArrayAttr -using DenseBoolArrayAttr = detail::DenseArrayAttr; -using DenseI8ArrayAttr = detail::DenseArrayAttr; -using DenseI16ArrayAttr = detail::DenseArrayAttr; -using DenseI32ArrayAttr = detail::DenseArrayAttr; -using DenseI64ArrayAttr = detail::DenseArrayAttr; -using DenseF32ArrayAttr = detail::DenseArrayAttr; -using DenseF64ArrayAttr = detail::DenseArrayAttr; +using DenseBoolArrayAttr = detail::DenseArrayAttrImpl; +using DenseI8ArrayAttr = detail::DenseArrayAttrImpl; +using DenseI16ArrayAttr = detail::DenseArrayAttrImpl; +using DenseI32ArrayAttr = detail::DenseArrayAttrImpl; +using DenseI64ArrayAttr = detail::DenseArrayAttrImpl; +using DenseF32ArrayAttr = detail::DenseArrayAttrImpl; +using DenseF64ArrayAttr = detail::DenseArrayAttrImpl; //===----------------------------------------------------------------------===// // DenseResourceElementsAttr diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -140,7 +140,7 @@ } //===----------------------------------------------------------------------===// -// DenseArrayBaseAttr +// DenseArrayAttr //===----------------------------------------------------------------------===// def Builtin_DenseArrayRawDataParameter : ArrayRefParameter< @@ -155,23 +155,28 @@ }]; } -def Builtin_DenseArrayBase : Builtin_Attr< - "DenseArrayBase", [ElementsAttrInterface, TypedAttrInterface]> { - let summary = "A dense array of i8, i16, i32, i64, f32, or f64."; +def Builtin_DenseArray : Builtin_Attr< + "DenseArray", [ElementsAttrInterface, TypedAttrInterface]> { + let summary = "A dense array of integer or floating point elements."; let description = [{ A dense array attribute is an attribute that represents a dense array of primitive element types. Contrary to DenseIntOrFPElementsAttr this is a flat unidimensional array which does not have a storage optimization for splat. This allows to expose the raw array through a C++ API as - `ArrayRef`. This is the base class attribute, the actual access is - intended to be managed through the subclasses `DenseI8ArrayAttr`, - `DenseI16ArrayAttr`, `DenseI32ArrayAttr`, `DenseI64ArrayAttr`, - `DenseF32ArrayAttr`, and `DenseF64ArrayAttr`. + `ArrayRef` for compatible types. The element type must be bool or an + integer or float whose bitwidth is a multiple of 8. Bool elements are stored + as bytes. + + This is the base class attribute. Access to C++ types is intended to be + managed through the subclasses `DenseI8ArrayAttr`, `DenseI16ArrayAttr`, + `DenseI32ArrayAttr`, `DenseI64ArrayAttr`, `DenseF32ArrayAttr`, + and `DenseF64ArrayAttr`. Syntax: ``` - dense-array-attribute ::= `[` `:` (integer-type | float-type) tensor-literal `]` + dense-array-attribute ::= `array` `<` (integer-type | float-type) + (`:` tensor-literal)? `>` ``` Examples: @@ -181,16 +186,26 @@ array ``` - when a specific subclass is used as argument of an operation, the declarative - assembly will omit the type and print directly: - ``` + When a specific subclass is used as argument of an operation, the + declarative assembly will omit the type and print directly: + + ```mlir [1, 2, 3] ``` }]; + let parameters = (ins AttributeSelfTypeParameter<"", "RankedTensorType">:$type, Builtin_DenseArrayRawDataParameter:$rawData ); + + let builders = [ + AttrBuilderWithInferredContext<(ins "RankedTensorType":$type, + "ArrayRef":$rawData), [{ + return $_get(type.getContext(), type, rawData); + }]>, + ]; + let extraClassDeclaration = [{ /// Allow implicit conversion to ElementsAttr. operator ElementsAttr() const { @@ -207,13 +222,9 @@ const int64_t *value_begin_impl(OverloadToken) const; const float *value_begin_impl(OverloadToken) const; const double *value_begin_impl(OverloadToken) const; - - /// Printer for the short form: will dispatch to the appropriate subclass. - void print(AsmPrinter &printer) const; - void print(raw_ostream &os) const; - /// Print the short form `42, 100, -1` without any braces or prefix. - void printWithoutBraces(raw_ostream &os) const; }]; + + let genVerifyDecl = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -820,96 +820,140 @@ } //===----------------------------------------------------------------------===// -// ElementsAttr Parser +// DenseArrayAttr Parser //===----------------------------------------------------------------------===// namespace { -/// This class provides an implementation of AsmParser, allowing to call back -/// into the libMLIRIR-provided APIs for invoking attribute parsing code defined -/// in libMLIRIR. -class CustomAsmParser : public AsmParserImpl { +/// A generic dense array element parser. It parsers integer and floating point +/// elements. +class DenseArrayElementParser { public: - CustomAsmParser(Parser &parser) - : AsmParserImpl(parser.getToken().getLoc(), parser) {} + explicit DenseArrayElementParser(Type type) : type(type) {} + + /// Parse an integer element. + ParseResult parseIntegerElement(Parser &p); + + /// Parse a floating point element. + ParseResult parseFloatElement(Parser &p); + + /// Convert the current contents to a dense array. + DenseArrayAttr getAttr() { + return DenseArrayAttr::get(RankedTensorType::get(size, type), rawData); + } + +private: + /// Append the raw data of an APInt to the result. + void append(const APInt &data); + + /// The array element type. + Type type; + /// The resultant byte array representing the contents of the array. + std::vector rawData; + /// The number of elements in the array. + int64_t size = 0; }; } // namespace +void DenseArrayElementParser::append(const APInt &data) { + llvm::append_range( + rawData, ArrayRef(reinterpret_cast(data.getRawData()), + data.getBitWidth() / 8)); + ++size; +} + +ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) { + bool isNegative = p.consumeIf(Token::minus); + + // Parse an integer literal as an APInt. + Optional value; + StringRef spelling = p.getToken().getSpelling(); + if (p.getToken().isAny(Token::kw_true, Token::kw_false)) { + if (!type.isInteger(1)) + return p.emitError("expected i1 type for 'true' or 'false' values"); + value = APInt(/*numBits=*/8, p.getToken().is(Token::kw_true), + !type.isUnsignedInteger()); + p.consumeToken(); + } else if (p.consumeIf(Token::integer)) { + value = buildAttributeAPInt(type, isNegative, spelling); + if (!value) + return p.emitError("integer constant out of range"); + } else { + return p.emitError("expected integer literal"); + } + append(*value); + return success(); +} + +ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) { + bool isNegative = p.consumeIf(Token::minus); + + Token token = p.getToken(); + Optional result; + auto floatType = type.cast(); + if (p.consumeIf(Token::integer)) { + // Parse an integer literal as a float. + if (p.parseFloatFromIntegerLiteral(result, token, isNegative, + floatType.getFloatSemantics(), + floatType.getWidth())) + return failure(); + } else if (p.consumeIf(Token::floatliteral)) { + // Parse a floating point literal. + Optional val = token.getFloatingPointValue(); + if (!val) + return failure(); + result = APFloat(isNegative ? -*val : *val); + if (!type.isF64()) { + bool unused; + result->convert(floatType.getFloatSemantics(), + APFloat::rmNearestTiesToEven, &unused); + } + } else { + return p.emitError("expected integer or floating point literal"); + } + + append(result->bitcastToAPInt()); + return success(); +} + /// Parse a dense array attribute. Attribute Parser::parseDenseArrayAttr(Type type) { consumeToken(Token::kw_array); + if (parseToken(Token::less, "expected '<' after 'array'")) + return {}; + + // Only bool or integer and floating point elements divisible by bytes are + // supported. SMLoc typeLoc = getToken().getLoc(); - if (parseToken(Token::less, "expected '<' after 'array'") || - (!type && !(type = parseType()))) + if (!type && !(type = parseType())) + return {}; + if (!type.isIntOrIndexOrFloat()) { + emitError(typeLoc, "expected integer or float type, got: ") << type; + return {}; + } + if (!type.isInteger(1) && type.getIntOrFloatBitWidth() % 8 != 0) { + emitError(typeLoc, "element type bitwidth must be a multiple of 8"); return {}; - CustomAsmParser parser(*this); - Attribute result; + } + // Check for empty list. - bool isEmptyList = getToken().is(Token::greater); - if (!isEmptyList && - parseToken(Token::colon, "expected ':' after dense array type")) + if (consumeIf(Token::greater)) + return DenseArrayAttr::get(RankedTensorType::get(0, type), {}); + if (parseToken(Token::colon, "expected ':' after dense array type")) return {}; - if (auto intType = type.dyn_cast()) { - switch (type.getIntOrFloatBitWidth()) { - case 1: - if (isEmptyList) - result = DenseBoolArrayAttr::get(parser.getContext(), {}); - else - result = DenseBoolArrayAttr::parseWithoutBraces(parser, Type{}); - break; - case 8: - if (isEmptyList) - result = DenseI8ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI8ArrayAttr::parseWithoutBraces(parser, Type{}); - break; - case 16: - if (isEmptyList) - result = DenseI16ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI16ArrayAttr::parseWithoutBraces(parser, Type{}); - break; - case 32: - if (isEmptyList) - result = DenseI32ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI32ArrayAttr::parseWithoutBraces(parser, Type{}); - break; - case 64: - if (isEmptyList) - result = DenseI64ArrayAttr::get(parser.getContext(), {}); - else - result = DenseI64ArrayAttr::parseWithoutBraces(parser, Type{}); - break; - default: - emitError(typeLoc, "expected i1, i8, i16, i32, or i64 but got: ") << type; - return {}; - } - } else if (auto floatType = type.dyn_cast()) { - switch (type.getIntOrFloatBitWidth()) { - case 32: - if (isEmptyList) - result = DenseF32ArrayAttr::get(parser.getContext(), {}); - else - result = DenseF32ArrayAttr::parseWithoutBraces(parser, Type{}); - break; - case 64: - if (isEmptyList) - result = DenseF64ArrayAttr::get(parser.getContext(), {}); - else - result = DenseF64ArrayAttr::parseWithoutBraces(parser, Type{}); - break; - default: - emitError(typeLoc, "expected f32 or f64 but got: ") << type; + DenseArrayElementParser eltParser(type); + if (type.isIntOrIndex()) { + if (parseCommaSeparatedList( + [&] { return eltParser.parseIntegerElement(*this); })) return {}; - } } else { - emitError(typeLoc, "expected integer or float type, got: ") << type; - return {}; + if (parseCommaSeparatedList( + [&] { return eltParser.parseFloatElement(*this); })) + return {}; } if (parseToken(Token::greater, "expected '>' to close an array attribute")) return {}; - return result; + return eltParser.getAttr(); } /// Parse a dense elements attribute. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -383,7 +383,7 @@ // Accessors. intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().size(); + return unwrap(attr).cast().size(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1476,6 +1476,9 @@ void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, bool allowHex); + /// Print a dense array attribute. + void printDenseArrayAttr(DenseArrayAttr attr); + void printDialectAttribute(Attribute attr); void printDialectType(Type type); @@ -1858,12 +1861,13 @@ } os << '>'; } - } else if (auto denseArrayAttr = attr.dyn_cast()) { + } else if (auto denseArrayAttr = attr.dyn_cast()) { typeElision = AttrTypeElision::Must; os << "array<" << denseArrayAttr.getType().getElementType(); - if (!denseArrayAttr.empty()) + if (!denseArrayAttr.empty()) { os << ": "; - denseArrayAttr.printWithoutBraces(os); + printDenseArrayAttr(denseArrayAttr); + } os << ">"; } else if (auto resourceAttr = attr.dyn_cast()) { os << "dense_resource<"; @@ -2029,6 +2033,31 @@ printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); } +void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { + Type type = attr.getElementType(); + unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth(); + ArrayRef data = attr.getRawData(); + + auto printElementAt = [&](unsigned i) { + // FIXME: The data needs to be padded, requiring an extra copy. + SmallVector padded(data.slice(bitwidth / 8 * i, bitwidth / 8)); + padded.append( + APInt::APINT_WORD_SIZE - padded.size() % APInt::APINT_WORD_SIZE, 0); + APInt value(type.getIntOrFloatBitWidth(), + {reinterpret_cast(padded.data()), + padded.size() / APInt::APINT_WORD_SIZE}); + // Print the data as-is or as a float. + if (type.isIntOrIndex()) { + printDenseIntElement(value, getStream(), !type.isUnsignedInteger()); + } else { + APFloat fltVal(type.cast().getFloatSemantics(), value); + printFloatValue(fltVal, getStream()); + } + }; + llvm::interleaveComma(llvm::seq(0, attr.size()), getStream(), + printElementAt); +} + void AsmPrinter::Impl::printType(Type type) { if (!type) { os << "<>"; diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -687,50 +687,50 @@ // DenseArrayAttr //===----------------------------------------------------------------------===// -const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { +LogicalResult +DenseArrayAttr::verify(function_ref emitError, + RankedTensorType type, ArrayRef rawData) { + if (type.getRank() != 1) + return emitError() << "expected rank 1 tensor type"; + if (!type.getElementType().isIntOrIndexOrFloat()) + return emitError() << "expected integer or floating point element type"; + int64_t dataSize = rawData.size(); + int64_t size = type.getShape().front(); + if (type.getElementType().isInteger(1)) { + if (size != dataSize) + return emitError() << "expected " << size + << " bytes for i1 array but got " << dataSize; + } else if (size * type.getElementTypeBitWidth() != dataSize * 8) { + return emitError() << "expected data size (" << size << " elements, " + << type.getElementTypeBitWidth() + << " bits each) does not match: " << dataSize + << " bytes"; + } + return success(); +} + +const bool *DenseArrayAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } -const int8_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { +const int8_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } -const int16_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { +const int16_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } -const int32_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { +const int32_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } -const int64_t * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { +const int64_t *DenseArrayAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } -const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { +const float *DenseArrayAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } -const double * -DenseArrayBaseAttr::value_begin_impl(OverloadToken) const { +const double *DenseArrayAttr::value_begin_impl(OverloadToken) const { return cast().asArrayRef().begin(); } -void DenseArrayBaseAttr::print(AsmPrinter &printer) const { - print(printer.getStream()); -} - -void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { - llvm::TypeSwitch(*this) - .Case([&](auto attr) { attr.printWithoutBraces(os); }); -} - -void DenseArrayBaseAttr::print(raw_ostream &os) const { - os << "["; - printWithoutBraces(os); - os << "]"; -} - namespace { /// Instantiations of this class provide utilities for interacting with native /// data types in the context of DenseArrayAttr. @@ -815,19 +815,19 @@ } // namespace template -void DenseArrayAttr::print(AsmPrinter &printer) const { +void DenseArrayAttrImpl::print(AsmPrinter &printer) const { print(printer.getStream()); } template -void DenseArrayAttr::printWithoutBraces(raw_ostream &os) const { +void DenseArrayAttrImpl::printWithoutBraces(raw_ostream &os) const { llvm::interleaveComma(asArrayRef(), os, [&](T value) { DenseArrayAttrUtil::printElement(os, value); }); } template -void DenseArrayAttr::print(raw_ostream &os) const { +void DenseArrayAttrImpl::print(raw_ostream &os) const { os << "["; printWithoutBraces(os); os << "]"; @@ -835,8 +835,8 @@ /// Parse a DenseArrayAttr without the braces: `1, 2, 3` template -Attribute DenseArrayAttr::parseWithoutBraces(AsmParser &parser, - Type odsType) { +Attribute DenseArrayAttrImpl::parseWithoutBraces(AsmParser &parser, + Type odsType) { SmallVector data; if (failed(parser.parseCommaSeparatedList([&]() { T value; @@ -851,7 +851,7 @@ /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` template -Attribute DenseArrayAttr::parse(AsmParser &parser, Type odsType) { +Attribute DenseArrayAttrImpl::parse(AsmParser &parser, Type odsType) { if (parser.parseLSquare()) return {}; // Handle empty list case. @@ -865,7 +865,7 @@ /// Conversion from DenseArrayAttr to ArrayRef. template -DenseArrayAttr::operator ArrayRef() const { +DenseArrayAttrImpl::operator ArrayRef() const { ArrayRef raw = getRawData(); assert((raw.size() % sizeof(T)) == 0); return ArrayRef(reinterpret_cast(raw.data()), @@ -874,19 +874,19 @@ /// Builds a DenseArrayAttr from an ArrayRef. template -DenseArrayAttr DenseArrayAttr::get(MLIRContext *context, - ArrayRef content) { +DenseArrayAttrImpl DenseArrayAttrImpl::get(MLIRContext *context, + ArrayRef content) { auto shapedType = RankedTensorType::get( content.size(), DenseArrayAttrUtil::getElementType(context)); auto rawArray = ArrayRef(reinterpret_cast(content.data()), content.size() * sizeof(T)); return Base::get(context, shapedType, rawArray) - .template cast>(); + .template cast>(); } template -bool DenseArrayAttr::classof(Attribute attr) { - if (auto denseArray = attr.dyn_cast()) +bool DenseArrayAttrImpl::classof(Attribute attr) { + if (auto denseArray = attr.dyn_cast()) return DenseArrayAttrUtil::checkElementType(denseArray.getElementType()); return false; } @@ -894,13 +894,13 @@ namespace mlir { namespace detail { // Explicit instantiation for all the supported DenseArrayAttr. -template class DenseArrayAttr; -template class DenseArrayAttr; -template class DenseArrayAttr; -template class DenseArrayAttr; -template class DenseArrayAttr; -template class DenseArrayAttr; -template class DenseArrayAttr; +template class DenseArrayAttrImpl; +template class DenseArrayAttrImpl; +template class DenseArrayAttrImpl; +template class DenseArrayAttrImpl; +template class DenseArrayAttrImpl; +template class DenseArrayAttrImpl; +template class DenseArrayAttrImpl; } // namespace detail } // namespace mlir diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -569,6 +569,26 @@ f64attr = [-142.] // CHECK-SAME: emptyattr = [] emptyattr = [] + + // CHECK: array.sizes + // CHECK-SAME: i0 = array + // CHECK-SAME: ui0 = array + // CHECK-SAME: si0 = array + // CHECK-SAME: i24 = array + // CHECK-SAME: ui24 = array + // CHECK-SAME: si24 = array + // CHECK-SAME: bf16 = array + // CHECK-SAME: f16 = array + "array.sizes"() { + x0_i0 = array, + x1_ui0 = array, + x2_si0 = array, + x3_i24 = array, + x4_ui24 = array, + x5_si24 = array, + x6_bf16 = array, + x7_f16 = array + }: () -> () return } diff --git a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp --- a/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp +++ b/mlir/test/lib/IR/TestBuiltinAttributeInterfaces.cpp @@ -41,9 +41,8 @@ auto elementsAttr = attr.getValue().dyn_cast(); if (!elementsAttr) continue; - if (auto concreteAttr = - attr.getValue().dyn_cast()) { - llvm::TypeSwitch(concreteAttr) + if (auto concreteAttr = attr.getValue().dyn_cast()) { + llvm::TypeSwitch(concreteAttr) .Case([&](DenseBoolArrayAttr attr) { testElementsAttrIteration(op, attr, "bool"); })