diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -411,18 +411,50 @@ } Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } + /// Parse a reference to an attribute or type using the given reader. + LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { + uint64_t attrIdx; + if (failed(reader.parseVarInt(attrIdx))) + return failure(); + result = resolveAttribute(attrIdx); + return success(!!result); + } + LogicalResult parseType(EncodingReader &reader, Type &result) { + uint64_t typeIdx; + if (failed(reader.parseVarInt(typeIdx))) + return failure(); + result = resolveType(typeIdx); + return success(!!result); + } + + template + LogicalResult parseAttribute(EncodingReader &reader, T &result) { + Attribute baseResult; + if (failed(parseAttribute(reader, baseResult))) + return failure(); + if ((result = baseResult.dyn_cast())) + return success(); + return reader.emitError("expected attribute of type: ", + llvm::getTypeName(), ", but got: ", baseResult); + } + private: /// Resolve the given entry at `index`. template T resolveEntry(SmallVectorImpl> &entries, size_t index, StringRef entryType); - /// Parse the value defined within the given reader. `code` indicates how the - /// entry was encoded. - LogicalResult parseEntry(EncodingReader &reader, bool hasCustomEncoding, - Attribute &result); - LogicalResult parseEntry(EncodingReader &reader, bool hasCustomEncoding, - Type &result); + /// Parse an entry using the given reader that was encoded using the textual + /// assembly format. + template + LogicalResult parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType); + + /// Parse an entry using the given reader that was encoded using a custom + /// bytecode format. + template + LogicalResult parseCustomEntry(Entry &entry, EncodingReader &reader, + StringRef entryType); /// The set of attribute and type entries. SmallVector attributes; @@ -506,8 +538,15 @@ // Parse the entry. EncodingReader reader(entry.data, fileLoc); - if (failed(parseEntry(reader, entry.hasCustomEncoding, entry.entry))) + + // Parse based on how the entry was encoded. + if (entry.hasCustomEncoding) { + if (failed(parseCustomEntry(entry, reader, entryType))) + return T(); + } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { return T(); + } + if (!reader.empty()) { (void)reader.emitError("unexpected trailing bytes after " + entryType + " entry"); @@ -516,51 +555,37 @@ return entry.entry; } -LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, - bool hasCustomEncoding, - Attribute &result) { - // Handle the fallback case, where the attribute was encoded using its - // assembly format. - if (!hasCustomEncoding) { - StringRef attrStr; - if (failed(reader.parseNullTerminatedString(attrStr))) - return failure(); - - size_t numRead = 0; - if (!(result = parseAttribute(attrStr, fileLoc->getContext(), numRead))) - return failure(); - if (numRead != attrStr.size()) { - return reader.emitError( - "trailing characters found after Attribute assembly format: ", - attrStr.drop_front(numRead)); - } - return success(); - } - - return reader.emitError("unexpected Attribute encoding"); -} +template +LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, + StringRef entryType) { + StringRef asmStr; + if (failed(reader.parseNullTerminatedString(asmStr))) + return failure(); -LogicalResult AttrTypeReader::parseEntry(EncodingReader &reader, - bool hasCustomEncoding, Type &result) { - // Handle the fallback case, where the type was encoded using its - // assembly format. - if (!hasCustomEncoding) { - StringRef typeStr; - if (failed(reader.parseNullTerminatedString(typeStr))) - return failure(); + // Invoke the MLIR assembly parser to parse the entry text. + size_t numRead = 0; + MLIRContext *context = fileLoc->getContext(); + if constexpr (std::is_same_v) + result = ::parseType(asmStr, context, numRead); + else + result = ::parseAttribute(asmStr, context, numRead); + if (!result) + return failure(); - size_t numRead = 0; - if (!(result = parseType(typeStr, fileLoc->getContext(), numRead))) - return failure(); - if (numRead != typeStr.size()) { - return reader.emitError( - "trailing characters found after Type assembly format: " + - typeStr.drop_front(numRead)); - } - return success(); + // Ensure there weren't dangling characters after the entry. + if (numRead != asmStr.size()) { + return reader.emitError("trailing characters found after ", entryType, + " assembly format: ", asmStr.drop_front(numRead)); } + return success(); +} - return reader.emitError("unexpected Type encoding"); +template +LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, + EncodingReader &reader, + StringRef entryType) { + // FIXME: Add support for reading custom attribute/type encodings. + return reader.emitError("unexpected Attribute encoding"); } //===----------------------------------------------------------------------===// @@ -600,20 +625,13 @@ //===--------------------------------------------------------------------===// // Attribute/Type Section - /// Parse an attribute or type using the given reader. Returns nullptr in the - /// case of failure. - Attribute parseAttribute(EncodingReader &reader); - Type parseType(EncodingReader &reader); - + /// Parse an attribute or type using the given reader. template - T parseAttribute(EncodingReader &reader) { - if (Attribute attr = parseAttribute(reader)) { - if (auto derivedAttr = attr.dyn_cast()) - return derivedAttr; - (void)reader.emitError("expected attribute of type: ", - llvm::getTypeName(), ", but got: ", attr); - } - return T(); + LogicalResult parseAttribute(EncodingReader &reader, T &result) { + return attrTypeReader.parseAttribute(reader, result); + } + LogicalResult parseType(EncodingReader &reader, Type &result) { + return attrTypeReader.parseType(reader, result); } //===--------------------------------------------------------------------===// @@ -863,23 +881,6 @@ return *opName->opName; } -//===----------------------------------------------------------------------===// -// Attribute/Type Section - -Attribute BytecodeReader::parseAttribute(EncodingReader &reader) { - uint64_t attrIdx; - if (failed(reader.parseVarInt(attrIdx))) - return Attribute(); - return attrTypeReader.resolveAttribute(attrIdx); -} - -Type BytecodeReader::parseType(EncodingReader &reader) { - uint64_t typeIdx; - if (failed(reader.parseVarInt(typeIdx))) - return Type(); - return attrTypeReader.resolveType(typeIdx); -} - //===----------------------------------------------------------------------===// // IR Section @@ -996,8 +997,8 @@ return failure(); /// Parse the location. - LocationAttr opLoc = parseAttribute(reader); - if (!opLoc) + LocationAttr opLoc; + if (failed(parseAttribute(reader, opLoc))) return failure(); // With the location and name resolved, we can start building the operation @@ -1006,8 +1007,8 @@ // Parse the attributes of the operation. if (opMask & bytecode::OpEncodingMask::kHasAttrs) { - DictionaryAttr dictAttr = parseAttribute(reader); - if (!dictAttr) + DictionaryAttr dictAttr; + if (failed(parseAttribute(reader, dictAttr))) return failure(); opState.attributes = dictAttr; } @@ -1019,7 +1020,7 @@ return failure(); opState.types.resize(numResults); for (int i = 0, e = numResults; i < e; ++i) - if (!(opState.types[i] = parseType(reader))) + if (failed(parseType(reader, opState.types[i]))) return failure(); } @@ -1130,11 +1131,10 @@ argLocs.reserve(numArgs); while (numArgs--) { - Type argType = parseType(reader); - if (!argType) - return failure(); - LocationAttr argLoc = parseAttribute(reader); - if (!argLoc) + Type argType; + LocationAttr argLoc; + if (failed(parseType(reader, argType)) || + failed(parseAttribute(reader, argLoc))) return failure(); argTypes.push_back(argType);