diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -17,8 +17,12 @@ namespace mlir { class AffineMap; +class AsmResourceBlob; class BoolAttr; +class BuiltinDialect; class DenseIntElementsAttr; +template +struct DialectResourceBlobHandle; class FlatSymbolRefAttr; class FunctionType; class IntegerSet; @@ -729,6 +733,13 @@ return denseAttr && denseAttr.isSplat(); } }; + +//===----------------------------------------------------------------------===// +// DenseResourceElementsAttr +//===----------------------------------------------------------------------===// + +using DenseResourceElementsHandle = DialectResourceBlobHandle; + } // namespace mlir //===----------------------------------------------------------------------===// @@ -743,6 +754,9 @@ //===----------------------------------------------------------------------===// namespace mlir { +//===----------------------------------------------------------------------===// +// DenseArrayAttr + namespace detail { /// Base class for DenseArrayAttr that is instantiated and specialized for each /// supported element type below. @@ -795,6 +809,71 @@ using DenseF32ArrayAttr = detail::DenseArrayAttr; using DenseF64ArrayAttr = detail::DenseArrayAttr; +//===----------------------------------------------------------------------===// +// DenseResourceElementsAttr + +namespace detail { +/// Base class for DenseResourceElementsAttr that is instantiated and +/// specialized for each supported element type below. +template +class DenseResourceElementsAttrBase : public DenseResourceElementsAttr { +public: + using DenseResourceElementsAttr::DenseResourceElementsAttr; + + /// A builder that inserts a new resource using the provided blob. The handle + /// of the inserted blob is used when building the attribute. The provided + /// `blobName` is used as a hint for the key of the new handle for the `blob` + /// resource, but may be changed if necessary to ensure uniqueness during + /// insertion. + static DenseResourceElementsAttrBase + get(ShapedType type, StringRef blobName, AsmResourceBlob blob); + + /// Return the data of this attribute as an ArrayRef if it is present, + /// returns None otherwise. + Optional> tryGetAsArrayRef() const; + + /// Support for isa<>/cast<>. + static bool classof(Attribute attr); +}; + +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +extern template class DenseResourceElementsAttrBase; +} // namespace detail + +// Public names for all the supported DenseResourceElementsAttr. + +using DenseBoolResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseI8ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseI16ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseI32ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseI64ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseUI8ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseUI16ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseUI32ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseUI64ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseF32ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; +using DenseF64ResourceElementsAttr = + detail::DenseResourceElementsAttrBase; + //===----------------------------------------------------------------------===// // BoolAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -17,6 +17,7 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/BuiltinDialect.td" include "mlir/IR/BuiltinAttributeInterfaces.td" +include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SubElementInterfaces.td" // TODO: Currently the attributes defined in this file are prefixed with @@ -424,6 +425,65 @@ let skipDefaultBuilders = 1; } +//===----------------------------------------------------------------------===// +// DenseResourceElementsAttr +//===----------------------------------------------------------------------===// + +def Builtin_DenseResourceElementsAttr : Builtin_Attr<"DenseResourceElements", [ + ElementsAttrInterface, TypedAttrInterface + ]> { + let summary = "An Attribute containing a dense multi-dimensional array " + "backed by a resource"; + let description = [{ + Syntax: + + ``` + dense-resource-elements-attribute ::= + `dense_resource` `<` resource-handle `>` `:` shaped-type + ``` + + A dense resource elements attribute is an elements attribute backed by a + handle to a builtin dialect resource containing a densely packed array of + values. This class provides the low-level attribute, which should only be + interacted with in very generic terms, actual access to the underlying + resource data is intended to be managed through one of the subclasses, such + as; `DenseBoolResourceElementsAttr`, `DenseUI64ResourceElementsAttr`, + `DenseI32ResourceElementsAttr`, `DenseF32ResourceElementsAttr`, + `DenseF64ResourceElementsAttr`, etc. + + Examples: + + ```mlir + // A tensor referencing a builtin dialect resource, `resource_1`, with two + // unsigned i32 elements. + dense_resource : tensor<2xui32> + ``` + }]; + let parameters = (ins + AttributeSelfTypeParameter<"", "ShapedType">:$type, + ResourceHandleParameter<"DenseResourceElementsHandle">:$rawHandle + ); + let builders = [ + AttrBuilderWithInferredContext<(ins + "ShapedType":$type, "DenseResourceElementsHandle":$handle + )> + ]; + let extraClassDeclaration = [{ + protected: + /// A builder that inserts a new resource into the builtin dialect's blob + /// manager using the provided blob. The handle of the inserted blob is used + /// when building the attribute. The provided `blobName` is used as a hint + /// for the key of the new handle for the `blob` resource, but may be + /// changed if necessary to ensure uniqueness during insertion. + static DenseResourceElementsAttr get( + ShapedType type, StringRef blobName, AsmResourceBlob blob + ); + + public: + }]; + let skipDefaultBuilders = 1; +} + //===----------------------------------------------------------------------===// // DictionaryAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -1023,8 +1023,17 @@ template FailureOr parseResourceHandle() { SMLoc handleLoc = getCurrentLocation(); - FailureOr handle = parseResourceHandle( - getContext()->getOrLoadDialect()); + + // Try to load the dialect that owns the handle. + auto *dialect = + getContext()->getOrLoadDialect(); + if (!dialect) { + return emitError(handleLoc) + << "dialect '" << ResourceT::Dialect::getDialectNamespace() + << "' is unknown"; + } + + FailureOr handle = parseResourceHandle(dialect); if (failed(handle)) return failure(); if (auto *result = dyn_cast(&*handle)) diff --git a/mlir/lib/AsmParser/AsmParserImpl.h b/mlir/lib/AsmParser/AsmParserImpl.h --- a/mlir/lib/AsmParser/AsmParserImpl.h +++ b/mlir/lib/AsmParser/AsmParserImpl.h @@ -460,7 +460,7 @@ /// Parse a handle to a resource within the assembly format. FailureOr parseResourceHandle(Dialect *dialect) override { - const auto *interface = dyn_cast_or_null(dialect); + const auto *interface = dyn_cast(dialect); if (!interface) { return parser.emitError() << "dialect '" << dialect->getNamespace() << "' does not expect resource handles"; diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -15,9 +15,10 @@ #include "AsmParserImpl.h" #include "mlir/AsmParser/AsmParserState.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Support/Endian.h" @@ -97,6 +98,10 @@ case Token::kw_dense: return parseDenseElementsAttr(type); + // Parse a dense resource elements attribute. + case Token::kw_dense_resource: + return parseDenseResourceElementsAttr(type); + // Parse a dictionary attribute. case Token::l_brace: { NamedAttrList elements; @@ -241,6 +246,7 @@ case Token::kw_affine_map: case Token::kw_affine_set: case Token::kw_dense: + case Token::kw_dense_resource: case Token::kw_false: case Token::kw_loc: case Token::kw_opaque: @@ -928,6 +934,39 @@ return literalParser.getAttr(loc, type); } +Attribute Parser::parseDenseResourceElementsAttr(Type attrType) { + auto loc = getToken().getLoc(); + consumeToken(Token::kw_dense_resource); + if (parseToken(Token::less, "expected '<' after 'dense_resource'")) + return nullptr; + + // Parse the resource handle. + FailureOr rawHandle = + parseResourceHandle(getContext()->getLoadedDialect()); + if (failed(rawHandle) || parseToken(Token::greater, "expected '>'")) + return nullptr; + + auto *handle = dyn_cast(&*rawHandle); + if (!handle) + return emitError(loc, "invalid `dense_resource` handle type"), nullptr; + + // Parse the type of the attribute if the user didn't provide one. + SMLoc typeLoc = loc; + if (!attrType) { + typeLoc = getToken().getLoc(); + if (parseToken(Token::colon, "expected ':'") || !(attrType = parseType())) + return nullptr; + } + + ShapedType shapedType = attrType.dyn_cast(); + if (!shapedType) { + emitError(typeLoc, "`dense_resource` expected a shaped type"); + return nullptr; + } + + return DenseResourceElementsAttr::get(shapedType, *handle); +} + /// Parse an opaque elements attribute. Attribute Parser::parseOpaqueElementsAttr(Type attrType) { SMLoc loc = getToken().getLoc(); diff --git a/mlir/lib/AsmParser/Parser.h b/mlir/lib/AsmParser/Parser.h --- a/mlir/lib/AsmParser/Parser.h +++ b/mlir/lib/AsmParser/Parser.h @@ -160,6 +160,7 @@ /// Parse a handle to a dialect resource within the assembly format. FailureOr parseResourceHandle(const OpAsmDialectInterface *dialect, StringRef &name); + FailureOr parseResourceHandle(Dialect *dialect); //===--------------------------------------------------------------------===// // Type Parsing @@ -272,6 +273,9 @@ Attribute parseDenseElementsAttr(Type attrType); ShapedType parseElementsLiteralType(Type type); + /// Parse a dense resource elements attribute. + Attribute parseDenseResourceElementsAttr(Type attrType); + /// Parse a DenseArrayAttr. Attribute parseDenseArrayAttr(); diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -340,6 +340,17 @@ return entry.second; } +FailureOr +Parser::parseResourceHandle(Dialect *dialect) { + const auto *interface = dyn_cast(dialect); + if (!interface) { + return emitError() << "dialect '" << dialect->getNamespace() + << "' does not expect resource handles"; + } + StringRef resourceName; + return parseResourceHandle(interface, resourceName); +} + //===----------------------------------------------------------------------===// // Code Completion diff --git a/mlir/lib/AsmParser/TokenKinds.def b/mlir/lib/AsmParser/TokenKinds.def --- a/mlir/lib/AsmParser/TokenKinds.def +++ b/mlir/lib/AsmParser/TokenKinds.def @@ -87,6 +87,7 @@ TOK_KEYWORD(ceildiv) TOK_KEYWORD(complex) TOK_KEYWORD(dense) +TOK_KEYWORD(dense_resource) TOK_KEYWORD(f16) TOK_KEYWORD(f32) TOK_KEYWORD(f64) diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/OpImplementation.h" @@ -1896,6 +1897,10 @@ os << " "; denseArrayAttr.printWithoutBraces(os); os << "]"; + } else if (auto resourceAttr = attr.dyn_cast()) { + os << "dense_resource<"; + printResourceHandle(resourceAttr.getRawHandle()); + os << ">"; } else if (auto locAttr = attr.dyn_cast()) { printLocation(locAttr); } else { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/Operation.h" @@ -36,11 +37,10 @@ //===----------------------------------------------------------------------===// void BuiltinDialect::registerAttributes() { - addAttributes(); + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/IR/BuiltinAttributes.cpp.inc" + >(); } //===----------------------------------------------------------------------===// @@ -1576,6 +1576,130 @@ return false; } +//===----------------------------------------------------------------------===// +// DenseResourceElementsAttr +//===----------------------------------------------------------------------===// + +DenseResourceElementsAttr +DenseResourceElementsAttr::get(ShapedType type, + DenseResourceElementsHandle handle) { + return Base::get(type.getContext(), type, handle); +} + +DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type, + StringRef blobName, + AsmResourceBlob blob) { + // Extract the builtin dialect resource manager from context and construct a + // handle by inserting a new resource using the provided blob. + auto &manager = + DenseResourceElementsHandle::getManagerInterface(type.getContext()); + return get(type, manager.insert(blobName, std::move(blob))); +} + +//===----------------------------------------------------------------------===// +// DenseResourceElementsAttrBase + +namespace { +/// Instantiations of this class provide utilities for interacting with native +/// data types in the context of DenseResourceElementsAttr. +template +struct DenseResourceAttrUtil; +template +struct DenseResourceElementsAttrIntUtil { + static bool checkElementType(Type eltType) { + IntegerType type = eltType.dyn_cast(); + if (!type || type.getWidth() != width) + return false; + return isSigned ? !type.isUnsigned() : !type.isSigned(); + } +}; +template <> +struct DenseResourceAttrUtil { + static bool checkElementType(Type eltType) { + return eltType.isSignlessInteger(1); + } +}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<8, true> {}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<8, false> {}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<16, true> {}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<16, false> {}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<32, true> {}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<32, false> {}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<64, true> {}; +template <> +struct DenseResourceAttrUtil + : public DenseResourceElementsAttrIntUtil<64, false> {}; +template <> +struct DenseResourceAttrUtil { + static bool checkElementType(Type eltType) { return eltType.isF32(); } +}; +template <> +struct DenseResourceAttrUtil { + static bool checkElementType(Type eltType) { return eltType.isF64(); } +}; +} // namespace + +template +DenseResourceElementsAttrBase +DenseResourceElementsAttrBase::get(ShapedType type, StringRef blobName, + AsmResourceBlob blob) { + // Check that the blob is in the form we were expecting. + assert(blob.getDataAlignment() == alignof(T) && + "alignment mismatch between expected alignment and blob alignment"); + assert(((blob.getData().size() % sizeof(T)) == 0) && + "size mismatch between expected element width and blob size"); + assert(DenseResourceAttrUtil::checkElementType(type.getElementType()) && + "invalid shape element type for provided type `T`"); + return DenseResourceElementsAttr::get(type, blobName, std::move(blob)) + .template cast>(); +} + +template +Optional> +DenseResourceElementsAttrBase::tryGetAsArrayRef() const { + if (AsmResourceBlob *blob = this->getRawHandle().getBlob()) + return blob->template getDataAs(); + return llvm::None; +} + +template +bool DenseResourceElementsAttrBase::classof(Attribute attr) { + auto resourceAttr = attr.dyn_cast(); + return resourceAttr && DenseResourceAttrUtil::checkElementType( + resourceAttr.getElementType()); +} + +namespace mlir { +namespace detail { +// Explicit instantiation for all the supported DenseResourceElementsAttr. +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +template class DenseResourceElementsAttrBase; +} // namespace detail +} // namespace mlir + //===----------------------------------------------------------------------===// // OpaqueElementsAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -16,6 +16,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" @@ -23,14 +24,27 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// Builtin Dialect +// TableGen'erated dialect //===----------------------------------------------------------------------===// #include "mlir/IR/BuiltinDialect.cpp.inc" +//===----------------------------------------------------------------------===// +// BuiltinBlobManagerInterface +//===----------------------------------------------------------------------===// + +using BuiltinBlobManagerInterface = + ResourceBlobManagerDialectInterfaceBase; + +//===----------------------------------------------------------------------===// +// BuiltinOpAsmDialectInterface +//===----------------------------------------------------------------------===// + namespace { struct BuiltinOpAsmDialectInterface : public OpAsmDialectInterface { - using OpAsmDialectInterface::OpAsmDialectInterface; + BuiltinOpAsmDialectInterface(Dialect *dialect, + BuiltinBlobManagerInterface &mgr) + : OpAsmDialectInterface(dialect), blobManager(mgr) {} AliasResult getAlias(Attribute attr, raw_ostream &os) const override { if (attr.isa()) { @@ -57,6 +71,38 @@ } return AliasResult::NoAlias; } + + //===------------------------------------------------------------------===// + // Resources + //===------------------------------------------------------------------===// + + std::string + getResourceKey(const AsmDialectResourceHandle &handle) const override { + return cast(handle).getKey().str(); + } + FailureOr + declareResource(StringRef key) const final { + return blobManager.insert(key); + } + LogicalResult parseResource(AsmParsedResourceEntry &entry) const final { + FailureOr blob = entry.parseAsBlob(); + if (failed(blob)) + return failure(); + + // Update the blob for this entry. + blobManager.update(entry.getKey(), std::move(*blob)); + return success(); + } + void + buildResources(Operation *op, + const SetVector &referencedResources, + AsmResourceBuilder &provider) const final { + blobManager.buildResources(provider, referencedResources.getArrayRef()); + } + +private: + /// The blob manager for the dialect. + BuiltinBlobManagerInterface &blobManager; }; } // namespace @@ -68,7 +114,9 @@ #define GET_OP_LIST #include "mlir/IR/BuiltinOps.cpp.inc" >(); - addInterfaces(); + + auto &blobInterface = addInterface(); + addInterface(blobInterface); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/DialectResourceBlobManager.cpp b/mlir/lib/IR/DialectResourceBlobManager.cpp --- a/mlir/lib/IR/DialectResourceBlobManager.cpp +++ b/mlir/lib/IR/DialectResourceBlobManager.cpp @@ -57,7 +57,7 @@ Twine(nameCounter++).toVector(nameStorage); // Try inserting with the new name. - if (BlobEntry *entry = tryInsertion(name)) + if (BlobEntry *entry = tryInsertion(nameStorage)) return *entry; nameStorage.resize(name.size() + 1); } while (true); diff --git a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp --- a/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MLIRServer.cpp @@ -712,8 +712,9 @@ /// Signal a completion for an attribute. void completeAttribute(const llvm::StringMap &aliases) override { - appendSimpleCompletions({"affine_set", "affine_map", "dense", "false", - "loc", "opaque", "sparse", "true", "unit"}, + appendSimpleCompletions({"affine_set", "affine_map", "dense", + "dense_resource", "false", "loc", "opaque", + "sparse", "true", "unit"}, lsp::CompletionItemKind::Field, /*sortText=*/"1"); diff --git a/mlir/test/IR/dense-resource-elements-attr.mlir b/mlir/test/IR/dense-resource-elements-attr.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/dense-resource-elements-attr.mlir @@ -0,0 +1,13 @@ +// RUN: mlir-opt -allow-unregistered-dialect %s -verify-diagnostics -split-input-file | FileCheck %s + +// CHECK: attr = dense_resource : tensor<3xi64> +"test.user_op"() {attr = dense_resource : tensor<3xi64> } : () -> () + +{-# + dialect_resources: { + builtin: { + // CHECK: blob1: "0x08000000010000000000000002000000000000000300000000000000" + blob1: "0x08000000010000000000000002000000000000000300000000000000" + } + } +#-} diff --git a/mlir/test/IR/invalid-builtin-attributes.mlir b/mlir/test/IR/invalid-builtin-attributes.mlir --- a/mlir/test/IR/invalid-builtin-attributes.mlir +++ b/mlir/test/IR/invalid-builtin-attributes.mlir @@ -519,3 +519,23 @@ "J// ----- " // expected-error {{expected}} + +// ----- + +// expected-error@+1 {{expected '<' after 'dense_resource'}} +#attr = dense_resource> + +// ----- + +// expected-error@+1 {{expected '>'}} +#attr = dense_resource + +// ----- + +// expected-error@+1 {{`dense_resource` expected a shaped type}} +#attr = dense_resource : i32 diff --git a/mlir/test/IR/invalid-file-metadata.mlir b/mlir/test/IR/invalid-file-metadata.mlir --- a/mlir/test/IR/invalid-file-metadata.mlir +++ b/mlir/test/IR/invalid-file-metadata.mlir @@ -59,10 +59,10 @@ // ----- -// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'builtin'}} +// expected-error@+4 {{unknown 'resource' key 'unknown_entry' for dialect 'ml_program'}} {-# dialect_resources: { - builtin: { + ml_program: { unknown_entry: "foo" } } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "gtest/gtest.h" @@ -13,6 +15,10 @@ using namespace mlir; using namespace mlir::detail; +//===----------------------------------------------------------------------===// +// DenseElementsAttr +//===----------------------------------------------------------------------===// + template static void testSplat(Type eltType, const EltTy &splatElt) { RankedTensorType shape = RankedTensorType::get({2, 1}, eltType); @@ -203,7 +209,119 @@ auto attr = DenseElementsAttr::get(shape, llvm::makeArrayRef({elementValue})); EXPECT_TRUE(attr.getValues()[0] == value); } +} // namespace + +//===----------------------------------------------------------------------===// +// DenseResourceElementsAttr +//===----------------------------------------------------------------------===// + +template +static void checkNativeAccess(MLIRContext *ctx, ArrayRef data, + Type elementType) { + auto type = RankedTensorType::get(data.size(), elementType); + auto attr = + AttrT::get(type, "resource", UnmanagedAsmResourceBlob::allocate(data)); + + // Check that we can access and iterate the data properly. + Optional> attrData = attr.tryGetAsArrayRef(); + EXPECT_TRUE(attrData.hasValue()); + EXPECT_EQ(*attrData, data); + + // Check that we cast to this attribute when possible. + Attribute genericAttr = attr; + EXPECT_TRUE(genericAttr.template isa()); +} +template +static void checkNativeIntAccess(Builder &builder, size_t intWidth) { + T data[] = {0, 1, 2}; + checkNativeAccess(builder.getContext(), llvm::makeArrayRef(data), + builder.getIntegerType(intWidth)); +} + +namespace { +TEST(DenseResourceElementsAttrTest, CheckNativeAccess) { + MLIRContext context; + Builder builder(&context); + + // Bool + bool boolData[] = {true, false, true}; + checkNativeAccess( + &context, llvm::makeArrayRef(boolData), builder.getI1Type()); + + // Unsigned integers + checkNativeIntAccess(builder, 8); + checkNativeIntAccess(builder, 16); + checkNativeIntAccess(builder, 32); + checkNativeIntAccess(builder, 64); + + // Signed integers + checkNativeIntAccess(builder, 8); + checkNativeIntAccess(builder, 16); + checkNativeIntAccess(builder, 32); + checkNativeIntAccess(builder, 64); + + // Float + float floatData[] = {0, 1, 2}; + checkNativeAccess( + &context, llvm::makeArrayRef(floatData), builder.getF32Type()); + + // Double + double doubleData[] = {0, 1, 2}; + checkNativeAccess( + &context, llvm::makeArrayRef(doubleData), builder.getF64Type()); +} + +TEST(DenseResourceElementsAttrTest, CheckNoCast) { + MLIRContext context; + Builder builder(&context); + + // Create a i32 attribute. + ArrayRef data; + auto type = RankedTensorType::get(data.size(), builder.getI32Type()); + Attribute i32ResourceAttr = DenseI32ResourceElementsAttr::get( + type, "resource", UnmanagedAsmResourceBlob::allocate(data)); + + EXPECT_TRUE(i32ResourceAttr.isa()); + EXPECT_FALSE(i32ResourceAttr.isa()); + EXPECT_FALSE(i32ResourceAttr.isa()); +} +TEST(DenseResourceElementsAttrTest, CheckInvalidData) { + MLIRContext context; + Builder builder(&context); + + // Create a bool attribute with data of the incorrect type. + ArrayRef data; + auto type = RankedTensorType::get(data.size(), builder.getI32Type()); + ASSERT_DEATH( + { + DenseBoolResourceElementsAttr::get( + type, "resource", UnmanagedAsmResourceBlob::allocate(data)); + }, + "alignment mismatch between expected alignment and blob alignment"); +} + +TEST(DenseResourceElementsAttrTest, CheckInvalidType) { + MLIRContext context; + Builder builder(&context); + + // Create a bool attribute with incorrect type. + ArrayRef data; + auto type = RankedTensorType::get(data.size(), builder.getI32Type()); + ASSERT_DEATH( + { + DenseBoolResourceElementsAttr::get( + type, "resource", UnmanagedAsmResourceBlob::allocate(data)); + }, + "invalid shape element type for provided type `T`"); +} +} // namespace + +//===----------------------------------------------------------------------===// +// SparseElementsAttr +//===----------------------------------------------------------------------===// + +namespace { TEST(SparseElementsAttrTest, GetZero) { MLIRContext context; context.allowUnregisteredDialects();