diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -26,6 +26,8 @@ class Location; class Operation; class ShapedType; +class AsmParser; +class AsmPrinter; //===----------------------------------------------------------------------===// // Elements Attributes @@ -64,9 +66,60 @@ template struct is_complex_t : public std::false_type {}; template struct is_complex_t> : public std::true_type {}; + +class DenseArrayAttributeStorage; +enum class DenseArrayAttributeElementType { None, I8, I16, I32, I64, F32, F64 }; } // namespace detail -/// An attribute that represents a reference to a dense vector or tensor object. +class DenseArrayAttrBase + : public Attribute::AttrBase { +public: + using Base::Base; + + /// Allow implicit conversion to ElementsAttr. + operator ElementsAttr() const { + return *this ? cast() : nullptr; + } + + /// Methods for support type inquiry through isa, cast, and dyn_cast. + detail::DenseArrayAttributeElementType getElementType() const; + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; +}; + +template +class DenseArrayAttr : public DenseArrayAttrBase { +public: + using DenseArrayAttrBase::DenseArrayAttrBase; + operator ArrayRef() const; + static DenseArrayAttr get(MLIRContext *context, ArrayRef content); + + static Attribute parse(AsmParser &parser, Type odsType); + void print(AsmPrinter &printer) const; + void print(raw_ostream &os) const; + + static bool classof(Attribute attr); +}; +template <> +void DenseArrayAttr::print(raw_ostream &os) const; + +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +extern template class DenseArrayAttr; +using DenseI8ArrayAttr = DenseArrayAttr; +using DenseI16ArrayAttr = DenseArrayAttr; +using DenseI32ArrayAttr = DenseArrayAttr; +using DenseI64ArrayAttr = DenseArrayAttr; +using DenseF32ArrayAttr = DenseArrayAttr; +using DenseF64ArrayAttr = DenseArrayAttr; + +/// An attribute that represents a reference to a dense vector or tensor +/// object. /// class DenseElementsAttr : public Attribute { public: diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1517,6 +1517,60 @@ let convertFromStorage = "$_self"; } +def DenseI8ArrayAttr : + ElementsAttrBase()">, + "i8 dense array attribute"> { + let storageType = [{ ::mlir::DenseI8ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI16ArrayAttr : + ElementsAttrBase()">, + "i16 dense array attribute"> { + let storageType = [{ ::mlir::DenseI16ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI32ArrayAttr : + ElementsAttrBase()">, + "i32 dense array attribute"> { + let storageType = [{ ::mlir::DenseI32ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseI64ArrayAttr : + ElementsAttrBase()">, + "i64 dense array attribute"> { + let storageType = [{ ::mlir::DenseI64ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseF32ArrayAttr : + ElementsAttrBase()">, + "float dense array attribute"> { + let storageType = [{ ::mlir::DenseF32ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + +def DenseF64ArrayAttr : + ElementsAttrBase()">, + "double dense array attribute"> { + let storageType = [{ ::mlir::DenseF64ArrayAttr }]; + let returnType = [{ ::llvm::ArrayRef }]; + + let convertFromStorage = "$_self"; +} + def IndexElementsAttr : IntElementsAttrBase() .getType() diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1822,9 +1822,36 @@ } os << '>'; } - + } else if (auto denseArrayAttr = attr.dyn_cast()) { + os << "arrayref "; + switch (denseArrayAttr.getElementType()) { + case detail::DenseArrayAttributeElementType::None: + os << " "; + break; + case detail::DenseArrayAttributeElementType::I8: + os << "i8 "; + break; + case detail::DenseArrayAttributeElementType::I16: + os << "i16 "; + break; + case detail::DenseArrayAttributeElementType::I32: + os << "i32 "; + break; + case detail::DenseArrayAttributeElementType::I64: + os << "i64 "; + break; + case detail::DenseArrayAttributeElementType::F32: + os << "f32 "; + break; + case detail::DenseArrayAttributeElementType::F64: + os << "f64 "; + break; + } + denseArrayAttr.print(os); } else if (auto locAttr = attr.dyn_cast()) { printLocation(locAttr); + } else { + llvm::report_fatal_error("Unknown builtin attribute"); } // Don't print the type if we must elide it, or if it is a None type. if (typeElision != AttrTypeElision::Must && !attrType.isa()) { diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -40,6 +40,36 @@ return eltType.getIntOrFloatBitWidth(); } +/// TODO +struct DenseArrayAttributeStorage : public AttributeStorage { + using KeyTy = + std::tuple>; + DenseArrayAttributeStorage(DenseArrayAttributeElementType eltType, + ArrayRef elements) + : eltType(eltType), elements(elements) {} + + bool operator==(const KeyTy &tblgenKey) const { + return (eltType == std::get<0>(tblgenKey)) && + (elements == std::get<1>(tblgenKey)); + } + + static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { + return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey)); + } + + static DenseArrayAttributeStorage * + construct(::mlir::AttributeStorageAllocator &allocator, + const KeyTy &tblgenKey) { + auto eltType = std::get<0>(tblgenKey); + auto elements = std::get<1>(tblgenKey); + elements = allocator.copyInto(elements); + return new (allocator.allocate()) + DenseArrayAttributeStorage(eltType, elements); + } + DenseArrayAttributeElementType eltType; + ::llvm::ArrayRef elements; +}; + /// An attribute representing a reference to a dense vector or tensor object. struct DenseElementsAttributeStorage : public AttributeStorage { public: diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Types.h" @@ -35,11 +36,11 @@ //===----------------------------------------------------------------------===// void BuiltinDialect::registerAttributes() { - addAttributes(); + addAttributes(); } //===----------------------------------------------------------------------===// @@ -664,6 +665,149 @@ readBits(getData(), offset + storageWidth, bitWidth)}; } +//===----------------------------------------------------------------------===// +// DenseArrayAttr +//===----------------------------------------------------------------------===// + +detail::DenseArrayAttributeElementType +DenseArrayAttrBase::getElementType() const { + return getImpl()->eltType; +} + +void DenseArrayAttrBase::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +void DenseArrayAttrBase::print(raw_ostream &os) const { + switch (getElementType()) { + case detail::DenseArrayAttributeElementType::I8: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::I16: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::I32: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::I64: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::F32: + this->cast().print(os); + return; + case detail::DenseArrayAttributeElementType::F64: + this->cast().print(os); + return; + default: + os << ""; + return; + } +} + +template +void DenseArrayAttr::print(AsmPrinter &printer) const { + print(printer.getStream()); +} + +template +void DenseArrayAttr::print(raw_ostream &os) const { + ArrayRef values{*this}; + os << "["; + llvm::interleaveComma(values, os); + os << "]"; +} + +template <> +void DenseArrayAttr::print(raw_ostream &os) const { + ArrayRef values{*this}; + os << "["; + llvm::interleaveComma(values, os, [&](int64_t v) { os << v; }); + os << "]"; +} + +template +Attribute DenseArrayAttr::parse(AsmParser &parser, Type odsType) { + if (parser.parseLSquare()) + return {}; + SmallVector data; + do { + T value; + if (parser.parseInteger(value)) + return {}; + data.push_back(value); + if (parser.parseOptionalComma()) + break; + } while (1); + + if (parser.parseRSquare()) + return {}; + return get(parser.getContext(), data); +} + +template +DenseArrayAttr::operator ArrayRef() const { + ArrayRef raw = getImpl()->elements; + assert((raw.size() % sizeof(T)) == 0); + return ArrayRef(reinterpret_cast(raw.data()), + raw.size() / sizeof(T)); +} + +namespace { +template +constexpr detail::DenseArrayAttributeElementType denseArrayAttrEltTypeMap = + detail::DenseArrayAttributeElementType::None; +template <> +constexpr detail::DenseArrayAttributeElementType + denseArrayAttrEltTypeMap = + detail::DenseArrayAttributeElementType::I8; +template <> +constexpr detail::DenseArrayAttributeElementType + denseArrayAttrEltTypeMap = + detail::DenseArrayAttributeElementType::I16; +template <> +constexpr detail::DenseArrayAttributeElementType + denseArrayAttrEltTypeMap = + detail::DenseArrayAttributeElementType::I32; +template <> +constexpr detail::DenseArrayAttributeElementType + denseArrayAttrEltTypeMap = + detail::DenseArrayAttributeElementType::I64; +template <> +constexpr detail::DenseArrayAttributeElementType + denseArrayAttrEltTypeMap = + detail::DenseArrayAttributeElementType::F32; +template <> +constexpr detail::DenseArrayAttributeElementType + denseArrayAttrEltTypeMap = + detail::DenseArrayAttributeElementType::F64; +} // namespace + +template +DenseArrayAttr DenseArrayAttr::get(MLIRContext *context, + ArrayRef content) { + return Base::get( + context, denseArrayAttrEltTypeMap, + ArrayRef(reinterpret_cast(content.data()), + content.size() * sizeof(T))) + .template cast>(); +} + +template +bool DenseArrayAttr::classof(Attribute attr) { + return attr.isa() && + attr.cast().getElementType() == + denseArrayAttrEltTypeMap; +} + +namespace mlir { +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +template class DenseArrayAttr; +} // namespace mlir + //===----------------------------------------------------------------------===// // DenseElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/IntegerSet.h" #include "mlir/Parser/AsmParserState.h" #include "llvm/ADT/StringExtras.h" @@ -35,6 +36,7 @@ /// | symbol-ref-id (`::` symbol-ref-id)* /// | `dense` `<` tensor-literal `>` `:` /// (tensor-type | vector-type) +/// | arrayref (integer-type | float-type) /// | `sparse` `<` attribute-value `,` attribute-value `>` /// `:` (tensor-type | vector-type) /// | `opaque` `<` dialect-namespace `,` hex-string-literal @@ -90,6 +92,10 @@ case Token::kw_dense: return parseDenseElementsAttr(type); + // Parse a DenseArrayAttr + case Token::kw_arrayref: + return parseDenseArrayAttr(); + // Parse a dictionary attribute. case Token::l_brace: { NamedAttrList elements; @@ -806,6 +812,126 @@ // ElementsAttr Parser //===----------------------------------------------------------------------===// +/// Parse a dense elements attribute. +Attribute Parser::parseDenseArrayAttr() { + consumeToken(Token::kw_arrayref); + + auto typeLoc = getToken().getLoc(); + auto type = parseType(); + auto intType = type.dyn_cast(); + auto floatType = type.dyn_cast(); + if (!intType && !floatType) { + emitError(typeLoc, "expected integer or float type, got: ") << type; + return {}; + } + int bitWidth = type.getIntOrFloatBitWidth(); + switch (bitWidth) { + case 8: + LLVM_FALLTHROUGH; + case 16: + LLVM_FALLTHROUGH; + case 32: + LLVM_FALLTHROUGH; + case 64: + break; + default: + if (intType) + emitError(typeLoc, "expected i8, i16, i32, or i64 but got: ") << type; + else + emitError(typeLoc, "expected f32 or f64 but got: ") << type; + return {}; + } + + if (parseToken(Token::l_square, "expected '[' after 'arrayref'")) + return nullptr; + + auto parseIntegerList = [&](auto &data) -> LogicalResult { + do { + bool neg = false; + if (getToken().is(Token::minus)) { + neg = true; + consumeToken(); + } + if (!getToken().is(Token::integer)) { + emitError("expected integer literal"); + return failure(); + } + auto optionalValue = getToken().getUInt64IntegerValue(); + if (!optionalValue.hasValue()) { + emitError("invalid integer literal"); + return failure(); + } + using value_type = + typename std::remove_reference::type::value_type; + int64_t value = optionalValue.getValue(); + value = neg ? -value : value; + value_type min = std::numeric_limits::min(); + value_type max = std::numeric_limits::max(); + if (value < min || value > max) { + emitError("overflow expected integer [") + << min << ", " << max << "]: " << value; + return failure(); + } + data.push_back(value); + consumeToken(); + if (!getToken().is(Token::comma)) + break; + consumeToken(); + } while (1); + return success(); + }; + Attribute result; + if (intType) { + switch (bitWidth) { + case 8: { + SmallVector data; + if (failed(parseIntegerList(data))) + return {}; + result = DenseI8ArrayAttr::get(getContext(), data); + break; + } + case 16: { + SmallVector data; + if (failed(parseIntegerList(data))) + return {}; + result = DenseI16ArrayAttr::get(getContext(), data); + break; + } + case 32: { + SmallVector data; + if (failed(parseIntegerList(data))) + return {}; + result = DenseI32ArrayAttr::get(getContext(), data); + break; + } + case 64: { + SmallVector data; + if (failed(parseIntegerList(data))) + return {}; + result = DenseI64ArrayAttr::get(getContext(), data); + break; + } + } + } else { + llvm::report_fatal_error("dense array float not implemented"); + } + + if (parseToken(Token::r_square, "expected ']' after 'arrayref'")) + return {}; + return result; +} + +template <> +struct FieldParser { + static FailureOr parse(AsmParser &parser) { + auto value = FieldParser::parse(parser); + if (failed(value)) + return failure(); + return {}; + // DenseArrayAttrBase::get(parser.getContext(), ArrayRef{}); + } +}; + /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { auto attribLoc = getToken().getLoc(); diff --git a/mlir/lib/Parser/Parser.h b/mlir/lib/Parser/Parser.h --- a/mlir/lib/Parser/Parser.h +++ b/mlir/lib/Parser/Parser.h @@ -265,6 +265,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a DenseArrayAttr. + Attribute parseDenseArrayAttr(); + /// Parse a sparse elements attribute. Attribute parseSparseElementsAttr(Type attrType); diff --git a/mlir/lib/Parser/TokenKinds.def b/mlir/lib/Parser/TokenKinds.def --- a/mlir/lib/Parser/TokenKinds.def +++ b/mlir/lib/Parser/TokenKinds.def @@ -78,6 +78,7 @@ TOK_KEYWORD(affine_map) TOK_KEYWORD(affine_set) TOK_KEYWORD(attributes) +TOK_KEYWORD(arrayref) TOK_KEYWORD(bf16) TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) diff --git a/mlir/test/IR/attribute.mlir b/mlir/test/IR/attribute.mlir --- a/mlir/test/IR/attribute.mlir +++ b/mlir/test/IR/attribute.mlir @@ -500,6 +500,30 @@ return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test DenseArrayAttr +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @dense_array_attr +// CHECK: i16attr = arrayref i16 [1, -2, 3] +func @dense_array_attr() attributes{ i16attr = arrayref i16 [1, -2, 3] } { +// CHECK: test.dense_array_attr +// CHECK-SAME: i8attr = [1, -2, 3] +// CHECK-SAME: i16attr = [3, 5, -4, 10] +// CHECK-SAME: i32attr = [1024, 453, -6435] +// CHECK-SAME: i64attr = [-142] + test.dense_array_attr + i8attr = [1, -2, 3] + i16attr = [3, 5, -4, 10] + i32attr = [1024, 453, -6435] + i64attr = [-142] + + return +} + // ----- //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -289,6 +289,21 @@ ); } +def DenseArrayAttrOp : TEST_Op<"dense_array_attr"> { + let arguments = (ins + DenseI8ArrayAttr:$i8attr, + DenseI16ArrayAttr:$i16attr, + DenseI32ArrayAttr:$i32attr, + DenseI64ArrayAttr:$i64attr +// DenseF32ArrayAttr:$f32attr, +// DenseF64ArrayAttr:$f64attr + ); + let assemblyFormat = [{ + `i8attr` `=` $i8attr `i16attr` `=` $i16attr `i32attr` `=` $i32attr `i64attr` `=` $i64attr attr-dict + }]; +} + + //===----------------------------------------------------------------------===// // Test Enum Attributes //===----------------------------------------------------------------------===//